In [None]:
import torch
import numpy as np
import sys,os
import matplotlib.pyplot as plt
sys.path.append('..')
from utils.utils import set_seed
import warnings
from scipy.interpolate import interp1d
warnings.filterwarnings("ignore")
import pickle

In [None]:
device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")
set_seed(42)
mag_list = [-0.5, -0.25, 0.0, 0.25, 0.5]
test_latent = {}
test_time = {}
for mag in mag_list:
    data = torch.load(f'../raw_data/L64_MC500_h0.1_T1.10_mag{mag}/X0_test.pt', map_location=device)
    kmc_time = np.load(f'../raw_data/L64_MC500_h0.1_T1.10_mag{mag}/kmc_times.npy')
    latent = torch.mean(data.to(torch.float32), dim=(2, 3))[:20]
    test_latent['mag_' + str(mag)] = latent
    test_time['mag_' + str(mag)] = kmc_time

In [None]:
# ========== visualization ==============
fig = plt.figure(figsize=(30, 5))
for i in range(len(mag_list)):
    axes = fig.add_subplot(1, len(mag_list), i + 1)
    axes.set_title(f'mag = {mag_list[i]}', fontsize=20)
    if i == 0:
        axes.set_xlabel('time', fontsize=20)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    axes.set_ylim(-1, 1)
    # axes.set_xlim(0, 100)
    tra = test_latent['mag_' + str(mag_list[i])].cpu().detach().numpy()
    t = test_time['mag_' + str(mag_list[i])]
    for j in range(tra.shape[0]):
        axes.plot(t[j, :], tra[j, :])
    axes.plot(0, mag_list[i], 'ro', markersize=10)
plt.tight_layout()
plt.show()
plt.close()

In [None]:
val_dt = torch.load(f'../raw_data/L64_MC200_h0.1_T1.10/time_step_val.pt', map_location=device)
mean_dt = torch.mean(val_dt)
mean_dt

In [None]:
t_mean = np.arange(0, 500) * mean_dt.item()

In [None]:
def calculate_trajectory_mean(tra_np, t_np, t_mean):

    tra_interpolated = []
    for i in range(tra_np.shape[0]):
        t_traj = t_np[i]
        tra_traj = tra_np[i]
        
        # Remove NaN values
        valid_mask = ~np.isnan(tra_traj) & ~np.isnan(t_traj)
        if np.sum(valid_mask) < 2:
            continue
            
        t_traj = t_traj[valid_mask]
        tra_traj = tra_traj[valid_mask]
        
        # Sort by time
        sort_idx = np.argsort(t_traj)
        t_traj = t_traj[sort_idx]
        tra_traj = tra_traj[sort_idx]
        
        # Interpolate to common grid
        if len(t_traj) >= 2:
            f = interp1d(t_traj, tra_traj, kind='linear', 
                       bounds_error=False, fill_value='extrapolate')
            tra_interp = f(t_mean)
            tra_interpolated.append(tra_interp)
    
    if len(tra_interpolated) == 0:
        return None, None
        
    tra_interpolated = np.array(tra_interpolated)
    tra_mean = np.mean(tra_interpolated, axis=0)
    
    return tra_mean

In [None]:
test_latent_mean = {}
# ========== visualization ==============
fig = plt.figure(figsize=(30, 5))
for i in range(len(mag_list)):
    axes = fig.add_subplot(1, len(mag_list), i + 1)
    axes.set_title(f'mag = {mag_list[i]}', fontsize=20)
    if i == 0:
        axes.set_xlabel('time', fontsize=20)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    axes.set_ylim(-1, 1)
    
    tra = test_latent['mag_' + str(mag_list[i])].cpu().detach().numpy()
    t = test_time['mag_' + str(mag_list[i])]
    
    # Calculate trajectory mean
    tra_mean = calculate_trajectory_mean(tra, t, t_mean)
    test_latent_mean['mag_' + str(mag_list[i])] = tra_mean
    
    if tra_mean is not None:
        axes.plot(t_mean, tra_mean, 'b-', linewidth=2, label='mean trajectory')
    
    axes.plot(0, mag_list[i], 'ro', markersize=10)
    
    if i == 0:
        axes.legend()
        
plt.tight_layout()
plt.show()
plt.close()

In [None]:
ckpt_ours = {}
ckpt_naive = {}
patch_L = [8, 16, 32, 64]
L = 64
folder = '../checkpoints/'
for patch_Length in patch_L:
    for path in os.listdir(folder):
        if f'patch_L_{patch_Length}' in path and 'ours' in path:
            path = os.path.join(folder, path)
            ckpt_path = os.path.join(path, 'model.pt')
            model = torch.load(ckpt_path, map_location=device)
            ckpt_ours['patch_L_' + str(patch_Length)] = model

        elif f'patch_L_{patch_Length}' in path and 'naive' in path:
            path = os.path.join(folder, path)
            ckpt_path = os.path.join(path, 'model.pt')
            model = torch.load(ckpt_path, map_location=device)
            ckpt_naive['patch_L_' + str(patch_Length)] = model

In [None]:
predict_ours = {}
predict_naive = {}

