In [None]:
import numpy as np
import matplotlib.pyplot as plt

import sys
import os
sys.path.append(os.getcwd())

from train_cnn_mlp import preload_hdf5_to_memory

import matplotlib as mpl
mpl.rcParams.update({
    "font.family": "serif",
    "font.serif": ["TeX Gyre Pagella", "Book Antiqua", "Palatino Linotype", "DejaVu Serif"]
})

import pickle
import jax
import jax.numpy as jnp

from cnn_mlp_model import CNNMLPModel, ModelConfig

### Import data

In [None]:
test_path = '/projects/mccleary_group/habjan.e/TNG/Data/CNN_MLP_data/CNN_MLP_test.h5'
train_path = '/projects/mccleary_group/habjan.e/TNG/Data/CNN_MLP_data/CNN_MLP_train.h5'

data_dict = preload_hdf5_to_memory(test_path)

data_im, data_gal, data_tar, data_mask = data_dict['images'], data_dict['gal_features'], data_dict['targets'], data_dict['mask']

### Plot data

In [None]:
test_idx = 900

imgs = data_im[test_idx, :, :, :]
cmap = 'cubehelix'

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 10), constrained_layout=True)
fig.set_constrained_layout_pads(w_pad=0.2, h_pad=0.2, wspace=0.1, hspace=0.1)
ax1, ax2, ax3 = axes[0], axes[1], axes[2]

im1 = ax1.imshow(imgs[0], vmin = np.quantile(imgs[0], 0.5), vmax = np.quantile(imgs[0], 0.98), cmap=cmap, origin="upper")
ax1.set_xlabel("X-coordinate", fontsize=14, fontweight='semibold')
ax1.set_ylabel("Y-coordinate", fontsize=14, fontweight='semibold')
cbar = fig.colorbar(im1, ax=ax1, fraction=0.046, pad=0.02)
cbar.set_label("Projected Mass", fontsize=14, fontweight='semibold')

im2 = ax2.imshow(imgs[1], cmap=cmap, origin="upper")
ax2.set_xlabel("X-coordinate", fontsize=14, fontweight='semibold')
ax2.set_ylabel("Y-coordinate", fontsize=14, fontweight='semibold')
cbar = fig.colorbar(im2, ax=ax2, fraction=0.046, pad=0.02)
cbar.set_label("Galaxy Density", fontsize=14, fontweight='semibold')

im3 = ax3.imshow(imgs[2], cmap=cmap, origin="upper")
ax3.set_xlabel("X-coordinate", fontsize=14, fontweight='semibold')
ax3.set_ylabel("Y-coordinate", fontsize=14, fontweight='semibold')
cbar = fig.colorbar(im3, ax=ax3, fraction=0.046, pad=0.02)
cbar.set_label("$\mathbf{V_{z}}$", fontsize=14, fontweight='semibold')

fig.savefig("/home/habjan.e/TNG/cluster_deprojection/figures/cnnmlp_data.png", bbox_inches="tight")
plt.show()

### Import model parameters and build model object

In [None]:
params_path = "/home/habjan.e/TNG/cluster_deprojection/cnn_mlp_model/CNN_MLP_models/cnn_mlp_params_cnnmlp_v1.pkl"
with open(params_path, "rb") as f:
    params = pickle.load(f)


cfg = ModelConfig(
    smoother_kernel=5,
    cnn_channels=(32, 64, 128, 256),
    cnn_dropout=0.0,
    mlp_hidden=(128, 128),
    head_hidden=(256, 256),
    dropout=0.0,
    output_dim=3,
)
model = CNNMLPModel(cfg=cfg)

### Pick a cluster to apply the model to

In [None]:
sample_idx = test_idx
images_4chw = data_im[sample_idx]
gal_feat = data_gal[sample_idx]
targets = data_tar[sample_idx]
mask = data_mask[sample_idx].astype(bool)

### Tranpose images to match CNN encoder
images_hwc = np.transpose(images_4chw, (1, 2, 0)).astype(np.float32)
images_b = jnp.asarray(images_hwc[None, ...])

### Iterate through cluster to make predictions

In [None]:
valid_idx = np.where(mask)[0]
pred_list = []

for i in valid_idx:

    gal_idx = int(valid_idx[0])
    mlp_in_b = jnp.asarray(gal_feat[gal_idx][None, ...])
    pred_b = model.apply({"params": params}, images_b, mlp_in_b, deterministic=True)
    pred_list.append(np.array(pred_b[0]))

pred_arr = np.array(pred_list)

### Plot results

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(15, 4), gridspec_kw={'wspace': 0.35})

one_one = np.linspace(-10000, 10000, 100)

x_plot_pos, y_plot_pos = targets[valid_idx, 0]* 1.5, pred_arr[:, 0]* 1.5

axs[0].scatter(x_plot_pos, y_plot_pos, c='k', s=10)
axs[0].plot(one_one, one_one, c='k', linestyle='--')
axs[0].set_xlabel(r'BAHAMAS $z$-position $\left[ Mpc \right]$', fontsize = 15)
axs[0].set_ylabel(r'CNN-MLP $z$-position $\left[ Mpc \right]$', fontsize = 15)
lims = np.concatenate([x_plot_pos, y_plot_pos])
axs[0].set_xlim(np.min(lims)*0.9, np.max(lims)*1.1)
axs[0].set_ylim(np.min(lims)*0.9, np.max(lims)*1.1)

x_plot_vel, y_plot_vel = targets[valid_idx, 1]*800, pred_arr[:, 1]*800

axs[1].scatter(x_plot_vel, y_plot_vel, c='k', s=10)
axs[1].plot(one_one, one_one, c='k', linestyle='--')
axs[1].set_xlabel(r'BAHAMAS $v_{x}$ $\left[ km s^{-1} \right]$', fontsize = 15)
axs[1].set_ylabel(r'CNN-MLP $v_{x}$ $\left[ km s^{-1} \right]$', fontsize = 15)
lims = np.concatenate([x_plot_vel, y_plot_vel])
axs[1].set_xlim(np.min(lims)*0.9, np.max(lims)*1.1)
axs[1].set_ylim(np.min(lims)*0.9, np.max(lims)*1.1)

x_plot_vel, y_plot_vel = targets[valid_idx, 2]*800, pred_arr[:, 2]*800

axs[2].scatter(x_plot_vel, y_plot_vel, c='k', s=10)
axs[2].plot(one_one, one_one, c='k', linestyle='--')
axs[2].set_xlabel(r'BAHAMAS $v_{y}$ $\left[ km s^{-1} \right]$', fontsize = 15)
axs[2].set_ylabel(r'CNN-MLP $v_{y}$ $\left[ km s^{-1} \right]$', fontsize = 15)
lims = np.concatenate([x_plot_vel, y_plot_vel])
axs[2].set_xlim(np.min(lims)*0.9, np.max(lims)*1.1)
axs[2].set_ylim(np.min(lims)*0.9, np.max(lims)*1.1)

fig.savefig("/home/habjan.e/TNG/cluster_deprojection/figures/cnnmlp_predictions.png", bbox_inches="tight")