In [None]:
mport torch
import pickle
import joblib
import numpy as np
import math
import sys
import os
import h5py
import argparse

import model
import data_io
import normalize

In [None]:
def minkowski_error(prediction, target, minkowski_parameter=1.5):
    """
    Minkowski error to be better deal with outlier errors
    """
    loss = torch.mean((torch.abs(prediction - target))**minkowski_parameter)
    return loss

In [None]:
device = torch.device("cuda")
region="9999LEAU_std"
identifier="9999LEAU"

nlevs = 45
in_features, nb_classes=(nlevs*4+3),(nlevs*2)
nb_hidden_layer = 8
hidden_size = 512
mlp = model.MLP(in_features, nb_classes, nb_hidden_layer, hidden_size)
# mlp = model.MLP_BN(in_features, nb_classes, nb_hidden_layer, hidden_size)
pytorch_total_params = sum(p.numel() for p in mlp.parameters() if p.requires_grad)
print("Number of traninable parameter: {0}".format(pytorch_total_params))

model_name = "q_qadv_t_tadv_swtoa_lhf_shf_qtphys_{0}_lyr_{1}_in_{2}_out_{3}_hdn_{4}_epch_{5}_btch_{6}_mae_vlr.tar".format(str(nb_hidden_layer).zfill(3),
                                                                                    str(in_features).zfill(3),
                                                                                    str(nb_classes).zfill(3),
                                                                                    str(hidden_size).zfill(4),
                                                                                    str(args.epochs).zfill(3),
                                                                                    str(args.batch_size).zfill(5),
                                                                                    identifier)
optimizer =  torch.optim.Adam(mlp.parameters(), lr=1.e-3)
# optimizer =  torch.optim.SGD(mlp.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
# loss_function = torch.nn.MSELoss()
loss_function = torch.nn.L1Loss()

locations={ "train_test_datadir":"/content/drive/My Drive/ML/data",
            "model_loc":"/content/drive/My Drive/ML/data/model"}

mlp.to(device)
print("Model's state_dict:")
for param_tensor in mlp.state_dict():
    print(param_tensor, "\t", mlp.state_dict()[param_tensor].size())

In [None]:
class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, x3data, x2data, ydata, nlevs):
        self.x3datasets = x3data
        self.x2datasets = x2data
        self.ydatasets = ydata
        self.nlevs = nlevs

    def __getitem__(self, i):
        x3 = [torch.tensor(d[i,:self.nlevs]) for d in self.x3datasets]
        x2 = [torch.tensor(d[i]) for d in self.x2datasets]
        x = torch.cat(x3+x2,dim=0)
        y = torch.cat([torch.tensor(d[i,:self.nlevs]) for d in self.ydatasets],dim=0)
        return (x.to(device),y.to(device))

    def __len__(self):
        return min(len(d) for d in self.x2datasets)

In [None]:
nn_data = data_io.Data_IO(region, locations)

x2d  = [nn_data.sw_toa_train, nn_data.lhf_train, nn_data.shf_train]
x2d_test  = [nn_data.sw_toa_test, nn_data.lhf_test, nn_data.shf_test]
x3d  = [nn_data.q_tot_train, nn_data.q_tot_adv_train, nn_data.theta_train, nn_data.theta_adv_train]
x3d_test  = [nn_data.q_tot_test, nn_data.q_tot_adv_test, nn_data.theta_test, nn_data.theta_adv_test]
y = [nn_data.qphys_train, nn_data.theta_phys_train]
y_test = [nn_data.qphys_test, nn_data.theta_phys_test]


In [None]:
train_loader = torch.utils.data.DataLoader(
             ConcatDataset(x3d,x2d,y,nlevs),
             batch_size=args.batch_size, shuffle=True)

validate_loader = torch.utils.data.DataLoader(
             ConcatDataset(x3d_test,x2d_test,y_test,nlevs),
             batch_size=args.batch_size, shuffle=False)

In [None]:
def train(epoch):
    """
    Train the model
    """
     
    # Sets the model into training mode
    mlp.train()
    train_loss = 0
    # x=data in, y=data out, z = extra loss qnext
    for batch_idx, (x, y) in enumerate(train_loader):
        x = x.to(device)
        optimizer.zero_grad()
        prediction = mlp(x)
        phys_loss = loss_function(prediction, y)
        # qnext_loss = q_loss_tensors_mm(prediction, z, x)
        # qnext_loss = q_loss_tensors_std(prediction, z, x)
        loss = phys_loss #+ 0.*qnext_loss
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, 
            batch_idx * len(x), len(train_loader.dataset),100. * batch_idx / len(train_loader),
            loss.item() / len(x)))
        
        ##### Use this for finding learning rate 
        # scheduler.step()
        # for param_group in optimizer.param_groups:
        #     print(param_group['lr'], loss.item() / len(x))
        ###### LR finder end
    average_loss = train_loss / len(train_loader.dataset)
    print('====> Epoch: {} Average loss: {:.6f}'.format(epoch, average_loss))
    
    return average_loss

In [None]:
def validate(epoch):
    mlp.eval()
    validation_loss = 0
    with torch.no_grad():
        for i, (x,y) in enumerate(validate_loader):
            x = x.to(device)
            prediction = mlp(x)
            validation_loss += loss_function(prediction,y).item()
            # if i == 0:
            #     n = min(x.size(0), 8)
            #     comparison = torch.cat([data[:n],
            #                           recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
            #     save_image(comparison.cpu(),
            #              'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    validation_loss /= len(validate_loader.dataset)
    print('====> validation loss: {:.6f}'.format(validation_loss))
    return validation_loss

In [None]:
def checkpoint_save(epoch: int, nn_model: model, nn_optimizer: torch.optim, training_loss: list, validation_loss: list, model_name: str):
    """
    Save model checkpoints
    """
    checkpoint_name = model_name.replace('.tar','_chkepo_{0}.tar'.format(str(epoch).zfill(3)))
    torch.save({'epoch':epoch,
                'model_state_dict':nn_model.state_dict(),
                'optimizer_state_dict':nn_optimizer.state_dict(),
                'training_loss':training_loss,
                'validation_loss':validation_loss},
                locations['model_loc']+'/'+checkpoint_name)

In [None]:
def train_main():
    training_loss = []
    validation_loss = []
    for epoch in range(1, args.epochs + 1):
        train_loss = train(epoch)
        validate_loss = validate(epoch)
        scheduler.step()    
        training_loss.append(train_loss)
        validation_loss.append(validate_loss)
        if epoch % 1 == 0:
            checkpoint_save(epoch, mlp, optimizer, training_loss, validation_loss, model_name)

    # Save the final model
    torch.save({'epoch':epoch,
                'model_state_dict':mlp.state_dict(),
                'optimizer_state_dict':optimizer.state_dict(),
                'training_loss':training_loss,
                'validation_loss':validation_loss},
                locations['model_loc']+'/'+model_name)