In [25]:
import asdf
import numpy as np
import matplotlib.pyplot as plt
import orientations as ori
from mpl_toolkits.axes_grid1 import make_axes_locatable
plt.rcParams.update({"font.family": "serif"})


config = {
    "m12f_res7100": np.arange(302, 524, step=5),
    "m12i_res7100": np.arange(200, 480, step=5),
    "m12m_res7100": np.arange(153, 564, step=5),
    "m12w_res7100": np.arange(152, 380, step=5),
}

tensors_f = asdf.open("tensors_m12f_res7100.asdf")
tensors_i = asdf.open("tensors_m12i_res7100.asdf")
tensors_m = asdf.open("tensors_m12m_res7100.asdf")
tensors_w = asdf.open("tensors_m12w_res7100.asdf")

lmc_f = asdf.open("lmc_positions_m12f_res7100.asdf")
lmc_i = asdf.open("lmc_positions_m12i_res7100.asdf")
lmc_m = asdf.open("lmc_positions_m12m_res7100.asdf")
lmc_w = asdf.open("lmc_positions_m12w_res7100.asdf")

scale = 350

max_f = np.array(tensors_f["virial"])[:, 0] * scale
max_i = np.array(tensors_i["virial"])[:, 0] * scale
max_m = np.array(tensors_m["virial"])[:, 0] * scale
max_w = np.array(tensors_w["virial"])[:, 0] * scale

tensor_idx_f = np.array([np.where(lmc_f["snapshot"] == i)[0] for i in config["m12f_res7100"]]).flatten()
tensor_idx_i = np.array([np.where(lmc_i["snapshot"] == i)[0] for i in config["m12i_res7100"]]).flatten()
tensor_idx_m = np.array([np.where(lmc_m["snapshot"] == i)[0] for i in config["m12m_res7100"]]).flatten()
tensor_idx_w = np.array([np.where(lmc_w["snapshot"] == i)[0] for i in config["m12w_res7100"]]).flatten()

m12f_time = ori.getSnapshotData(
    "../../../data/latte_metaldiff/m12i_res7100/", lmc_f["snapshot"][tensor_idx_f])[0]
m12i_time = ori.getSnapshotData(
    "../../../data/latte_metaldiff/m12i_res7100/", lmc_i["snapshot"][tensor_idx_i])[0]
m12m_time = ori.getSnapshotData(
    "../../../data/latte_metaldiff/m12i_res7100/", lmc_m["snapshot"][tensor_idx_m])[0]
m12w_time = ori.getSnapshotData(
    "../../../data/latte_metaldiff/m12i_res7100/", lmc_w["snapshot"][tensor_idx_w])[0]

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(10,3), dpi=200)
ax[0].scatter(lmc_f["position"][tensor_idx_f][:, 0], lmc_f["position"][tensor_idx_f][:, 1], c=m12f_time, label="LMC")
ax[1].scatter(lmc_f["position"][tensor_idx_f][:, 0], lmc_f["position"][tensor_idx_f][:, 2], c=m12f_time)
f = ax[2].scatter(lmc_f["position"][tensor_idx_f][:, 1], lmc_f["position"][tensor_idx_f][:, 2], c=m12f_time)

ax[0].scatter(max_f[:, 0], max_f[:, 1], c=m12f_time, marker='+', label="Major axis")
ax[1].scatter(max_f[:, 0], max_f[:, 2], c=m12f_time, marker='+')
ax[2].scatter(max_f[:, 1], max_f[:, 2], c=m12f_time, marker='+')

ax[0].set_xlabel(f"$x$ [kpc]")
ax[1].set_xlabel(f"$x$ [kpc]")
ax[2].set_xlabel(f"$y$ [kpc]")

ax[0].set_ylabel(f"$y$ [kpc]")
ax[1].set_ylabel(f"$z$ [kpc]")
ax[2].set_ylabel(f"$z$ [kpc]")

for a in ax:
    a.axvline(0, c='k', ls='--', alpha=.25)
    a.axhline(0, c='k', ls='--', alpha=.25)

fig.tight_layout()

cax = fig.add_axes([1, 0.15, .025, .85])
plt.colorbar(f, cax=cax, label="Time [Gyr]")

fig.text(.5, 1.1, "m12f", fontsize=14, fontweight="bold")

plt.savefig("tests/m12f.pdf", bbox_inches="tight")
plt.close()