In [None]:
import torch
import numpy as np
import sys,os
import matplotlib.pyplot as plt
sys.path.append('..')
from utils.utils import set_seed
import warnings
from scipy.interpolate import interp1d
warnings.filterwarnings("ignore")
import pickle

In [None]:
patch_L = [8, 16, 32, 64]
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
set_seed(42)
mag_list = [-0.75, -0.5, -0.25, 0.0, 0.25, 0.5, 0.75]

AE_models = {}
for patch_Length in patch_L:
    for path in os.listdir('../checkpoints'): 
        if f'patch_L_{patch_Length}' in path and 'AE' in path:
            save_path = os.path.join('../checkpoints', path)
            ckpt_path = os.path.join(save_path, 'model.pt')
            model = torch.load(ckpt_path, map_location=device)
            AE_models['patch_L_' + str(patch_Length)] = model
            print(f'Load model from {ckpt_path}')
print(f'Load {len(AE_models.keys())} models from ../checkpoints')

In [None]:
test_latent = {}
test_time = {}
for mag in mag_list:
    data = torch.load(f'../raw_data/L64_MC500_h0.1_T2.50_mag{mag}/X0_test.pt', map_location=device).unsqueeze(2)
    kmc_time = np.load(f'../raw_data/L64_MC500_h0.1_T2.50_mag{mag}/kmc_times.npy')
    test_time[f'mag_{str(mag)}'] = kmc_time
    for patch_Length in patch_L:
        model = AE_models['patch_L_' + str(patch_Length)]
        model.eval()
        latent = []
        with torch.no_grad():
            for i in range(data.shape[0]):
                z = model.encoder(data[i].to(torch.float32))
                latent.append(z)
        latent = torch.stack(latent)
        latent = (latent - model.encoder.min_val) / (model.encoder.max_val - model.encoder.min_val)
        test_latent[f'patch_L_{patch_Length}_mag_{str(mag)}'] = latent

In [None]:
plot_L = 16

In [None]:
# ========== visualization ==============
fig = plt.figure(figsize=(30, 5))
for i in range(len(mag_list)):
    axes = fig.add_subplot(1, len(mag_list), i + 1)
    axes.set_title(f'mag = {mag_list[i]}', fontsize=20)
    if i == 0:
        axes.set_xlabel('time', fontsize=20)
    # plt.xticks(fontsize=20)
    # plt.yticks(fontsize=20)
    axes.set_ylim(-1, 1)
    tra = test_latent[f'patch_L_{plot_L}_mag_' + str(mag_list[i])].cpu().detach().numpy()
    t = test_time['mag_' + str(mag_list[i])]
    for j in range(tra.shape[0]):
        axes.plot(t[j, :], tra[j, :, 0])
    axes.plot(0, mag_list[i], 'ro', markersize=10)
plt.tight_layout()
plt.show()
plt.close()

In [None]:
val_dt = torch.load(f'../raw_data/L64_MC200_h0.1_T2.50/time_step_val.pt', map_location=device)
mean_dt = torch.mean(val_dt)
mean_dt

In [None]:
t_mean = np.arange(0, 500) * mean_dt.item()

In [None]:
def calculate_trajectory_mean(tra_np, t_np, t_mean):

    tra_interpolated = []
    for i in range(tra_np.shape[0]):
        t_traj = t_np[i]
        tra_traj = tra_np[i]  # Shape: (500, 4) for multi-dimensional case
        
        # Remove NaN values - check NaN for time and any dimension of trajectory
        if tra_traj.ndim == 1:
            # Original 1D case
            valid_mask = ~np.isnan(tra_traj) & ~np.isnan(t_traj)
        else:
            # Multi-dimensional case: check if any dimension has NaN
            valid_mask = ~np.isnan(t_traj) & ~np.any(np.isnan(tra_traj), axis=1)
            
        if np.sum(valid_mask) < 2:
            continue
            
        t_traj = t_traj[valid_mask]
        tra_traj = tra_traj[valid_mask]
        
        # Sort by time
        sort_idx = np.argsort(t_traj)
        t_traj = t_traj[sort_idx]
        tra_traj = tra_traj[sort_idx]
        
        # Interpolate to common grid
        if len(t_traj) >= 2:
            if tra_traj.ndim == 1:
                # Original 1D case
                f = interp1d(t_traj, tra_traj, kind='linear', 
                           bounds_error=False, fill_value='extrapolate')
                tra_interp = f(t_mean)
            else:
                # Multi-dimensional case: interpolate each dimension separately
                tra_interp = np.zeros((len(t_mean), tra_traj.shape[1]))
                for dim in range(tra_traj.shape[1]):
                    f = interp1d(t_traj, tra_traj[:, dim], kind='linear', 
                               bounds_error=False, fill_value='extrapolate')
                    tra_interp[:, dim] = f(t_mean)
            
            tra_interpolated.append(tra_interp)
    
    if len(tra_interpolated) == 0:
        return None
        
    tra_interpolated = np.array(tra_interpolated)
    tra_mean = np.mean(tra_interpolated, axis=0)
    
    return tra_mean

