In [1]:
import os
import sys

import math
import time
import datetime
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR
from tcunet import Unet2D
# from YourDataset import YourDataset  # Import your custom dataset here
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
from torchinfo import summary

torch.manual_seed(23)

import pickle

scaler = GradScaler()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


  scaler = GradScaler()


In [2]:
# Define your custom loss function here
class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, y_pred, y_true, Par, Lambda=None):
        # Implement your custom loss calculation here
        if Lambda is not None:
            residue = torch.absolute(y_true - y_pred)
            Lambda = Par['gamma']*Lambda + Par['eta']*residue/torch.max(residue)
            # loss = torch.mean(torch.square(Lambda*residue)) 
            loss = torch.norm(Lambda*residue, p=2)/torch.norm(y_true, p=2)
        
        else:
            # loss = torch.mean(torch.square(y_true - y_pred)) 
            loss = torch.norm(y_true-y_pred, p=2)/torch.norm(y_true, p=2)

        return loss, Lambda

class YourDataset(Dataset):
    def __init__(self, x, t, y, transform=None):
        self.x = x
        self.t = t
        self.y = y
        self.transform = transform

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x_sample = self.x[idx]
        t_sample = self.t[idx]
        y_sample = self.y[idx]

        if self.transform:
            x_sample, t_sample, y_sample = self.transform(x_sample, t_sample, y_sample)

        return x_sample, t_sample, y_sample

class YourDataset_L(Dataset):
    def __init__(self, x, t, y, Lambda, transform=None):
        self.x = x
        self.t = t
        self.y = y
        self.Lambda = Lambda
        self.transform = transform

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x_sample = self.x[idx]
        t_sample = self.t[idx]
        y_sample = self.y[idx]
        Lambda_sample = self.Lambda[idx]

        if self.transform:
            x_sample, t_sample, y_sample = self.transform(x_sample, t_sample, y_sample)

        return x_sample, t_sample, y_sample, Lambda_sample, idx

def preprocess(traj, Par):
    x = sliding_window_view(traj[:,:-(Par['lf']-1),:,:], window_shape=Par['lb'], axis=1 ).transpose(0,1,4,2,3).reshape(-1,Par['lb'],Par['nx'], Par['ny'])
    y = sliding_window_view(traj[:,Par['lb']-1:,:,:], window_shape=Par['lf'], axis=1 ).transpose(0,1,4,2,3).reshape(-1,Par['lf'],Par['nx'], Par['ny'])
    t = np.linspace(0,1,Par['lf']).reshape(-1,1)

    nt = y.shape[1]
    n_samples = y.shape[0]

    t = np.tile(t, [n_samples,1]).reshape(-1,)                                                     #[_*nt, ]
    x = np.repeat(x,nt, axis=0)                                   #[_*nt, 1, 64, 64]
    y = y.reshape(y.shape[0]*y.shape[1],1,y.shape[2],y.shape[3])  #[_*nt, 64, 64]


    print('x: ', x.shape)
    print('y: ', y.shape)
    print('t: ', t.shape)
    print()
    return x,y,t

def get_flat_gradients(param_tensors):
    grad_list = []
    for p in param_tensors:
        if p.grad is not None:
            grad_list.append(p.grad.view(-1))
    flat_gradients = torch.cat(grad_list)
    return flat_gradients

def get_snr(L_theta_ls):
    L_theta = np.concatenate(L_theta_ls, axis=0) #[NB, W]
    L_theta = np.nan_to_num(L_theta, nan=0.0, posinf=1e12, neginf=-1e12)
    mu  = np.mean(L_theta, axis=0) #[W,]
    sig = np.std(L_theta, axis=0)  #[W,]
    NUM = np.linalg.norm(mu, ord=2)
    DEN = np.linalg.norm(sig, ord=2)
    snr = NUM/DEN

    # Save MU and STD as well!
    # Do not use GradScaler for calculating SNR
    # Set gamma as gamma = 1 - eta

    if np.isnan(L_theta).any():
        print(f"Warning: NaN detected in gradients at L_theta")
    if np.isinf(L_theta).any():
        print(f"Warning: inf detected in gradients at L_theta")

    return snr, NUM, DEN


