In [2]:
import pandas as pd
import numpy as np
import torch

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from collections import OrderedDict

In [None]:
cd /home/CAMPUS/hdasari/apebench_experiments/mse_experiments

In [None]:
exp_test_path = '/home/CAMPUS/hdasari/apebench_experiments/ks_2d/data/KS_2d_test_data_exp1.npy'

exp_test = np.load(exp_test_path)
exp_test = np.expand_dims(exp_test, axis=0)
exp_test.shape

In [5]:
from torch.utils.data import Dataset

class KSTrajectoryDataset(Dataset):
    def __init__(self, ks_array):
        self.inputs = []

        num_experiments, num_sims, time_steps, _, _,_ = ks_array.shape

        for exp in range(num_experiments):
            for sim in range(num_sims):
                x_seq = ks_array[exp, sim, :, 0, :, :]  # shape: (time_steps, spatial_dim)
                self.inputs.append(torch.tensor(x_seq, dtype=torch.float32))

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx]  # shape: (time_steps, spatial_dim)


In [None]:
from vanilla_2d.vanilla_unet2d import UNet2d
from extrusion_2d.src_codes.models.primary_func import PrimaryNetwork
device = "cuda" if torch.cuda.is_available() else "cpu"

unet_1d_weights_path = '/home/CAMPUS/hdasari/apebench_experiments/mse_experiments/vanilla_1d/checkpoints/new_june18_2_mse_epoch_20_unet_1d_weights_biases.pth'
model = PrimaryNetwork(unet_1d_weights_path=unet_1d_weights_path, device=device).to(device)
checkpoint_hyper = torch.load('/home/CAMPUS/hdasari/apebench_experiments/mse_experiments/extrusion_2d/checkpoints/new_check_june18/model_epoch_20.pth')
model.load_state_dict(checkpoint_hyper["model_state_dict"])
model.eval()

model_unet = UNet2d().to(device)
checkpoint_path = torch.load('/home/CAMPUS/hdasari/apebench_experiments/mse_experiments/vanilla_2d/checkpoints/new_check_june18_5/model_epoch_20.pth')

model_unet.load_state_dict(checkpoint_path['model_state_dict'])
model_unet.eval()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np

test_dataset = KSTrajectoryDataset(exp_test)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

