In [None]:
dirc_path = '/home/habjan.e/'

import sys
sys.path.append(dirc_path + "TNG/TNG_cluster_dynamics")
import TNG_DA
import numpy as np 
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
from IPython.display import display, Markdown
from astropy.io import fits
from astropy.table import Table
import pickle

import os
sys.path.append(os.getcwd())
from training_structure import train_model, predict

import jraph
import jax.numpy as jnp

### Import parameters

In [None]:
suffix_jax = '_testing'

save_path = os.getcwd() + '/GNN_models/gnn_model_params' + suffix_jax + '.pkl'

with open(save_path, 'rb') as f:
    loaded_params = pickle.load(f)

### Import model

In [None]:
import sys
sys.path.append(os.getcwd())
from gnn import GraphConvNet
model = GraphConvNet(latent_size = 128, 
                         hidden_size = 256, 
                         num_mlp_layers = 3, 
                         message_passing_steps = 5, 
                         skip_connections = True,
                         edge_skip_connections = True,
                         norm = "pair", 
                         attention = True,
                         shared_weights = False,
                         relative_updates = False,
                         output_dim = 3,
                         dropout_rate = 0.1)

### Make predictions using the trained model

In [None]:
batch_file_path = "/projects/mccleary_group/habjan.e/TNG/Data/GNN_SBI_data/graph_data"

#pred_train, tgt_train, mask_train = predict(model = model, params = loaded_params, data_dir = batch_file_path, data_prefix = 'train')
#pred_test, tgt_test, mask_test = predict(model = model, params = loaded_params, data_dir = batch_file_path, data_prefix = 'test')

### Pick a cluster to look at

In [None]:
cluster_ind = 3

pos, vel, groups, subhalo_masses, h, halo_mass = TNG_DA.get_cluster_props(cluster_ind)

### Define some functions that will be used to make classifications

In [None]:
LATENT_SIZE = 128
KNN_K = 16

def make_graph(nodes_np: np.ndarray) -> jraph.GraphsTuple:
    """Convert (N, 3) numpy array -> GraphsTuple."""

    nodes = jnp.asarray(nodes_np, dtype=jnp.float32)
    N   = nodes.shape[0]

    # Pair-wise calculation of x, y, v_z
    diffs = nodes[:, None, :] - nodes[None, :, :]
    d2 = jnp.sum(diffs ** 2, axis=-1)
    d2 = d2 + jnp.eye(N) * 1e9
    knn_idx = jnp.argsort(d2, axis=1)[:, :KNN_K]

    senders = jnp.repeat(jnp.arange(N, dtype=jnp.int32), KNN_K)
    receivers = knn_idx.reshape(-1).astype(jnp.int32)

    src = nodes[senders]
    dst = nodes[receivers]
    rel = dst - src
    dist = jnp.linalg.norm(rel, axis=-1, keepdims=True)
    edges = jnp.concatenate([rel, dist], axis=-1)

    dummy_globals = jnp.zeros((1, LATENT_SIZE), dtype=jnp.float32)

    return jraph.GraphsTuple(
        nodes=nodes,             
        edges=edges,
        senders=senders,
        receivers=receivers,
        n_node=jnp.array([N], dtype=jnp.int32),
        n_edge=jnp.array([edges.shape[0]],  dtype=jnp.int32),
        globals=dummy_globals
    )

def prediction(model, params, in_graph):
    """
    Make predictions with a trained model.
    """
        
    preds = model.apply({'params': params}, in_graph, deterministic = True)

    return preds.nodes #jnp.concatenate(preds.nodes, axis=0).squeeze()

### Make a graph for a single cluster

In [None]:
x_mean = np.nanmean(pos[:, 0])
y_mean = np.nanmean(pos[:, 1])
vz_mean = np.nanmean(vel[:, 2])

x_std = np.nanstd(pos[:, 0])
y_std = np.nanstd(pos[:, 1])
vz_std = np.nanstd(vel[:, 2])

obs_arr = np.array([(pos[:, 0] - x_mean) / x_std, (pos[:, 1] - y_mean) / y_std, (vel[:, 2] - vz_mean) / vz_std]).T