In [3]:
debug = False
# Load your data into NumPy arrays (x_train, t_train, y_train, x_val, t_val, y_val, x_test, t_test, y_test)
#########################
x = np.load('../data/x.npy')  #[_, 64, 64]
t = np.load('../data/t.npy')  #[_, 200]
y = np.load('../data/y.npy')  #[_, 64, 64, 200]

if debug:
    x = x[:100]
    y = y[:100]

idx1 = int(0.8 * x.shape[0])
idx2 = int(0.9 * x.shape[0])

traj = np.append( np.expand_dims(x, axis=-1), y, axis=-1 ).transpose(0,3,1,2) #[_, 64, 64, 201]

traj_train = traj[:idx1, ::4]
traj_val   = traj[idx1:idx2, ::4]
traj_test  = traj[idx2:, ::4]

Par = {}
# Par['nt'] = 100 
Par['nx'] = traj_train.shape[2]
Par['ny'] = traj_train.shape[3]
Par['nf'] = 1
Par['d_emb'] = 128

Par['lb'] = 1
Par['lf'] = 51 
# Par['temp'] = Par['nt'] - Par['lb'] - Par['lf'] + 2
Par['num_epochs'] = 500
if debug:
    Par['num_epochs'] = 5 

print('\nTrain Dataset')
x_train, y_train, t_train = preprocess(traj_train, Par)
print('\nValidation Dataset')
x_val, y_val, t_val  = preprocess(traj_val, Par)
print('\nTest Dataset')
x_test, y_test, t_test  = preprocess(traj_test, Par)

t_min = np.min(t_train)
t_max = np.max(t_train)

Par['inp_shift'] = np.mean(x_train) 
Par['inp_scale'] = np.std(x_train)
Par['out_shift'] = np.mean(y_train)
Par['out_scale'] = np.std(y_train)
Par['t_shift']   = t_min
Par['t_scale']   = t_max - t_min

# Par['eta']   = 0.1 #use 0.01 so that eta = 1-gamma
Par['gamma'] = 0.99
Par['eta']   = 1 - Par['gamma']

Par['do_rba']  = False
Par['get_snr'] = True

Par['Lambda_max'] = Par['eta']/(1 - Par['gamma'])

Lambda = np.ones(y_train.shape, dtype=np.float32)*Par['Lambda_max']/2.0
print("Lambda: ", Lambda.shape)

print("Par: \n", Par)

with open('Par.pkl', 'wb') as f:
    pickle.dump(Par, f)

# sys.exit()
#########################

# Create custom datasets
x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
t_train_tensor = torch.tensor(t_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
Lambda_tensor  = torch.tensor(Lambda, dtype=torch.float32)

x_val_tensor   = torch.tensor(x_val,   dtype=torch.float32)
t_val_tensor   = torch.tensor(t_val,   dtype=torch.float32)
y_val_tensor   = torch.tensor(y_val,   dtype=torch.float32)

x_test_tensor  = torch.tensor(x_test,  dtype=torch.float32)
t_test_tensor  = torch.tensor(t_test,  dtype=torch.float32)
y_test_tensor  = torch.tensor(y_test,  dtype=torch.float32)

train_dataset = YourDataset_L(x_train_tensor, t_train_tensor, y_train_tensor, Lambda_tensor)
val_dataset = YourDataset(x_val_tensor, t_val_tensor, y_val_tensor)
test_dataset = YourDataset(x_test_tensor, t_test_tensor, y_test_tensor)

# Define data loaders
train_batch_size = 100
val_batch_size   = 100
test_batch_size  = 100
train_loader = DataLoader(train_dataset, batch_size=train_batch_size)
val_loader = DataLoader(val_dataset, batch_size=val_batch_size)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size)


