In [14]:
import os
import torch
from copy import deepcopy
import numpy as np
import xarray as xr
import pandas as pd
import torch.nn as nn
import random
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import netCDF4 as nc

In [15]:
def set_seed(seed = 427):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)

In [None]:
def data_enhance(df, data_type='feature'):
    if data_type=='feature':
        df_ = df
        for i in range(1, 12):
            dfi = df_.reshape(-1, 24, 72)
            dfi = dfi[i:-(12-i), :, :]
            dfi = dfi.reshape(-1, 12, 24, 72)
            df = np.concatenate((df, dfi), axis=0)
    elif data_type=='label':
        df_ = np.concatenate((df[:, :12], df[-1:, 12:24], df[-1:, 24:36]), axis=0)
        for i in range(1, 12):
            df_ = df_.reshape(-1)
            dfi1 = df_[i:-(36-i)].reshape(-1, 12)
            dfi2 = df_[i+12:-(24-i)].reshape(-1, 12)
            dfi3 = df_[i+24:-(12-i)].reshape(-1, 12)
            dfi = np.concatenate((dfi1, dfi2, dfi3), axis=1)
            df = np.concatenate((df, dfi), axis=0)
    return df

In [16]:
def load_data2():
    # CMIP data    
    train = xr.open_dataset('../data/enso_round1_train_20210201/CMIP_train.nc')
    label = xr.open_dataset('../data/enso_round1_train_20210201/CMIP_label.nc')    
   
    train_sst = data_enhance(train['sst'][:, :12].values)  # (4645, 12, 24, 72)
    train_t300 = data_enhance(train['t300'][:, :12].values)
    train_ua = data_enhance(train['ua'][:, :12].values)
    train_va = data_enhance(train['va'][:, :12].values)
    train_label = data_enhance(label['nino'].values, data_type='label')
    train_label = train_label[:, 12:36]

    train_ua = np.nan_to_num(train_ua)
    train_va = np.nan_to_num(train_va)
    train_t300 = np.nan_to_num(train_t300)
    train_sst = np.nan_to_num(train_sst)

    # SODA data    
    train2 = xr.open_dataset('./data/enso_round1_train_20210201/SODA_train.nc')
    label2 = xr.open_dataset('./data/enso_round1_train_20210201/SODA_label.nc')
    
    train_sst2 = data_enhance(train2['sst'][:, :12].values)  # (4645, 12, 24, 72)
    train_t3002 = data_enhance(train2['t300'][:, :12].values)
    train_ua2 = data_enhance(train2['ua'][:, :12].values)
    train_va2 = data_enhance(train2['va'][:, :12].values)
    train_label2 = data_enhance(label2['nino'].values, data_type='label')
    train_label2 = train_label2[:, 12:36]

    print('Train samples: {}, Valid samples: {}'.format(len(train_label), len(train_label2)))

    dict_train = {
        'sst':train_sst,
        't300':train_t300,
        'ua':train_ua,
        'va': train_va,
        'label': train_label}
    dict_valid = {
        'sst':train_sst2,
        't300':train_t3002,
        'ua':train_ua2,
        'va': train_va2,
        'label': train_label2}
    train_dataset = EarthDataSet(dict_train)
    valid_dataset = EarthDataSet(dict_valid)
    return train_dataset, valid_dataset
    

class EarthDataSet(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data['sst'])

    def __getitem__(self, idx):   
        return (self.data['sst'][idx], self.data['t300'][idx], self.data['ua'][idx], self.data['va'][idx]), self.data['label'][idx]