avg_extrusion_Rmse = 0.0
avg_unet_Rmse = 0.0
k = 1
with torch.no_grad():
    for batch in test_loader:
        traj = batch[0].to(device)  # shape: (T, H, W)
        time_steps, H, W = traj.shape

        predictions = []
        predictions_unet = []
        input_ar = traj[0].unsqueeze(0).unsqueeze(0)  # shape: (1, 1, H, W)
        # unet_input = traj[0].unsqueeze(0).unsqueeze(0).to(device_unet)  # shape: (1, 1, H, W)
        input_ar_unet = traj[0].unsqueeze(0).unsqueeze(0)  # shape: (1, 1, H, W)

        for t in range(time_steps - 1):
            output = model(input_ar)  # shape: (1, 1, H, W)
            output_unet = model_unet(input_ar_unet)
            # print(f"output_unet shape: {output_unet.shape}, output shape: {output.shape}")
            output = output.squeeze(0).squeeze(0)     # shape: (H, W)
            output_unet = output_unet.squeeze(0).squeeze(0)  # shape: (H, W)
            predictions.append(output.cpu().numpy())
            predictions_unet.append(output_unet.cpu().numpy())

            input_ar = output.unsqueeze(0).unsqueeze(0).detach()
            input_ar_unet = output_unet.unsqueeze(0).unsqueeze(0).detach()

        targets = traj[1:].cpu().numpy()        # shape: (T-1, H, W)
        predictions = np.stack(predictions)     # shape: (T-1, H, W)
        predictions_unet = np.stack(predictions_unet)  # shape: (T-1, H, W)

        vmin = min(predictions.min(), targets.min(), predictions_unet.min())
        vmax = max(predictions.max(), targets.max(), predictions_unet.max())

        data_target = targets       # shape: (time_steps, H, W)
        data_pred = predictions    # shape: (time_steps, H, W)
        data_diff = np.abs(data_target - data_pred)

        data_pred_unet = predictions_unet  # shape: (time_steps, H, W)
        data_diff_unet = np.abs(data_target - data_pred_unet)

        vmin_data = min(data_diff.min(), data_diff_unet.min())
        vmax_data = max(data_diff.max(), data_diff_unet.max())

        fig, axes = plt.subplots(2, 4, figsize=(24, 10))
        fig.suptitle("Extrusion-Unet2D auto regressive Predictions with diff_gamma,hyp_diff_gamma,gradient_norm_delta = {-1, -15, -6}", fontsize=16)

        # Setup initial images
        im_target = axes[0,0].imshow(data_target[0], cmap='plasma', origin='lower', vmin=vmin, vmax=vmax)
        axes[0,0].set_title("Ground Truth ")
        axes[0,0].set_xlabel("Width")
        axes[0,0].set_ylabel("Height")
        plt.colorbar(im_target, ax=axes[0,0])

        im_pred_Unet = axes[0,1].imshow(data_pred_unet[0], cmap='plasma', origin='lower', vmin=vmin, vmax=vmax)
        axes[0,1].set_title("Pred_UNet2D ")
        axes[0,1].set_xlabel("Width")
        axes[0,1].set_ylabel("Height")
        plt.colorbar(im_pred_Unet, ax=axes[0,1])
        

        im_pred_Extru_Unet = axes[1,1].imshow(data_pred[0], cmap='plasma', origin='lower', vmin=vmin, vmax=vmax)
        axes[1,1].set_title("Pred_ExtrusionUNet2D ")
        axes[1,1].set_xlabel("Width")
        axes[1,1].set_ylabel("Height")
        plt.colorbar(im_pred_Unet, ax=axes[1,1])
        

        im_unet_diff= axes[0,2].imshow(data_diff_unet[0], cmap='plasma', origin='lower', vmin=vmin_data, vmax=vmax_data)
        axes[0,2].set_title("Diff_Unet2D")
        axes[0,2].set_xlabel("Width")
        axes[0,2].set_ylabel("Height")
        plt.colorbar(im_unet_diff, ax=axes[0,2])


        im_extr_unet_diff= axes[1,2].imshow(data_diff[0], cmap='plasma', origin='lower', vmin=vmin_data, vmax=vmax_data)
        axes[1,2].set_title("Diff_ExtrusionUNet2D")
        axes[1,2].set_xlabel("Width")
        axes[1,2].set_ylabel("Height")
        plt.colorbar(im_extr_unet_diff, ax=axes[1,2])

        # axes[0,3].plot(np.mean(data_diff_unet,axis=(1,2)), label='Mean Abs Difference')
        # axes[0,3].set_title("line plot - Unet2D")
        # axes[0,3].set_xlabel("Time Steps")
        # axes[0,3].set_ylabel("Mean Abs Difference")
        axes[0,3].plot(np.mean(data_diff_unet, axis=(1,2)), label='UNet2D')
        axes[0,3].plot(np.mean(data_diff, axis=(1,2)), label='Extrusion UNet2D')
        axes[0,3].set_title("Mean Abs Difference Over Time")
        axes[0,3].set_xlabel("Time Steps")
        axes[0,3].set_ylabel("Mean Abs Difference")
        axes[0,3].legend()

        # Compute MAPE: mean(abs((pred - target) / (target + epsilon))) over (H, W)
        epsilon = 1e-8  # to avoid division by zero

        mape_unet = np.mean(np.abs((predictions_unet - targets) / (targets + epsilon)), axis=(1, 2)) * 100
        mape_extrusion = np.mean(np.abs((predictions - targets) / (targets + epsilon)), axis=(1, 2)) * 100

        # # Plot in axes[1, 3]
        # axes[1,3].plot(mape_unet, label='UNet2D')
        # axes[1,3].plot(mape_extrusion, label='Extrusion UNet2D')
        # axes[1,3].set_title("MAPE Over Time")
        # axes[1,3].set_xlabel("Time Steps")
        # axes[1,3].set_ylabel("MAPE (%)")
        # axes[1,3].legend()

        axes[1,0].axis('off')
        axes[1,3].axis('off')

        def update(frame):
            im_target.set_data(data_target[frame])
            axes[0,0].set_title(f"Targets - Time Step {frame}")
            
            im_pred_Unet.set_data(data_pred_unet[frame])
            axes[0,1].set_title(f"Pred_UNet2D - Time Step {frame}")
            
            im_pred_Extru_Unet.set_data(data_pred[frame])
            axes[1,1].set_title(f"Pred_ExtrusionUNet2D  - Time Step {frame}")

            im_unet_diff.set_data(data_diff_unet[frame])
            axes[0,2].set_title(f"Diff_Unet2D - Time Step {frame}")

            im_extr_unet_diff.set_data(data_diff[frame])
            axes[1,2].set_title(f"Diff_ExtrusionUNet2D - Time Step {frame}")
            
            return im_target, im_pred_Unet, im_pred_Extru_Unet, im_unet_diff, im_extr_unet_diff

        ani = animation.FuncAnimation(
            fig, update, frames=data_target.shape[0], interval=200, blit=False
        )

        # ani.save("/home/CAMPUS/hdasari/apebench_experiments/mse_experiments/extrusion_2d/results/new_results_june18/exct_unet_2d_ks_exp_animation_three_views.gif", writer='pillow')

        plt.show()
        # rmse_unet = compute_rmse(predictions_unet, targets)
        # rmse_extr_unet = compute_rmse(predictions, targets)
        # avg_extrusion_Rmse += rmse_extr_unet
        # avg_unet_Rmse += rmse_unet
        # break  # Uncomment to process only 1 batch

# print(f"Average RMSE for Extrusion UNet2D: {avg_extrusion_Rmse / k}")
# print(f"Average RMSE for UNet2D: {avg_unet_Rmse / k}")
