In [1]:
# evaluate TAB
from torch.utils.data import TensorDataset
import torch
import numpy as np
import time
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
feat_cols = ['TOT_NUM_CONC',
 'TOT_MASS_CONC',
 'pmc_SO4',
 'pmc_NO3',
 'pmc_Cl',
 'pmc_NH4',
 'pmc_ARO1',
 'pmc_ARO2',
 'pmc_ALK1',
 'pmc_OLE1',
 'pmc_API1',
 'pmc_API2',
 'pmc_LIM1',
 'pmc_LIM2',
 'pmc_OC',
 'pmc_BC',
 'pmc_H2O',
 'TEMPERATURE',
 'REL_HUMID',
 'ALT',
 'Z',
#  'XLAT',
#  'XLONG',
 'h2so4',
 'hno3',
 'hcl',
 'nh3',
 'no',
 'no2',
 'no3',
 'n2o5',
 'hono',
 'hno4',
 'o3',
 'o1d',
 'O3P',
 'oh',
 'ho2',
 'h2o2',
 'co',
 'so2',
 'ch4',
 'c2h6',
 'ch3o2',
 'ethp',
 'hcho',
 'ch3oh',
 'ANOL',
 'ch3ooh',
 'ETHOOH',
 'ald2',
 'hcooh',
 'RCOOH',
 'c2o3',
 'pan',
 'aro1',
 'aro2',
 'alk1',
 'ole1',
 'api1',
 'api2',
 'lim1',
 'lim2',
 'par',
 'AONE',
 'mgly',
 'eth',
 'OLET',
 'OLEI',
 'tol',
 'xyl',
 'cres',
 'to2',
 'cro',
 'open',
 'onit',
 'rooh',
 'ro2',
 'ano2',
 'nap',
 'xo2',
 'xpar',
 'isop',
 'isoprd',
 'isopp',
 'isopn',
 'isopo2',
 'api',
 'lim',
 'dms',
 'msa',
 'dmso',
 'dmso2',
 'ch3so2h',
 'ch3sch2oo',
 'ch3so2',
 'ch3so3',
 'ch3so2oo',
 'ch3so2ch2oo',
 'SULFHOX',
 'P',
 'PB']

target_cols = [
    'ccn_001',
     'ccn_003',
     'ccn_006',
     'CHI',
     'CHI_CCN',
     'D_ALPHA',
     'D_GAMMA',
     'D_ALPHA_CCN',
     'D_GAMMA_CCN',
     'PM25'
]

In [7]:
from Models.MLP import SimpleMLP, predictions, val_step

In [19]:
mlp_model = torch.load("2100_mlp.pt")

In [9]:
feats = np.load('numpy_data/t0_feat.npy')
targs = np.load('numpy_data/t0_targ.npy')

tds = TensorDataset( torch.from_numpy(feats), torch.from_numpy(targs) )

means = {}
stds  = {}
mins  = {}
for i, f in enumerate(feat_cols):
    tmp = feats[:,i]
    d = tmp[tmp>0]
    means[f] = np.log(d).mean()
    stds[f]  = np.log(d).std()
    
    mins[f] = np.exp(np.log(d).min())
    
for i, f in enumerate(target_cols):
    tmp = targs[:,i]
    d = tmp[tmp>0]
    means[f] = np.log(d).mean()
    stds[f]  = np.log(d).std()
    
    mins[f] = np.exp(np.log(d).min())

feat_mean_list = torch.from_numpy(np.array([means[f] for f in feat_cols]))
feat_std_list = torch.from_numpy(np.array([stds[f] for f in feat_cols]))
feat_min_list = torch.from_numpy(np.array([mins[f] for f in feat_cols]))

targ_mean_list = torch.from_numpy(np.array([means[f] for f in target_cols]))
targ_std_list = torch.from_numpy(np.array([stds[f] for f in target_cols]))
targ_min_list = torch.from_numpy(np.array([mins[f] for f in target_cols]))

In [None]:
tds = ConcatDataset(allds[::2])

L = len(tds)
ntrain = int(L * 0.8)
nval   = int(L * 0.1)
ntest  = L - ntrain - nval

bsz = 256

ds_train, ds_val, ds_test = torch.utils.data.random_split(tds, [ntrain, nval, ntest], generator=torch.Generator().manual_seed(42))
dl_train =torch.utils.data.DataLoader(ds_train, batch_size=bsz, shuffle=True, )
dl_val =torch.utils.data.DataLoader(ds_val, batch_size=bsz, shuffle=False, )
dl_test =torch.utils.data.DataLoader(ds_test, batch_size=bsz, shuffle=False, )

In [11]:
dl = torch.utils.data.DataLoader(tds, batch_size=256 )

In [14]:
def loss_fn(ypred, y):
    return ((ypred-y)**2).mean()

@torch.no_grad()
def val_step(dl, model):
    tstart = time.time()
    total_loss = 0
    for X, y in tqdm(dl):
        model.eval()
        padded = torch.where(X < feat_min_list.unsqueeze(0), feat_min_list.unsqueeze(0), X)
        Xpad_norm = (padded.log() - feat_mean_list)/ feat_std_list

        padded = torch.where(y < targ_min_list.unsqueeze(0), targ_min_list.unsqueeze(0), y)
        ypad_norm = (padded.log() - targ_mean_list)/ targ_std_list

        ypred = model(Xpad_norm.cuda())

        loss = loss_fn(ypred, ypad_norm.cuda())
        total_loss += loss.item()
    return total_loss/ len(dl) , time.time() - tstart

In [15]:
val_step(dl, mlp_model)

100%|██████████| 4094/4094 [00:40<00:00, 100.80it/s]


(0.037853904605447755, 40.67121481895447)

In [17]:
val_step(dl, mlp_model)

100%|██████████| 4094/4094 [00:29<00:00, 139.27it/s]


(0.03156511715120119, 29.403302669525146)

In [20]:
val_step(dl, mlp_model)

100%|██████████| 4094/4094 [00:33<00:00, 120.68it/s]


(0.03191129043998591, 33.957342863082886)