In [17]:
class simpleSpatailTimeNN(nn.Module):
    def __init__(self, n_cnn_layer:int=1, kernals:list=[3], n_lstm_units:int=64):
        super(simpleSpatailTimeNN, self).__init__()
        self.conv1 = nn.ModuleList([nn.Conv2d(in_channels=12, out_channels=12, kernel_size=i) for i in kernals]) 
        self.conv2 = nn.ModuleList([nn.Conv2d(in_channels=12, out_channels=12, kernel_size=i) for i in kernals])
        self.conv3 = nn.ModuleList([nn.Conv2d(in_channels=12, out_channels=12, kernel_size=i) for i in kernals])
        self.conv4 = nn.ModuleList([nn.Conv2d(in_channels=12, out_channels=12, kernel_size=i) for i in kernals])
        self.pool1 = nn.AdaptiveAvgPool2d((22, 1))
        self.pool2 = nn.AdaptiveAvgPool2d((1, 70))
        self.pool3 = nn.AdaptiveAvgPool2d((1, 128))
        self.batch_norm = nn.BatchNorm1d(12, affine=False)
        self.lstm = nn.LSTM(1540 * 4, n_lstm_units, 2, bidirectional=True)
        self.linear = nn.Linear(128, 24)

    def forward(self, sst, t300, ua, va):
        for conv1 in self.conv1:
            sst = conv1(sst)  # batch * 12 * (24 - 2) * (72 -2)
        for conv2 in self.conv2:
            t300 = conv2(t300)
        for conv3 in self.conv3:
            ua = conv3(ua)
        for conv4 in self.conv4:
            va = conv4(va)

        sst = torch.flatten(sst, start_dim=2)  # batch * 12 * 1540
        t300 = torch.flatten(t300, start_dim=2)
        ua = torch.flatten(ua, start_dim=2)
        va = torch.flatten(va, start_dim=2)  # if flat, lstm input_dims = 1540 * 4              
            
        x = torch.cat([sst, t300, ua, va], dim=-1)    
        x = self.batch_norm(x)
        x, _ = self.lstm(x)
        x = self.pool3(x).squeeze(dim=-2)
        x = self.linear(x)
        return x

In [18]:
def coreff(x, y):
    x_mean = np.mean(x)
    y_mean = np.mean(y)
    c1 = sum((x - x_mean) * (y - y_mean))
    c2 = sum((x - x_mean)**2) * sum((y - y_mean)**2)
    return c1/np.sqrt(c2)

def rmse(preds, y):
    return np.sqrt(sum((preds - y)**2)/preds.shape[0])

def eval_score(preds, label):
    # preds = preds.cpu().detach().numpy().squeeze()
    # label = label.cpu().detach().numpy().squeeze()
    acskill = 0
    RMSE = 0
    a = 0
    a = [1.5]*4 + [2]*7 + [3]*7 + [4]*6
    for i in range(24):
        RMSE += rmse(label[:, i], preds[:, i])
        cor = coreff(label[:, i], preds[:, i])
    
        acskill += a[i] * np.log(i+1) * cor
    return 2/3 * acskill - RMSE

In [None]:
import numpy as np
import os
import random
import argparse
import torch
from torch import double, optim, var
import torch.nn as nn


def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYHTONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def save_model(model, model_save_address='./model'):
    
    modelSave1 = model_save_address + '/model.pth.tar'
    try:
        torch.save({'state_dict': model.state_dict(), }, modelSave1)
    except:
        torch.save({'state_dict': model.module.state_dict(), }, modelSave1)
    print('Model saved!')