In [None]:
test_latent_mean = {}
# ========== visualization ==============
for mag in mag_list:
    for patch_Length in patch_L:
        tra = test_latent[f'patch_L_{patch_Length}_mag_' + str(mag)].cpu().detach().numpy()
        t = test_time[f'mag_{str(mag)}']
        # Calculate trajectory mean
        tra_mean = calculate_trajectory_mean(tra, t, t_mean)
        test_latent_mean[f'patch_L_{patch_Length}_mag_{str(mag)}'] = tra_mean

In [None]:
# ========== visualization ==============
fig = plt.figure(figsize=(30, 5))
for i in range(len(mag_list)):

    axes = fig.add_subplot(1, len(mag_list), i + 1)
    axes.set_title(f'mag = {mag_list[i]}', fontsize=20)
    if i == 0:
        axes.set_xlabel('time', fontsize=20)
    # plt.xticks(fontsize=20)
    # plt.yticks(fontsize=20)
    axes.set_ylim(-1, 1)
    
    t = test_time['mag_' + str(mag_list[i])]    
    tra_mean = test_latent_mean[f'patch_L_{plot_L}_mag_{str(mag_list[i])}']

    axes.plot(t_mean, tra_mean[:, 0], 'b-', linewidth=2, label='mean trajectory')
    axes.plot(0, mag_list[i], 'ro', markersize=10)
    
    if i == 0:
        axes.legend()
        
plt.tight_layout()
plt.show()
plt.close()

In [None]:
ckpt_ours = {}
ckpt_naive = {}
# patch_L = [8, 16, 32]
L = 64
folder = '../checkpoints/'
for patch_Length in patch_L:
    for path in os.listdir(folder):
        if f'patch_L_{patch_Length}' in path and 'ours' in path:
            path = os.path.join(folder, path)
            ckpt_path = os.path.join(path, 'model.pt')
            model = torch.load(ckpt_path, map_location=device)
            ckpt_ours['patch_L_' + str(patch_Length)] = model

        elif f'patch_L_{patch_Length}' in path and 'naive' in path:
            path = os.path.join(folder, path)
            ckpt_path = os.path.join(path, 'model.pt')
            model = torch.load(ckpt_path, map_location=device)
            ckpt_naive['patch_L_' + str(patch_Length)] = model

In [None]:
predict_ours = {}
predict_naive = {}

for mag in mag_list:
    for patch_Length in patch_L:
        ckpt = ckpt_ours[f'patch_L_{patch_Length}']
        true_traj = test_latent[f'patch_L_{patch_Length}_mag_{str(mag)}']
        predict_tra = ckpt.predict(true_traj[:, 0], true_traj.shape[1], dt=torch.tensor(1., device=device))
        predict_tra = predict_tra.detach().cpu().numpy()
        predict_ours[f'patch_L_{patch_Length}_mag_{str(mag)}'] = predict_tra

for mag in mag_list:
    for patch_Length in patch_L:
        ckpt = ckpt_naive[f'patch_L_{patch_Length}']
        true_traj = test_latent[f'patch_L_{patch_Length}_mag_{str(mag)}']
        predict_tra = ckpt.predict(true_traj[:, 0], true_traj.shape[1], dt=torch.tensor(1., device=device))
        predict_tra = predict_tra.detach().cpu().numpy()
        predict_naive[f'patch_L_{patch_Length}_mag_{str(mag)}'] = predict_tra

In [None]:
# ========== visualization ==============
fig = plt.figure(figsize=(30, 5))
for i in range(len(mag_list)):
    axes = fig.add_subplot(1, len(mag_list), i + 1)
    axes.set_title(f'mag = {mag_list[i]}', fontsize=20)
    if i == 0:
        axes.set_xlabel('time', fontsize=20)
    axes.set_ylim(-1, 1)

    tra = test_latent[f'patch_L_{plot_L}_mag_' + str(mag_list[i])].cpu().detach().numpy()
    t = test_time['mag_' + str(mag_list[i])]
    for j in range(tra.shape[0]):
        axes.plot(t[j, :], tra[j, :, 0], 'r', alpha=0.1)
    tra_pred = predict_ours[f'patch_L_{plot_L}_mag_' + str(mag_list[i])]
    for j in range(tra_pred.shape[0]):
        axes.plot(t_mean, tra_pred[j, :, 0], 'b', alpha=0.1)
    axes.plot(0, mag_list[i], 'ro', markersize=10)
plt.tight_layout()
plt.show()
plt.close()

