In [11]:
import torch
import logging
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
import pandas as pd

logging.basicConfig(level=logging.INFO, format='[%(asctime)s]: %(message)s', datefmt='%H:%M:%S')

log = logging.getLogger()
KDATA_PATH = '/Users/ndhuynh/sandbox/data/cost_surface_k.feather'
MAX_LEVELS = 15

In [22]:
class KCostDataSet(Dataset):
    def __init__(self, transform=None, target_transform=None):
        df = pd.read_feather(KDATA_PATH)
        print('Read in dataframe')

        cont_inputs = ['h', 'z0', 'z1', 'q', 'w']
        cate_inputs = ['T'] + [f'K_{i}' for i in range(MAX_LEVELS)]
        output_cols = ['new_cost']

        mean = df[cont_inputs].mean()
        std = df[cont_inputs].std()
        std[std == 0] = 1
        df[cont_inputs] = (df[cont_inputs] - mean) / std

        self.cont_inputs = torch.from_numpy(df[cont_inputs].values).float()
        print('Normalized continous vars')
        self.cate_inputs = torch.from_numpy(df[cate_inputs].values).to(torch.int64)

        self.outputs = torch.from_numpy(df[output_cols].values).float()

    def __len__(self):
        return len(self.cont_inputs)

    def __getitem__(self, idx):
        categories = torch.flatten(nn.functional.one_hot(self.cate_inputs[idx], num_classes=50), start_dim=-2)
        inputs = torch.cat((self.cont_inputs[idx], categories), dim=-1)
        label = self.outputs[idx]

        return inputs, label

In [23]:
class KCost1Hidden(nn.Module):
    def __init__(self):
        super(KCost1Hidden, self).__init__()
        self.categorical_stack = nn.Sequential(
            nn.Linear(800, 16),
            nn.ReLU(),
            )
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(21, 21),
            nn.ReLU(),
            nn.Linear(21, 21),
            nn.ReLU(),
            nn.Linear(21, 1),
            nn.ReLU(),
        )

    def forward(self, x):
        out = self.categorical_stack(x[:, 5:])
        out = torch.cat((out, x[:, :5]), 1)
        out = self.linear_relu_stack(out)

        return out


In [24]:
data = KCostDataSet()
val_len = int(len(data) * 0.1)
train_len = len(data) - val_len
train, val = torch.utils.data.random_split(data, [train_len, val_len])
train = DataLoader(train, batch_size=1024, num_workers=0, shuffle=True)
val = DataLoader(val, batch_size=1024, num_workers=0, shuffle=False)
val_len, train_len

Read in dataframe
Normalized continous vars


(2434390, 21909515)

In [25]:
# torch.cat((data[0][0][:5], data[0][0][5:]))
data[5]

(tensor([-2.5769e+00, -2.0781e-14, -2.2411e-14, -2.4040e-14, -2.5874e-14,
          0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.

In [26]:
loss_fn = nn.MSELoss()
model = KCost1Hidden()
optimizer = torch.optim.Adam(model.parameters())

In [27]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % (500) == 0:
            loss, current = loss.item(), batch * len(X)
            log.info(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            
def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss = 0

    model.eval()
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()

    test_loss /= num_batches
    log.info(f'validation loss: {test_loss:>8f}\n')

In [None]:
epochs = 32
for t in range(epochs):
    log.info(f"Epoch {t+1}\n-------------------------------")
    train_loop(train, model, loss_fn, optimizer)
    test_loop(val, model, loss_fn)
log.info("Done!")

[21:59:26]: Epoch 1
-------------------------------
[21:59:26]: loss: 465.733948  [    0/21909515]
[21:59:34]: loss: 23.864552  [512000/21909515]
[21:59:40]: loss: 4.059256  [1024000/21909515]
[21:59:47]: loss: 2.831160  [1536000/21909515]
[21:59:53]: loss: 1.433657  [2048000/21909515]
[22:00:00]: loss: 1.388919  [2560000/21909515]
[22:00:07]: loss: 0.906500  [3072000/21909515]
[22:00:14]: loss: 0.641506  [3584000/21909515]
[22:00:21]: loss: 0.418503  [4096000/21909515]
[22:00:27]: loss: 0.290987  [4608000/21909515]
[22:00:34]: loss: 0.319168  [5120000/21909515]
[22:00:41]: loss: 0.264917  [5632000/21909515]
[22:00:50]: loss: 0.328859  [6144000/21909515]
[22:00:58]: loss: 0.208151  [6656000/21909515]
[22:01:07]: loss: 0.170035  [7168000/21909515]
[22:01:15]: loss: 0.204483  [7680000/21909515]
[22:01:24]: loss: 0.221106  [8192000/21909515]
[22:01:33]: loss: 0.286766  [8704000/21909515]
[22:01:41]: loss: 0.181674  [9216000/21909515]
[22:01:50]: loss: 0.126335  [9728000/21909515]
[22:01:5

In [126]:
x, y = data[1]
x, y

(tensor([ 0.2247, -0.8814,  2.6436, -0.8808, -0.8799,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  