cl_graph = make_graph(obs_arr)

### Predict on this single cluster

In [None]:
preds = prediction(model = model, params = loaded_params, in_graph = cl_graph)

### Compare true TNG positions/velocities with predictions

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(16, 4), gridspec_kw={'wspace': 0.3})

pos_mean = np.mean([x_mean, y_mean])
pos_std = np.mean([x_std, y_std])

one_one = np.linspace(-10000, 10000, 100)

# Plot into each subplot
axs[0].scatter(pos[:, 2], (preds[:, 0] * pos_std) + pos_mean, c='k', s=10)
axs[0].plot(one_one, one_one, c='k', linestyle='--')
axs[0].set_xlabel('TNG galaxy z-position [kpc]', fontsize = 17.5)
axs[0].set_ylabel('GNN z-position [kpc]', fontsize = 17.5)
lims = np.concatenate([pos[:, 2], (preds[:, 0] * pos_std) + pos_mean])
axs[0].set_xlim(np.min(lims)*1.1, np.max(lims)*1.1)
axs[0].set_ylim(np.min(lims)*1.1, np.max(lims)*1.1)

axs[1].scatter(vel[:, 0], (preds[:, 1] * vz_std) + vz_mean, c='k', s=10)
axs[1].plot(one_one, one_one, c='k', linestyle='--')
axs[1].set_xlabel('TNG galaxy x-velocity [$km s^{-1}$]', fontsize = 17.5)
axs[1].set_ylabel('GNN x-velocity [$km s^{-1}$]', fontsize = 17.5)
lims = np.concatenate([vel[:, 0], (preds[:, 1] * vz_std) + vz_mean])
axs[1].set_xlim(np.min(lims)*1.1, np.max(lims)*1.1)
axs[1].set_ylim(np.min(lims)*1.1, np.max(lims)*1.1)

axs[2].scatter(vel[:, 1], (preds[:, 2] * vz_std) + vz_mean, c='k', s=10)
axs[2].plot(one_one, one_one, c='k', linestyle='--')
axs[2].set_xlabel('TNG galaxy y-velocity [$km s^{-1}$]', fontsize = 17.5)
axs[2].set_ylabel('GNN y-velocity [$km s^{-1}$]', fontsize = 17.5)
lims = np.concatenate([vel[:, 1], (preds[:, 2] * vz_std) + vz_mean])
axs[2].set_xlim(np.min(lims)*1.1, np.max(lims)*1.1)
axs[2].set_ylim(np.min(lims)*1.1, np.max(lims)*1.1)

### Import loss arrays

In [None]:
data_path = '/home/habjan.e/TNG/Sandbox_notebooks/phase_space_recon/Loss_arrays/'

test_loss = np.load(data_path + 'test_loss' + suffix_jax + '.npy')
train_loss = np.load(data_path + 'train_loss' + suffix_jax + '.npy')
epochs = np.arange(1, len(test_loss) + 1)

### Plot losses

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 4), gridspec_kw={'wspace': 0.2})


axs[0].plot(epochs, train_loss, color = 'blue', label = 'Training Loss')
axs[1].plot(epochs, test_loss, color = 'red', label = 'Validation Loss')

axs[0].set_xlabel('Epoch', fontsize = 20)
axs[1].set_xlabel('Epoch', fontsize = 20)
axs[0].set_ylabel(r'MSE Loss', fontsize = 20)

#axs[0].set_yscale('symlog', linthresh=1)
#axs[1].set_yscale('symlog', linthresh=1)

for ax in axs:
    fmt = ScalarFormatter(useMathText=True)
    fmt.set_scientific(True)
    fmt.set_powerlimits((0, 0))
    fmt.set_useOffset(False)
    ax.yaxis.set_major_formatter(fmt)

axs[0].legend()
axs[1].legend()

### Import a BAHAMAS cluster

In [None]:
import h5py

data_path = "/projects/mccleary_group/habjan.e/TNG/Data/GNN_SBI_data/"
train_file = "GNN_data_train.h5"
test_file = "GNN_data_test.h5"

key = '000256'  # sample id, zero-padded to 6 digits

