In [1]:
import sys
import os
import toml
import glob

import torch
import torcharrow
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchdata import datapipes as DataPipe

import pandas as pd
import numpy as np
import scipy.optimize as SciOpt
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

sys.path.append(os.path.join(sys.path[0], '../..'))

from data.io import Reader
from lsm.lsmtype import Policy
from jobs.train import TrainJob
from model.trainer import Trainer
import lsm.cost as CostFunc

In [2]:
config = Reader.read_config('../../endure.toml')

In [3]:
class EndureQDataSet(torch.utils.data.Dataset):
    def __init__(self, config, fnames):
        self._config = config
        self._mean = np.array(self._config['train']['mean_bias'], np.float32)
        self._std = np.array(self._config['train']['std_bias'], np.float32)
        
        self._df = self._load_data(fnames)
        self._label_cols = ['z0_cost', 'z1_cost', 'q_cost', 'w_cost']
        self._input_cols = ['h', 'z0', 'z1', 'q', 'w', 'T', 'Q']
        
        self.inputs = torch.from_numpy(self._df[self._input_cols].values).float()
        self.labels = torch.from_numpy(self._df[self._label_cols].values).float()
    
    def _load_data(self, fnames):
        df = []
        for fname in fnames:
            df.append(pd.read_csv(fname))
        df = pd.concat(df)
        return self._process_df(df)
    
    def _process_df(self, df):
        df[['h', 'z0', 'z1', 'q', 'w']] -= self._mean
        df[['h', 'z0', 'z1', 'q', 'w']] /= self._std
        df['T'] = df['T'] - self._config['lsm']['size_ratio']['min']
        df['Q'] = df['Q'] - (self._config['lsm']['size_ratio']['min'] - 1)
        return df
    
    def __len__(self):
        return len(self._df)
    
    def __getitem__(self, idx):
        return self.labels[idx], self.inputs[idx]

In [4]:
data_root = os.path.join(config['io']['data_dir'], config['train']['dir'])
train_files = glob.glob(os.path.join(config['io']['data_dir'], config['train']['dir'], '*.csv'))
test_files = glob.glob(os.path.join(config['io']['data_dir'], config['test']['dir'], '*.csv'))

['/data/train-data/qcost/qcost-0036.csv',
 '/data/train-data/qcost/qcost-0032.csv',
 '/data/train-data/qcost/qcost-0012.csv',
 '/data/train-data/qcost/qcost-0048.csv',
 '/data/train-data/qcost/qcost-0020.csv',
 '/data/train-data/qcost/qcost-0052.csv',
 '/data/train-data/qcost/qcost-0060.csv',
 '/data/train-data/qcost/qcost-0004.csv',
 '/data/train-data/qcost/qcost-0028.csv',
 '/data/train-data/qcost/qcost-0040.csv',
 '/data/train-data/qcost/qcost-0024.csv',
 '/data/train-data/qcost/qcost-0000.csv',
 '/data/train-data/qcost/qcost-0008.csv',
 '/data/train-data/qcost/qcost-0056.csv',
 '/data/train-data/qcost/qcost-0016.csv',
 '/data/train-data/qcost/qcost-0044.csv',
 '/data/train-data/qcost/qcost-0037.csv',
 '/data/train-data/qcost/qcost-0049.csv',
 '/data/train-data/qcost/qcost-0061.csv',
 '/data/train-data/qcost/qcost-0053.csv',
 '/data/train-data/qcost/qcost-0021.csv',
 '/data/train-data/qcost/qcost-0029.csv',
 '/data/train-data/qcost/qcost-0005.csv',
 '/data/train-data/qcost/qcost-004

In [5]:
%%time
train_dataset = EndureQDataSet(config, train_files)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [None]:
%%time
test_dataset = EndureQDataSet(config, test_files)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [6]:
train_job = TrainJob(config)

In [7]:
trainer = Trainer(
    config=config,
    model=train_job.model,
    optimizer=train_job.optimizer,
    loss_fn=train_job.loss_fn,
    train_data=train_dataloader,
    test_data=test_dataloader,)

In [8]:
trainer.run()

loss 240.680634:  10%|█▌              | 102097/1048576 [02:37<24:21, 647.53it/s]


KeyboardInterrupt: 