In [None]:
# ========== visualization ==============
fig = plt.figure(figsize=(30, 5))
for i in range(len(mag_list)):
    axes = fig.add_subplot(1, len(mag_list), i + 1)
    axes.set_title(f'mag = {mag_list[i]}', fontsize=20)
    if i == 0:
        axes.set_xlabel('time', fontsize=20)
    axes.set_ylim(-1, 1)

    tra = test_latent[f'patch_L_{plot_L}_mag_' + str(mag_list[i])].cpu().detach().numpy()
    t = test_time['mag_' + str(mag_list[i])]
    for j in range(tra.shape[0]):
        axes.plot(t[j, :], tra[j, :, 1], 'r', alpha=0.1)
    tra_pred = predict_ours[f'patch_L_{plot_L}_mag_' + str(mag_list[i])]
    for j in range(tra_pred.shape[0]):
        axes.plot(t_mean, tra_pred[j, :, 1], 'b', alpha=0.1)
plt.tight_layout()
plt.show()
plt.close()

In [None]:
# ========== visualization ==============
j = 0
fig = plt.figure(figsize=(32, 8), dpi=300)
for i in range(len(patch_L)):
    axes = fig.add_subplot(1, len(patch_L), i + 1)
    if i == 0:
        axes.set_xlabel('time', fontsize=40)
    axes.set_title(f'patch_L = {patch_L[3-i]}', fontsize=40)
    # plt.xticks(fontsize=40)
    # plt.yticks(fontsize=40)
    axes.set_ylim(-1, 1)
    # axes.set_xlim(0, 100)
    
    pred_tra_ours = predict_ours['patch_L_' + str(patch_L[3-i]) + '_' + 'mag_' + str(mag_list[j])]
    pred_tra_naive = predict_naive['patch_L_' + str(patch_L[3-i]) + '_' + 'mag_' + str(mag_list[j])]

    true_mean = test_latent_mean[f'patch_L_{plot_L}_mag_' + str(mag_list[j])]
    pred_mean_ours = np.mean(pred_tra_ours, axis=0)
    pred_mean_naive = np.mean(pred_tra_naive, axis=0)

    axes.plot(t_mean, true_mean[:, 0], 'tab:green', label='true', linewidth=6)
    axes.plot(t_mean, pred_mean_ours[:, 0], 'tab:orange', label='our', linewidth=6)
    axes.plot(t_mean, pred_mean_naive[:, 0], 'tab:blue', label='naive', linewidth=6)

    axes.plot(0, mag_list[j], 'ro', markersize=10)
plt.tight_layout()
plt.legend(loc='lower right', fontsize=40)
plt.show()
plt.close()

In [None]:
# Save test_latent
test_latent_np = {k: v.cpu().numpy() for k, v in test_latent.items()}
with open('test_latent.pkl', 'wb') as f:
    pickle.dump(test_latent_np, f)

# Save predict_ours
with open('predict_ours.pkl', 'wb') as f:
    pickle.dump(predict_ours, f)

# Save predict_naive
with open('predict_naive.pkl', 'wb') as f:
    pickle.dump(predict_naive, f)

In [None]:
# ========== calculate error ==============
mse_ours_list = []
mse_naive_list = []

# mmd_ours_list = []
# mmd_naive_list = []
for box_i in patch_L:
    print('patch_L:', box_i)
    print('mag_list:', mag_list)
    mse_ours_list_mag = []
    mse_naive_list_mag = []

    # mmd_ours_list_mag = []
    # mmd_naive_list_mag = []
    for mag_j in mag_list:
        
        true_tra = test_latent['patch_L_' + str(box_i) + '_' + 'mag_' + str(mag_j)].cpu().detach().numpy()
        pred_tra_ours = predict_ours['patch_L_' + str(box_i) + '_' + 'mag_' + str(mag_j)]
        pred_tra_naive = predict_naive['patch_L_' + str(box_i) + '_' + 'mag_' + str(mag_j)]

        true_mean = test_latent_mean['patch_L_' + str(box_i) + '_mag_' + str(mag_j)]
        pred_mean_ours = np.mean(pred_tra_ours, axis=0)
        pred_mean_naive = np.mean(pred_tra_naive, axis=0)

        mse_ours_loss = np.mean((pred_mean_ours - true_mean) ** 2) / np.mean(true_mean ** 2)
        mse_naive_loss = np.mean((pred_mean_naive - true_mean) ** 2) / np.mean(true_mean ** 2)

        mse_ours_list_mag.append(mse_ours_loss)
        mse_naive_list_mag.append(mse_naive_loss)

    mse_ours_list.append(np.mean(mse_ours_list_mag))
    mse_naive_list.append(np.mean(mse_naive_list_mag))

In [None]:
# ========== Save data to CSV file ==============
import pandas as pd

# Prepare data for saving
x_labels = [64, 32, 16, 8]
data_to_save = {
    'patch_side_length': x_labels,
    'mse_ours': mse_ours_list[::-1],  # Reverse to match x_labels order
    'mse_naive': mse_naive_list[::-1],  # Reverse to match x_labels order
}

# Create DataFrame
df = pd.DataFrame(data_to_save)

# Save to CSV
csv_filename = 'ising_results.csv'
df.to_csv(csv_filename, index=False)
print(f"Data saved to {csv_filename}")

# Display the saved data
print("\nSaved data:")
print(df)