with h5py.File(data_path + train_file, "r") as f:

    print(f.keys())

    grp = f[key]  # this is an h5py Group

    # DATASETS (arrays) -> use [:] to read
    nodes   = grp["padded_nodes"][:]
    mask    = grp["node_mask"][:]
    targets = grp["padded_targets"][:]

    projection_vector = grp["projection_vector"][:]
    x_position = grp["x_position"][:]
    y_position = grp["y_position"][:]
    z_position = grp["z_position"][:]
    x_velocity = grp["x_velocity"][:]
    y_velocity = grp["y_velocity"][:]
    z_velocity = grp["z_velocity"][:]
    # (optionally) subhalo masses if you need them
    # subhalo_masses = grp["subhalo_masses"][:]

    # ATTRIBUTES (scalars/metadata) -> use .attrs[]
    sim          = grp.attrs["simulation"]
    cluster_idx  = grp.attrs["cluster_index"]
    halo_mass    = grp.attrs["cluster_mass"]

    x_ro_mean  = grp.attrs["x_position_mean"]
    y_ro_mean  = grp.attrs["y_position_mean"]
    z_ro_mean  = grp.attrs["z_position_mean"]
    vx_ro_mean = grp.attrs["x_velocity_mean"]
    vy_ro_mean = grp.attrs["y_velocity_mean"]
    vz_ro_mean = grp.attrs["z_velocity_mean"]

    x_ro_std  = grp.attrs["x_position_std"]
    y_ro_std  = grp.attrs["y_position_std"]
    z_ro_std  = grp.attrs["z_position_std"]
    vx_ro_std = grp.attrs["x_velocity_std"]
    vy_ro_std = grp.attrs["y_velocity_std"]
    vz_ro_std = grp.attrs["z_velocity_std"]

    # If simulation comes back as bytes, decode it:
    if isinstance(sim, (bytes, np.bytes_)):
        sim = sim.decode("utf-8")


print(sim, cluster_idx, np.log10(halo_mass))

### Make predictions

In [None]:
cl_graph = make_graph(nodes)

preds = prediction(model = model, params = loaded_params, in_graph = cl_graph)

### Plot results

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(16, 4), gridspec_kw={'wspace': 0.3})

pos_mean = np.mean([x_mean, y_mean])
pos_std = np.mean([x_std, y_std])

one_one = np.linspace(-5000, 5000, 100)

# Plot into each subplot
axs[0].scatter(targets[mask, 0], preds[mask, 0], c='k', s=10)
axs[0].plot(one_one, one_one, c='k', linestyle='--')
axs[0].set_xlabel('TNG galaxy z-position', fontsize = 17.5)
axs[0].set_ylabel('GNN z-position', fontsize = 17.5)
lims = np.concatenate([targets[mask, 0], preds[mask, 0]])
axs[0].set_xlim(np.min(lims)*1.1, np.max(lims)*1.1)
axs[0].set_ylim(np.min(lims)*1.1, np.max(lims)*1.1)

axs[1].scatter(targets[mask, 1], preds[mask, 1], c='k', s=10)
axs[1].plot(one_one, one_one, c='k', linestyle='--')
axs[1].set_xlabel('TNG galaxy x-velocity', fontsize = 17.5)
axs[1].set_ylabel('GNN x-velocity', fontsize = 17.5)
lims = np.concatenate([targets[mask, 1], preds[mask, 1]])
axs[1].set_xlim(np.min(lims)*1.1, np.max(lims)*1.1)
axs[1].set_ylim(np.min(lims)*1.1, np.max(lims)*1.1)

axs[2].scatter(targets[mask, 2], preds[mask, 2], c='k', s=10)
axs[2].plot(one_one, one_one, c='k', linestyle='--')
axs[2].set_xlabel('TNG galaxy y-velocity', fontsize = 17.5)
axs[2].set_ylabel('GNN y-velocity', fontsize = 17.5)
lims = np.concatenate([targets[mask, 2], preds[mask, 2]])
axs[2].set_xlim(np.min(lims)*1.1, np.max(lims)*1.1)
axs[2].set_ylim(np.min(lims)*1.1, np.max(lims)*1.1)