if __name__ == "__main__":
    # Parameters for training
    parser = argparse.ArgumentParser()
    parser.add_argument("--continue_training", type=bool, default=True)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--learning_rate", type=float, default=2e-4)
    parser.add_argument("--print_freq", type=int, default=5)
    parser.add_argument("--train_test_ratio", type=float, default=0.8)
    #parser.add_argument("--data_load_address", type=str, default='./channelData')
    parser.add_argument("--model_save_address", type=str, default='./model')
    parser.add_argument("--gpu_list", type=str, default='0,1')
    args = parser.parse_args(args=[])

    
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_list
    SEED = 42
    seed_everything(SEED)

    learning_rate = args.learning_rate
    num_workers = 4

    # parameters for data
    train_dataset, valid_dataset = load_data2()
    model = simpleSpatailTimeNN()
    
    if args.continue_training:
        model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(args.model_save_address + '/model.pth.tar')['state_dict'].items()})
        #model.load_state_dict(torch.load(args.model_save_address + '/model.pth.tar')['state_dict'])
        
        
    if len(args.gpu_list.split(',')) > 1:
        model = torch.nn.DataParallel(model).cuda()  # model.module
    else:
        model = model.cuda()

    
    criterion_mse = nn.MSELoss().cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-5, last_epoch=-1)

    # Data loading 
    train_dataset, test_dataset = load_data2()      
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size)
 
    best_score = -100    
    
    for epoch in range(args.epochs):
        print('========================')
        print('lr:%.4e'%optimizer.param_groups[0]['lr'])     
        model.train()
        for i, ((sst, t300, ua, va), label) in enumerate(train_loader):                
            sst = sst.cuda().float()
            t300 = t300.cuda().float()
            ua = ua.cuda().float()
            va = va.cuda().float()
            
            optimizer.zero_grad()
            label = label.cuda().float()
            preds = model(sst, t300, ua, va)
            loss = criterion_mse(preds,label)
            loss.backward()
            optimizer.step()
            if epoch > 5: 
                scheduler.step() 
            
            if i % args.print_freq == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                    'Loss_mse {loss:.7f}\t'  
                    'lr: {lr:.7e}\t'.format(
                    epoch, i, len(train_loader), loss=loss.item(), lr=optimizer.param_groups[0]['lr']))

        model.eval()
        y_true, y_pred = [], []
        with torch.no_grad():
            for i, ((sst, t300, ua, va), label) in enumerate(test_loader):
                sst = sst.cuda().float()
                t300 = t300.cuda().float()
                ua = ua.cuda().float()
                va = va.cuda().float()
                label = label.cuda().float()
                preds = model(sst, t300, ua, va)

                y_pred.append(preds)
                y_true.append(label)

            y_true = torch.cat(y_true, axis=0)
            y_pred = torch.cat(y_pred, axis=0)
            score = eval_score(y_true.cpu().detach().numpy(), y_pred.cpu().detach().numpy())
            if score>best_score:
                save_model(model, args.model_save_address)
                best_score = score
                print('Epoch: {}, Valid Score {}'.format(epoch+1,score))
                save_model(model, args.model_save_address)
                print('Model saved successfully')
                
    del model, optimizer, train_loader,test_loader
    torch.cuda.empty_cache()

Train samples: 4645, Valid samples: 100
Train samples: 4645, Valid samples: 100
lr:2.0000e-04
Epoch: [0][0/37]	Loss_mse 0.0372694	lr: 2.0000000e-04	
Epoch: [0][5/37]	Loss_mse 0.0339390	lr: 2.0000000e-04	
Epoch: [0][10/37]	Loss_mse 0.0266590	lr: 2.0000000e-04	
Epoch: [0][15/37]	Loss_mse 0.0381869	lr: 2.0000000e-04	
Epoch: [0][20/37]	Loss_mse 0.0162637	lr: 2.0000000e-04	
Epoch: [0][25/37]	Loss_mse 0.0179244	lr: 2.0000000e-04	
Epoch: [0][30/37]	Loss_mse 0.0243381	lr: 2.0000000e-04	
Epoch: [0][35/37]	Loss_mse 0.0207633	lr: 2.0000000e-04	
Model saved!
Epoch: 1, Valid Score 79.94345188131022
Model saved!
Model saved successfully
lr:2.0000e-04
Epoch: [1][0/37]	Loss_mse 0.0292350	lr: 2.0000000e-04	
Epoch: [1][5/37]	Loss_mse 0.0339534	lr: 2.0000000e-04	
Epoch: [1][10/37]	Loss_mse 0.0261335	lr: 2.0000000e-04	
Epoch: [1][15/37]	Loss_mse 0.0377939	lr: 2.0000000e-04	
Epoch: [1][20/37]	Loss_mse 0.0168834	lr: 2.0000000e-04	
Epoch: [1][25/37]	Loss_mse 0.0189022	lr: 2.0000000e-04	
Epoch: [1][30/37]	Los