for key_box, model in ckpt_ours.items():
    for key_mag, true_traj in test_latent.items():
        predict_tra = model.predict(true_traj[:, :1], true_traj.shape[1], dt=mean_dt)
        predict_tra = predict_tra.detach().cpu().numpy()
        predict_ours[key_box + '_' + key_mag] = predict_tra

for key_box, model in ckpt_naive.items():
    for key_mag, true_traj in test_latent.items():
        predict_tra = model.predict(true_traj[:, :1], true_traj.shape[1], dt=mean_dt)
        predict_tra = predict_tra.detach().cpu().numpy()
        predict_naive[key_box + '_' + key_mag] = predict_tra

In [None]:
# ========== visualization ==============
fig = plt.figure(figsize=(30, 5))
for i in range(len(mag_list)):
    axes = fig.add_subplot(1, len(mag_list), i + 1)
    axes.set_title(f'mag = {mag_list[i]}', fontsize=20)
    if i == 0:
        axes.set_xlabel('time', fontsize=20)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    axes.set_ylim(-1, 1)
    # axes.set_xlim(0, 100)
    tra = test_latent['mag_' + str(mag_list[i])].cpu().detach().numpy()
    t = test_time['mag_' + str(mag_list[i])]
    for j in range(tra.shape[0]):
        axes.plot(t[j, :], tra[j, :])
    tra_pred = predict_ours['patch_L_8_mag_' + str(mag_list[i])]
    for j in range(tra_pred.shape[0]):
        axes.plot(t_mean, tra_pred[j, :, 0], 'b', alpha=0.01)
    axes.plot(0, mag_list[i], 'ro', markersize=10)
plt.tight_layout()
plt.show()
plt.close()

In [None]:
# ========== visualization ==============
fig = plt.figure(figsize=(30, 20))
for i in range(len(patch_L)):
    for j in range(len(mag_list)):
        axes = fig.add_subplot(len(patch_L), len(mag_list), i * len(mag_list) + j + 1)
        if i == 0:
            axes.set_title(f'mag = {mag_list[j]}', fontsize=20)
        if j == 0:
            axes.set_ylabel(f'patch_L = {patch_L[i]}', fontsize=20)
        plt.xticks(fontsize=20)
        plt.yticks(fontsize=20)
        axes.set_ylim(-1, 1)
        # axes.set_xlim(0, 100)
        pred_tra_ours = predict_ours['patch_L_' + str(patch_L[i]) + '_' + 'mag_' + str(mag_list[j])]
        pred_tra_naive = predict_naive['patch_L_' + str(patch_L[i]) + '_' + 'mag_' + str(mag_list[j])]

        true_mean = test_latent_mean['mag_' + str(mag_list[j])]
        pred_mean_ours = np.mean(pred_tra_ours, axis=0)
        pred_mean_naive = np.mean(pred_tra_naive, axis=0)

        axes.plot(t_mean, true_mean, 'r-', label='true', linewidth=2)
        axes.plot(t_mean, pred_mean_ours, 'b-', label='ours', linewidth=2)
        axes.plot(t_mean, pred_mean_naive, 'g-', label='naive', linewidth=2)

        axes.plot(0, mag_list[j], 'ro', markersize=10)
plt.tight_layout()
plt.legend(loc='lower right', fontsize=20)
plt.show()
plt.close()

In [None]:
# ========== visualization ==============
j = 3
fig = plt.figure(figsize=(32, 8), dpi=300)
for i in range(len(patch_L)):
    axes = fig.add_subplot(1, len(patch_L), i + 1)
    if i == 0:
        axes.set_xlabel('time', fontsize=40)
    axes.set_title(f'patch_L = {patch_L[3-i]}', fontsize=40)
    plt.xticks(fontsize=40)
    plt.yticks(fontsize=40)
    axes.set_ylim(-1, 1)
    # axes.set_xlim(0, 100)
    
    pred_tra_ours = predict_ours['patch_L_' + str(patch_L[3-i]) + '_' + 'mag_' + str(mag_list[j])]
    pred_tra_naive = predict_naive['patch_L_' + str(patch_L[3-i]) + '_' + 'mag_' + str(mag_list[j])]

    true_mean = test_latent_mean['mag_' + str(mag_list[j])]
    pred_mean_ours = np.mean(pred_tra_ours, axis=0)
    pred_mean_naive = np.mean(pred_tra_naive, axis=0)

    axes.plot(t_mean, true_mean, 'tab:green', label='true', linewidth=6)
    axes.plot(t_mean, pred_mean_ours, 'tab:orange', label='our', linewidth=6)
    axes.plot(t_mean, pred_mean_naive, 'tab:blue', label='naive', linewidth=6)

    axes.plot(0, mag_list[j], 'ro', markersize=10)
plt.tight_layout()
plt.legend(loc='lower right', fontsize=40)
plt.show()
plt.close()

In [None]:
# Save test_latent
test_latent_np = {k: v.cpu().numpy() for k, v in test_latent.items()}
with open('test_latent.pkl', 'wb') as f:
    pickle.dump(test_latent_np, f)

# Save predict_ours
with open('predict_ours.pkl', 'wb') as f:
    pickle.dump(predict_ours, f)

# Save predict_naive
with open('predict_naive.pkl', 'wb') as f:
    pickle.dump(predict_naive, f)