In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import json
import xarray
from pathlib import Path
import dask
from tqdm import tqdm

## Preprocess

In [None]:
prefix = '/g/data/x77/jm0124/ocean'

def sst():
    prefix = '/g/data/x77/jm0124/ocean'
    #path = '/home/156/cn1951/kae-cyclones/input/sstday.npy'
    path = '/g/data/x77/jm0124/ocean/sstday.npy'
    X = np.load(path)
    t, m, n = X.shape
    indices = range(3600)
    training_idx, valid_idx, test_idx = indices[220:1315], indices[1315:2557], indices[2557:2922] # 6 years
    
    # mean subtract
    X = X.reshape(-1,m*n)
    X -= X.mean(axis=0)    
    
    # scale 
    X = X.reshape(-1,m*n)
    X = 2 * (X - np.min(X)) / np.ptp(X) - 1
    X = X.reshape(-1,m,n) 
    
    # split into train, valid and test set
    
    X_train = X[training_idx]  
    X_valid = X[valid_idx]
    X_test = X[test_idx]
    
    np.save(f'{prefix}/sstday_train.npy', X_train)
    np.save(f'{prefix}/sstday_valid.npy', X_valid)
    np.save(f'{prefix}/sstday_test.npy', X_test)

    return X_train, X_valid, X_test, m, n

In [None]:
sst()

In [None]:
path = f'{prefix}/sstday_train.npy'
X_train = np.load(path)

In [None]:
plt.imshow(X_train[1094], cmap='coolwarm')
plt.show()

## Generate dataset

In [6]:
class OceanToOcean(Dataset):
    def __init__(self, prediction_length, partition_name='train'):
        self.ocean_array = np.load(f"/home/156/cn1951/kae-cyclones/input/sstday_{partition_name}.npy")
        self.prediction_length = prediction_length
    
    def __len__(self):
        return self.ocean_array.shape[0]
    
    def __getitem__(self, idx):
        i = 0
        for ocean_run in self.ocean_array:
            j = self.prediction_length
            for time_step in ocean_run[self.prediction_length:-self.prediction_length]:
                if i == idx:
                    return torch.from_numpy(ocean_run[j-self.prediction_length:j+self.prediction_length]), torch.from_numpy(np.flip(ocean_run[j-self.prediction_length:j+self.prediction_length], 0).copy())
                j += 1
                i += 1

In [7]:
def generate_ocean_ds():
    train_ds = OceanToOcean(4, 'train')
    val_ds = OceanToOcean(4, 'valid')
    test_ds = OceanToOcean(4, 'test')

    return train_ds, val_ds, test_ds

In [None]:
train_ds, val_ds, test_ds = generate_ocean_ds()
loader = torch.utils.data.DataLoader(train_ds, batch_size=64, num_workers=8, pin_memory=True, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=64, num_workers=8, pin_memory=True, shuffle=True)
input_size = 2
alpha = 4
beta = 4
learning_rate = 1e-4

In [None]:
img, output = next(iter(loader))
img.shape

## Load trained model

In [8]:
from models import *

In [None]:
saved_models_path = '/home/156/cn1951/kae-cyclones/saved_models/ELEI-eigenloss_and_eigeninit-final.pt'

model = koopmanAE(b=16, steps=4, steps_back=4, alpha=16, eigen_init=True, input_size=150).to(0)
model.load_state_dict(torch.load(saved_models_path))
model.eval()

## Plotting results

In [9]:
import pandas as pd

In [None]:
# forward losses
path = '/home/156/cn1951/kae-cyclones/results/wandb_export_2022-07-05T20.csv'
df = pd.read_csv(path)
df.columns

In [None]:
regular = df['initial_run_ocean-ocean - forward loss'].to_numpy()
eigenloss = df['eigenloss_only-ocean - forward loss'].to_numpy()
eigeninit = df['eigen_init_vanilla-ocean - forward loss'].to_numpy()
both = df['eigenloss_and_eigeninit-ocean - forward loss'].to_numpy()

In [None]:
plt.loglog(df.Step, regular)
plt.loglog(df.Step, eigenloss)
plt.loglog(df.Step, eigeninit)
plt.loglog(df.Step, both)
plt.xlabel("Epoch")
plt.ylabel("Mean squared error (MSE)")
plt.show()