In [1]:
import torch
import logging
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split

import pandas as pd
logging.basicConfig(
    level=logging.INFO,
    format='[%(asctime)s]: %(message)s',
    datefmt='%H:%M:%S'
)

log = logging.getLogger()

In [2]:
class KCostDataSet(Dataset):
    def __init__(self, transform=None, target_transform=None):
        self.df = pd.read_csv('/Users/ndhuynh/sandbox/data/cost_surface_k.csv')
        data = self.df['K'].map(lambda x: list(map(int, x[1:-1].split())))
        Ks = pd.DataFrame(data.to_list()).add_prefix('K_').fillna(0)
        self.df = pd.concat([self.df, Ks], axis=1)
        
        max_levels = self.df.query('T == 2')['K'].apply(lambda x: len(x[1:-1].split())).max()
        input_cols = ['h', 'T', 'z0', 'z1', 'q', 'w'] + [f'K_{i}' for i in range(max_levels)]
        output_cols = ['new_cost']
        
        mean = self.df[input_cols].mean()
        std = self.df[input_cols].std()
        std[std == 0] = 1
        self.df[input_cols] = (self.df[input_cols] - mean) / std
        
        self.inputs = torch.from_numpy(self.df[input_cols].values).float()
        self.outputs = torch.from_numpy(self.df[output_cols].values).float()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        inputs = self.inputs[idx]
        label = self.outputs[idx]

        return inputs, label

In [3]:
class KCostNeuralNet(nn.Module):
    def __init__(self):
        super(KCostNeuralNet, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(21, 21),
            nn.ReLU(),
            nn.Linear(21, 21),
            nn.ReLU(),
            nn.Linear(21, 1),
            nn.ReLU(),
        )

    def forward(self, x):
        out = self.linear_relu_stack(x)
        return out

In [4]:
data = KCostDataSet()
val_len = len(data) - int(len(data) * 0.05)
train_len = len(data) - val_len
train, val = torch.utils.data.random_split(data, [train_len, val_len])
train = DataLoader(train, batch_size=1024, num_workers=0, shuffle=True)
val = DataLoader(val, batch_size=1024, num_workers=0, shuffle=False)

In [5]:
loss_fn = nn.MSELoss()
model = KCostNeuralNet()
optimizer = torch.optim.Adam(model.parameters())

In [6]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % (500) == 0:
            loss, current = loss.item(), batch * len(X)
            log.info(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            
def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss = 0

    model.eval()
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()

    test_loss /= num_batches
    log.info(f'validation loss: {test_loss:>8f}\n')

In [7]:
epochs = 32
for t in range(epochs):
    log.info(f"Epoch {t+1}\n-------------------------------")
    train_loop(train, model, loss_fn, optimizer)
    test_loop(val, model, loss_fn)
log.info("Done!")

[16:35:49]: Epoch 1
-------------------------------
[16:35:49]: loss: 527.707825  [    0/1217195]
[16:35:50]: loss: 19.901567  [512000/1217195]
[16:35:51]: loss: 5.366539  [1024000/1217195]
[16:36:55]: validation loss: 5.128123

[16:36:55]: Epoch 2
-------------------------------
[16:36:56]: loss: 5.365809  [    0/1217195]
[16:36:58]: loss: 2.816974  [512000/1217195]
[16:37:00]: loss: 1.969640  [1024000/1217195]
[16:38:39]: validation loss: 1.673957

[16:38:39]: Epoch 3
-------------------------------
[16:38:39]: loss: 1.891552  [    0/1217195]
[16:38:41]: loss: 1.569899  [512000/1217195]
[16:38:43]: loss: 1.314659  [1024000/1217195]
[16:40:23]: validation loss: 1.173841

[16:40:23]: Epoch 4
-------------------------------
[16:40:23]: loss: 0.948044  [    0/1217195]
[16:40:25]: loss: 1.041244  [512000/1217195]
[16:40:28]: loss: 0.819875  [1024000/1217195]
[16:42:06]: validation loss: 0.920453

[16:42:06]: Epoch 5
-------------------------------
[16:42:06]: loss: 0.875225  [    0/121719

In [17]:
x, y = data[1]
x, y

(tensor([-2.5769e+00, -3.3111e+00, -2.0781e-14, -2.2411e-14, -2.4040e-14,
         -2.5874e-14, -2.3647e+00, -1.6430e+00, -9.3732e-01,  1.1826e+00,
          9.9400e+00,  2.7035e+01,  5.7396e+01,  8.6080e+01,  1.1828e+02,
          6.3374e+01, -1.0787e-02, -7.9190e-03, -5.7924e-03, -4.2199e-03,
         -3.0614e-03]),
 tensor([8.1739]))

In [16]:
model(data[1][0])

tensor([7.5977], grad_fn=<ReluBackward0>)