In [1]:
from Models.MLP import SimpleMLP, train_step, val_step

import torch
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


# WARNING using 31 for data

In [2]:
from sklearn.model_selection import train_test_split

nsnapshots = 133

np.random.seed(52)
full_indx = np.arange(nsnapshots)
train_idx, val_idx = train_test_split(full_indx, test_size=0.2, train_size=0.8) 

In [3]:
print(train_idx)
print(val_idx)

[ 21  95  92  52 110  58  44   8  83  35  50  56 124  36  41 128  98 117
  14  60  70 102  29  62  53  75 108 118 116  68  57  93  80  37  73  32
  77 119  87  48  55  67  28 115 111 125  91 120 106  12  99 121 129 131
  10  94 101  17   9  49  30  38   0  26  85  15 130  25   6  82  65 107
  59  64  74  84  78  43  34  20 114   7 126  71  22 100  39  63  76 122
  79  45  61  42  46  54 112 132  16 113   5  33  97  86  13  11]
[ 40  81  90  69  88  19 127  27 123  66 109 103  96  89  51 104   4   3
   2  47  31  23  18  24  72   1 105]


In [4]:
data_source = 'kedar'

# Dataset guider

In [5]:
def getting_datapaths(
    train_idx,
    val_idx,
    files='sunny', 
    target_cols=["ccn_001", "ccn_003", "ccn_006"],
    exclude_cols = ["XLONG", "XLAT"],
):
    # getting colname
    featloc = "/home/kphadke/hackathon/norm_data_timesteps_feat"
    targloc = "/home/kphadke/hackathon/norm_data_timesteps_targ"

    featcols = np.load(f"{featloc}/x_name.npy")
    targcols = np.load(f"{targloc}/y_name.npy")
    
    target_idx = [np.where(targcols == i)[0].item() for i in target_cols ]
    feat_idx   = [np.where(featcols == i)[0].item() for i in featcols if i not in exclude_cols]
    
    # getting paths
    if files == 'sunny':
        # my data
        featloc = "/home/kwoksun2/hackathon_data/norm_data"
        targloc = "/home/kwoksun2/hackathon_data/norm_data"

        train_feat_files = [f"{featloc}/t{t}_feat.npy" for t in train_idx]
        val_feat_files   = [f"{featloc}/t{t}_feat.npy" for t in val_idx  ]

        train_targ_files = [f"{targloc}/t{t}_targ.npy" for t in train_idx]
        val_targ_files   = [f"{targloc}/t{t}_targ.npy" for t in val_idx]
    
    elif files == 'kedar':
        featloc = "/home/kphadke/hackathon/norm_data_timesteps_feat"
        targloc = "/home/kphadke/hackathon/norm_data_timesteps_targ"

        train_feat_files = [f"{featloc}/{t}feat_norm.npy" for t in train_idx]
        val_feat_files   = [f"{featloc}/{t}feat_norm.npy" for t in val_idx  ]

        train_targ_files = [f"{targloc}/{t}targ_norm.npy" for t in train_idx]
        val_targ_files   = [f"{targloc}/{t}targ_norm.npy" for t in val_idx]

    return [np.array(train_feat_files),np.array(train_targ_files)], [np.array(val_feat_files),np.array(val_targ_files)], feat_idx, target_idx

In [6]:
featloc = "/home/kphadke/hackathon/norm_data_timesteps_feat"
targloc = "/home/kphadke/hackathon/norm_data_timesteps_targ"

featcols = np.load(f"{featloc}/x_name.npy")
targcols = np.load(f"{targloc}/y_name.npy")
targcols

array(['ccn_001', 'ccn_003', 'ccn_006', 'CHI', 'CHI_CCN', 'D_ALPHA',
       'D_GAMMA', 'D_ALPHA_CCN', 'D_GAMMA_CCN', 'PM25'], dtype='<U11')

In [7]:
target_cols=['ccn_001', 'ccn_003', 'ccn_006'] 
# target_cols=['CHI', 'CHI_CCN'] 
# target_cols=['D_ALPHA', 'D_GAMMA', 'D_ALPHA_CCN', 'D_GAMMA_CCN',]

exclude_cols = ["XLONG", "XLAT"]
train_files, val_files, feat_idx, target_idx = getting_datapaths(
    train_idx, 
    val_idx, 
    files=data_source,
    target_cols=target_cols,
    exclude_cols=exclude_cols
)

In [8]:
def prepare_dataloaders(feat_files, targ_files, feat_idx, target_idx, box_fraction=0.1, shuffle=True):
    NBOX = 39 * 159 * 169
    
    num_samples_per_box = int(NBOX * box_fraction)
    
    d = []
    t = []
    for f_fn, t_fn in zip(feat_files, targ_files):
        # select a fraction of data randomly
        sidx = np.random.randint(0, NBOX, num_samples_per_box)
        feats = np.load(f_fn)
        targs = np.load(t_fn)
        
        if targs.ndim == 1:
            targs = targs.reshape(10,-1).transpose()

        # keep the subset
        d.append(feats[sidx][:,feat_idx])
        t.append(targs[sidx][:,target_idx])
    train_feat, train_targ = np.vstack(d), np.vstack(t)

    tds = torch.utils.data.TensorDataset(
        torch.from_numpy(train_feat).float(), 
        torch.from_numpy(train_targ).float()
    )
    return torch.utils.data.DataLoader(tds, batch_size=256, shuffle=shuffle)

In [9]:
train_feat_files, train_targ_files = train_files
val_feat_files,   val_targ_files   = val_files

In [10]:
net = SimpleMLP(ninputs=len(feat_idx), nouts=len(target_idx)).cuda()
optim = torch.optim.Adam(net.parameters())

In [11]:
train_feat_files

