In [1]:
import sys, os, toml
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from IPython.display import display, HTML
from itertools import combinations_with_replacement
from torch.utils.data import Dataset, DataLoader

sys.path.append(os.path.join(sys.path[0], '../..'))

from data.kcost_dataset import KCostDataSetSplit
from data.io import Writer, Reader
from model.kcost import KCostModel
from lsm.cost import EndureQFixedCost, EndureTierLevelCost, EndureKHybridCost
from lsm.lsmtype import LSMTree, LSMSystem, Policy

In [2]:
def wl_to_array(wl_dict):
    return (wl_dict['id'], wl_dict['z0'], wl_dict['z1'], wl_dict['q'], wl_dict['w'])

# config = Reader.read_config(os.path.join(sys.path[0], '../..', 'config', 'endure.toml'))
# reader = Reader(config)
# writer = Writer(config)
# system_vars = LSMSystem(**config['system'])

cfg = toml.load(os.path.join(sys.path[0], '../..', 'config', 'training.toml'))
data_path = os.path.join(cfg['io']['data_dir'], cfg['io']['train_dir'])

In [3]:
%%time
paths = [os.path.join(data_path, fname) for fname in cfg['io']['train_data']]
train_data = KCostDataSetSplit(cfg, paths)

CPU times: user 11.1 s, sys: 6.8 s, total: 17.9 s
Wall time: 10 s


In [4]:
%%time
paths = [os.path.join(data_path, fname) for fname in cfg['io']['test_data']]
test_data = KCostDataSetSplit(cfg, paths)
test = DataLoader(test_data, batch_size=8192)

CPU times: user 1.49 s, sys: 840 ms, total: 2.33 s
Wall time: 749 ms


In [5]:
model = KCostModel(cfg)
model.state_dict()
model_data = torch.load(os.path.join(sys.path[0], '../..', 'checkpoint.pt')) 
model.load_state_dict(model_data['model_state_dict'])

<All keys matched successfully>

In [6]:
test_path = paths[0]
df = pd.read_feather(test_path)
df

Unnamed: 0,wl_id,h,T,z0,z1,q,w,B,phi,s,...,K_6,K_7,K_8,K_9,K_10,K_11,K_12,K_13,K_14,K_15
0,3,0.0,2,0.01,0.01,0.97,0.01,4,1,0.0,...,1,1,1,1,0,0,0,0,0,0
1,3,0.0,4,0.01,0.01,0.97,0.01,4,1,0.0,...,0,0,0,0,0,0,0,0,0,0
2,3,0.0,4,0.01,0.01,0.97,0.01,4,1,0.0,...,0,0,0,0,0,0,0,0,0,0
3,3,0.0,4,0.01,0.01,0.97,0.01,4,1,0.0,...,0,0,0,0,0,0,0,0,0,0
4,3,0.0,4,0.01,0.01,0.97,0.01,4,1,0.0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1470411,3,9.5,48,0.01,0.01,0.97,0.01,4,1,0.0,...,0,0,0,0,0,0,0,0,0,0
1470412,3,9.5,48,0.01,0.01,0.97,0.01,4,1,0.0,...,0,0,0,0,0,0,0,0,0,0
1470413,3,9.5,48,0.01,0.01,0.97,0.01,4,1,0.0,...,0,0,0,0,0,0,0,0,0,0
1470414,3,9.5,48,0.01,0.01,0.97,0.01,4,1,0.0,...,0,0,0,0,0,0,0,0,0,0


In [7]:
model_cost = []
model.eval()
with torch.no_grad():
    for x, _ in tqdm(test, desc='Testing'):
        pred = model(x)
        model_cost.append(pred)
df['model_cost'] = torch.flatten(torch.cat(model_cost))

Testing:   0%|          | 0/180 [00:00<?, ?it/s]

In [8]:
mse = ((df['new_cost'] - df['model_cost']) ** 2).sum() / len(df)
mse

157.1812958020101

In [9]:
model(train_data[0][0].unsqueeze(0))

tensor([[5.1304]], grad_fn=<ReluBackward0>)

In [10]:
model(test_data[100][0].unsqueeze(0))

tensor([[44.0423]], grad_fn=<ReluBackward0>)