In [1]:
import sys
import os
import torch
import logging
import toml
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm

sys.path.append(os.path.join(sys.path[0], '../..'))

from data.io import Reader
from data.kcost_dataset import KCostDataSetSplit, KCostDataSet
from model.kcost import KCostModel, KCostModelAlpha

In [2]:
MAX_LEVELS = 16
cfg = toml.load('../../config/training.toml')
cfg

{'model': 'KCostModel',
 'log': {'level': 'INFO', 'name': 'endure'},
 'io': {'data_dir': '/scratchHDDa/ndhuynh/data-cold',
  'train_dir': 'training_data',
  'train_data': ['train_0.feather',
   'train_1.feather',
   'train_2.feather',
   'train_3.feather',
   'train_4.feather',
   'train_5.feather',
   'train_6.feather',
   'train_7.feather',
   'train_8.feather',
   'train_9.feather',
   'train_10.feather',
   'train_11.feather',
   'train_12.feather',
   'train_13.feather',
   'train_14.feather',
   'train_15.feather',
   'train_16.feather',
   'train_17.feather',
   'train_18.feather',
   'train_19.feather'],
  'test_data': ['train_20.feather']},
 'static_params': {'max_levels': 16,
  'max_size_ratio': 50,
  'mean_bias': [4.75, 0.5, 0.5, 0.5, 0.5],
  'std_bias': [2.74, 0.3, 0.3, 0.3, 0.3],
  'out_dims': 4},
 'hyper_params': {'num_cont_vars': 5,
  'num_cate_vars': 17,
  'hidden_layers': 2,
  'embedding_size': 17},
 'train': {'max_epochs': 128,
  'batch_size': 32,
  'learning_rate': 0

In [3]:
%%time
paths = [os.path.join(cfg['io']['data_dir'], cfg['io']['train_dir'], fname) for fname in cfg['io']['train_data']]
data = KCostDataSet(cfg, paths)
val_len = int(len(data) * cfg['validate']['percent'])
train_len = len(data) - val_len
train, val = torch.utils.data.random_split(data, [train_len, val_len])
train = DataLoader(train, batch_size=cfg['train']['batch_size'], shuffle=True)
val = DataLoader(val, batch_size=cfg['validate']['batch_size'], shuffle=False)
print(f'Validate length: {val_len}\nTraining length: {train_len}')

Validate length: 8192
Training length: 73728
CPU times: user 564 ms, sys: 102 ms, total: 667 ms
Wall time: 117 ms


In [4]:
def train_loop(dataloader, model, loss_fn, optimizer):
    model.train()
    pbar = tqdm(dataloader, desc='Training')
    for batch, (X, y) in enumerate(pbar):
        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % (1000) == 0:
            pbar.set_description(f'loss: {loss:>4f}')
            
    return 
            
def test_loop(dataloader, model, loss_fn):
    num_batches = len(dataloader)
    test_loss = 0

    model.eval()
    with torch.no_grad():
        for X, y in tqdm(dataloader, desc='Validate'):
            pred = model(X)
            test_loss += loss_fn(pred, y).item()

    test_loss /= num_batches
    print(f'validation loss: {test_loss:>8f}\n')
    
    return test_loss

In [5]:
loss_fn = nn.MSELoss()
model = KCostModelAlpha(cfg)
optimizer = torch.optim.SGD(model.parameters(), lr=cfg['train']['learning_rate'])
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=cfg['train']['learning_rate_decay'])

In [6]:
MAX_EPOCHS = cfg['train']['max_epochs']
for t in range(MAX_EPOCHS):
    print(f"Epoch [{t + 1}/{MAX_EPOCHS}]")
    train_loop(train, model, loss_fn, optimizer)
    scheduler.step()
    curr_loss = test_loop(val, model, loss_fn)

Epoch [1/128]


Training:   0%|          | 0/2304 [00:00<?, ?it/s]

Validate:   0%|          | 0/1 [00:00<?, ?it/s]

validation loss: 2.104850

Epoch [2/128]


Training:   0%|          | 0/2304 [00:00<?, ?it/s]

Validate:   0%|          | 0/1 [00:00<?, ?it/s]

validation loss: 1.280241

Epoch [3/128]


Training:   0%|          | 0/2304 [00:00<?, ?it/s]

Validate:   0%|          | 0/1 [00:00<?, ?it/s]

validation loss: 0.839560

Epoch [4/128]


Training:   0%|          | 0/2304 [00:00<?, ?it/s]

Validate:   0%|          | 0/1 [00:00<?, ?it/s]

validation loss: 0.734201

Epoch [5/128]


Training:   0%|          | 0/2304 [00:00<?, ?it/s]

KeyboardInterrupt: 