array(['/home/kphadke/hackathon/norm_data_timesteps_feat/21feat_norm.npy',
       '/home/kphadke/hackathon/norm_data_timesteps_feat/95feat_norm.npy',
       '/home/kphadke/hackathon/norm_data_timesteps_feat/92feat_norm.npy',
       '/home/kphadke/hackathon/norm_data_timesteps_feat/52feat_norm.npy',
       '/home/kphadke/hackathon/norm_data_timesteps_feat/110feat_norm.npy',
       '/home/kphadke/hackathon/norm_data_timesteps_feat/58feat_norm.npy',
       '/home/kphadke/hackathon/norm_data_timesteps_feat/44feat_norm.npy',
       '/home/kphadke/hackathon/norm_data_timesteps_feat/8feat_norm.npy',
       '/home/kphadke/hackathon/norm_data_timesteps_feat/83feat_norm.npy',
       '/home/kphadke/hackathon/norm_data_timesteps_feat/35feat_norm.npy',
       '/home/kphadke/hackathon/norm_data_timesteps_feat/50feat_norm.npy',
       '/home/kphadke/hackathon/norm_data_timesteps_feat/56feat_norm.npy',
       '/home/kphadke/hackathon/norm_data_timesteps_feat/124feat_norm.npy',
       '/home/kphadke/ha

In [None]:
nchunks = 20

# per chunk 133/20 ~ 6
max_grp_size = 5
nchunks_train = len(train_feat_files) // max_grp_size
nchunks_val   = len(val_feat_files) // max_grp_size

idx_shuf = np.arange(len(train_feat_files))

train_hist = []
val_hist   = []

max_epochs = 100
for i in range(max_epochs):
    np.random.shuffle(idx_shuf)

    train_feat_grps= np.array_split(train_feat_files[idx_shuf], nchunks_train)
    train_targ_grps= np.array_split(train_targ_files[idx_shuf], nchunks_train)

    val_feat_grps= np.array_split(val_feat_files, nchunks_val)
    val_targ_grps= np.array_split(val_targ_files, nchunks_val)

    total_train_loss = 0
    total_val_loss   = 0
    total_train_time = 0
    total_val_time   = 0

    for feat_fns, targs_fns in zip(train_feat_grps, train_targ_grps):
        tdl = prepare_dataloaders(
            feat_fns, 
            targs_fns, 
            feat_idx, 
            target_idx, 
            shuffle=True
        )
        train_loss, train_time = train_step(tdl, net, optim)
        total_train_loss += train_loss/ len(train_feat_grps)


    for feat_fns, targs_fns in zip(val_feat_grps, val_targ_grps):
        tdl = prepare_dataloaders(
            feat_fns, 
            targs_fns,
            feat_idx, 
            target_idx, 
            shuffle=False)
        val_loss, val_time = val_step(tdl, net)
        total_val_loss += val_loss/ len(val_feat_grps)
        
    train_hist.append(total_train_loss)
    val_hist.append(total_val_loss)
    print(total_train_loss, total_val_loss)
        
    # checkpoint data
    checkpoint = {}
    checkpoint['model_state_dict'] = net
    checkpoint['epoch'] = i
    checkpoint['datasource'] = data_source
    checkpoint['exclude_cols'] = exclude_cols
    checkpoint['target_cols'] = target_cols
    checkpoint['total_train_loss'] = total_train_loss
    checkpoint['total_val_loss'] = total_val_loss

    torch.save(checkpoint, f"ep_{i}_{data_source}_{','.join(target_cols)}_{','.join(exclude_cols)}_MLP.pt")

100%|██████████| 2457/2457 [00:29<00:00, 84.00it/s]
100%|██████████| 2047/2047 [00:21<00:00, 93.13it/s]
100%|██████████| 2047/2047 [00:22<00:00, 90.82it/s]
100%|██████████| 2047/2047 [00:21<00:00, 94.87it/s]
100%|██████████| 2047/2047 [00:25<00:00, 81.71it/s]
100%|██████████| 2047/2047 [00:22<00:00, 92.84it/s]
100%|██████████| 2047/2047 [00:23<00:00, 86.69it/s]
100%|██████████| 2047/2047 [00:22<00:00, 92.79it/s]
100%|██████████| 2047/2047 [00:23<00:00, 85.58it/s]
100%|██████████| 2047/2047 [00:22<00:00, 91.44it/s]
100%|██████████| 2047/2047 [00:23<00:00, 88.53it/s]
100%|██████████| 2047/2047 [00:22<00:00, 92.78it/s]
100%|██████████| 2047/2047 [00:21<00:00, 93.30it/s]
100%|██████████| 2047/2047 [00:25<00:00, 80.72it/s]
100%|██████████| 2047/2047 [00:28<00:00, 72.82it/s]
100%|██████████| 2047/2047 [00:28<00:00, 71.23it/s]
100%|██████████| 2047/2047 [00:29<00:00, 69.15it/s]
100%|██████████| 2047/2047 [00:30<00:00, 66.14it/s]
100%|██████████| 2047/2047 [00:29<00:00, 69.88it/s]
100%|███████

0.23800760109604194 0.19108818308558795


100%|██████████| 2457/2457 [00:33<00:00, 73.45it/s]
100%|██████████| 2047/2047 [00:22<00:00, 91.35it/s]
100%|██████████| 2047/2047 [00:26<00:00, 77.21it/s]
100%|██████████| 2047/2047 [00:22<00:00, 92.95it/s]
100%|██████████| 2047/2047 [00:24<00:00, 84.07it/s]
100%|██████████| 2047/2047 [00:22<00:00, 91.26it/s]
100%|██████████| 2047/2047 [00:22<00:00, 92.60it/s]
 28%|██▊       | 578/2047 [00:08<00:23, 63.22it/s]