In [None]:
import numpy as np
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
import torch
from torch import nn
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import DataLoader, TensorDataset,Dataset
import os
import copy

In [None]:
data = np.load('../../CESM_data/ResNet50_data/CESM_data_sst_sss_psl_deseason_normalized_resized.npy')
target = np.load('../../CESM_data/ResNet50_data/CESM_label_amv_index.npy')

In [None]:
tstep = 86
percent_train = 0.95
ens = 40
channels = 3
batch_size = 32
max_epochs = 10

In [None]:
data = data[:,0:ens,:,:,:]
data.shape

In [None]:
target = target[0:ens,:]
target.shape

In [None]:
lead = 10

y = target[:,lead:].reshape(ens*(tstep-lead),1)
X = (data[:,:,:tstep-lead,:,:]).reshape(3,ens*(tstep-lead),244,244).transpose(1,0,2,3)

In [None]:
X_train = torch.from_numpy( X[0:int(np.floor(percent_train*(tstep-lead)*ens)),:,:,:].astype(np.float32) )

X_val = torch.from_numpy( X[int(np.floor(percent_train*(tstep-lead)*ens)):,:,:,:].astype(np.float32) )

y_train = torch.from_numpy(  y[0:int(np.floor(percent_train*(tstep-lead)*ens)),:].astype(np.float32)  )

y_val = torch.from_numpy( y[int(np.floor(percent_train*(tstep-lead)*ens)):,:].astype(np.float32)  )

In [None]:
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size)
val_loader   = DataLoader(TensorDataset(X_val, y_val), batch_size=batch_size)

In [None]:
model = models.resnet50(pretrained=True)

In [None]:
for param in model.parameters():
    param.requires_grad = False

In [None]:
model.fc = nn.Linear(2048, 1)

In [None]:
opt = torch.optim.Adam(model.parameters())
loss_fn = nn.MSELoss()

In [None]:
epo_train_losses = [] #[loss_fn(model(X_train),y_train)]
epo_val_losses = [] #[loss_fn(model(X_val),y_val)]


for iepoch in tqdm(range(max_epochs)):
    
    batch_train_losses = []
    model.train()
    for x_batch, y_batch in train_loader:
        y_pred = model(x_batch)
        loss = loss_fn(y_pred, y_batch)
        batch_train_losses.append(loss.item())
        loss.backward()
        opt.step()
        opt.zero_grad()
    epo_train_losses.append( sum(batch_train_losses)/len(batch_train_losses) )

    batch_val_losses = []
    with torch.set_grad_enabled(False):
        for x_batch_val, y_batch_val in val_loader:
            y_pred = model(x_batch_val)
            loss = loss_fn(y_pred, y_batch_val)
            batch_val_losses.append(loss.item())
        epo_val_losses.append( sum(batch_val_losses)/len(batch_val_losses) )

In [None]:
plt.plot(epo_train_losses)
plt.plot(epo_val_losses)

In [None]:
model.eval()

In [None]:
y_pred_val = model(X_val)
y_pred_train = model(X_train)

In [None]:
plt.plot(y_pred_val.detach().numpy()[:,0],y_val.detach().numpy()[:,0],'.' )
plt.plot(y_pred_train.detach().numpy()[:,0],y_train.detach().numpy()[:,0],'.' )

In [None]:
np.corrcoef( y_pred_val.detach().numpy()[:,0],y_val.detach().numpy()[:,0] )

In [None]:
np.corrcoef( y_pred_train.detach().numpy()[:,0],y_train.detach().numpy()[:,0] )