Train Dataset
x:  (40800, 1, 64, 64)
y:  (40800, 1, 64, 64)
t:  (40800,)


Validation Dataset
x:  (5100, 1, 64, 64)
y:  (5100, 1, 64, 64)
t:  (5100,)


Test Dataset
x:  (5100, 1, 64, 64)
y:  (5100, 1, 64, 64)
t:  (5100,)

Lambda:  (40800, 1, 64, 64)
Par: 
 {'nx': 64, 'ny': 64, 'nf': 1, 'd_emb': 128, 'lb': 1, 'lf': 51, 'num_epochs': 500, 'inp_shift': np.float64(0.007669870189599345), 'inp_scale': np.float64(0.061450183570653995), 'out_shift': np.float64(-0.001635048307912137), 'out_scale': np.float64(0.03807830166415873), 't_shift': np.float64(0.0), 't_scale': np.float64(1.0), 'gamma': 0.99, 'eta': 0.010000000000000009, 'do_rba': False, 'get_snr': True, 'Lambda_max': 1.0}


In [4]:
# Initialize your Unet2D model
model = Unet2D(dim=16, Par=Par, dim_mults=(1, 2, 4, 8)).to(device).to(torch.float32)
print(summary(model, input_size=((1,)+x_train.shape[1:], (1,))  ))

path_model = 'models/best_model.pt'
model.load_state_dict(torch.load(path_model))

# Define loss function and optimizer
criterion = CustomLoss()

Layer (type:depth-idx)                                  Output Shape              Param #
Unet2D                                                  [1, 1, 64, 64]            --
├─Conv2d: 1-1                                           [1, 16, 64, 64]           2,368
├─Sequential: 1-2                                       [1, 64]                   --
│    └─SinusoidalPosEmb: 2-1                            [1, 16]                   --
│    └─Linear: 2-2                                      [1, 64]                   1,088
│    └─GELU: 2-3                                        [1, 64]                   --
│    └─Linear: 2-4                                      [1, 64]                   4,160
├─ModuleList: 1-3                                       --                        --
│    └─ModuleList: 2-5                                  --                        --
│    │    └─ResnetBlock: 3-1                            [1, 16, 64, 64]           6,784
│    │    └─ResnetBlock: 3-2                    

In [5]:
model.eval()
val_loss = 0.0
with torch.no_grad():
    for x, t, y_true in val_loader:
        with autocast():
            y_pred = model(x.to(device), t.to(device))
            loss, _   = criterion(y_pred, y_true.to(device), Par)
            val_loss += loss.item()

val_loss /= len(val_loader)
print(f'Val Loss: {val_loss:.4e}')

  with autocast():


Val Loss: 3.6775e-02


In [6]:
y_true_ls = []
y_pred_ls = []

# Testing loop
model.eval()
test_loss = 0.0
with torch.no_grad():
    for x, t, y_true in test_loader:
        with autocast():
            y_pred = model(x.to(device), t.to(device))
            loss, _ = criterion(y_pred, y_true.to(device), Par)
        
            y_true_ls.append(y_true)
            y_pred_ls.append(y_pred)
            test_loss += loss.item()

test_loss /= len(test_loader)
print(f'Test Loss: {test_loss:.4e}')

  with autocast():


Test Loss: 3.6934e-02


In [7]:
Y_TRUE = torch.cat(y_true_ls, axis=0).reshape(-1,51,64,64).detach().cpu().numpy()
Y_PRED = torch.cat(y_pred_ls, axis=0).reshape(-1,51,64,64).detach().cpu().numpy()

print('true: ', Y_TRUE.shape)
print('pred: ', Y_PRED.shape)

true:  (100, 51, 64, 64)
pred:  (100, 51, 64, 64)


In [8]:
np.save("Y_TRUE.npy", Y_TRUE)
np.save("Y_PRED.npy", Y_PRED)