In [1]:
import sys
sys.path.append('./src')

from chatgpt_util_imports import generate_dataset
from data_loader import PatchDataset
from models import VAE
from train import load_checkpoint, print_loss_metrics, save_checkpoint, train_one_epoch, validate
from util import get_date_and_time

import numpy as np
import os
import pandas as pd
import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split

In [2]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [3]:
# Load +31k patches as compiled by `bwhitman` for `learnfm`.
fpath = 'data/compact.bin'

df = generate_dataset(fpath)
df = pd.DataFrame(df)

print(f"Num patches: {len(df)}")
print(f"Num features: {len(df.keys())}")

# df.hist(figsize=(50, 50))
# plt.show()

Num patches: 31380
Num features: 146


In [4]:
p_train = 0.8
batch_size = 32

dataset = PatchDataset(df)

n_train = int(p_train * len(dataset))
n_val = len(dataset) - n_train

train_dataset, val_dataset = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

In [5]:
n_features = len(dataset.parameter_names)
n_hidden = 64
n_latent = 16

model = VAE(n_features, n_hidden, n_latent)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)

In [6]:
# Training.
n_epochs = 1000
checkpoint_dir = 'checkpoints/'
save_interval_epochs = 10

best_loss = np.inf
for i in range(n_epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer, device)
    val_loss = validate(model, val_loader, device)

    print(f"Epoch {i}")
    print_loss_metrics(train_loss, val_loss)

    if i % save_interval_epochs != 0:
        continue

    if val_loss['loss'] < best_loss:
        fname = f"chkpnt_{get_date_and_time()}_epoch{i}.pt"
        fpath = os.path.join(checkpoint_dir, fname)

        save_checkpoint(model, optimizer, fpath)

        print("=" * 80)
        print(f"Saved checkpoint at: {fpath}")
        print("=" * 80)

        best_loss = val_loss['loss']

Epoch 0
Trn:   0.0430	MSE: 0.0319	KLD: 0.0111
Val:   0.0424	MSE: 0.0316	KLD: 0.0109

Saved checkpoint at: checkpoints/chkpnt_20231001_215119_epoch0.pt


Epoch 1
Trn:   0.0424	MSE: 0.0319	KLD: 0.0106
Val:   0.0419	MSE: 0.0315	KLD: 0.0104

Epoch 2
Trn:   0.0419	MSE: 0.0318	KLD: 0.0101
Val:   0.0414	MSE: 0.0315	KLD: 0.0099

Epoch 3
Trn:   0.0415	MSE: 0.0318	KLD: 0.0097
Val:   0.0410	MSE: 0.0315	KLD: 0.0095

Epoch 4
Trn:   0.0411	MSE: 0.0318	KLD: 0.0093
Val:   0.0406	MSE: 0.0315	KLD: 0.0091

Epoch 5
Trn:   0.0407	MSE: 0.0318	KLD: 0.0089
Val:   0.0403	MSE: 0.0315	KLD: 0.0088

Epoch 6
Trn:   0.0403	MSE: 0.0318	KLD: 0.0086
Val:   0.0399	MSE: 0.0314	KLD: 0.0085

Epoch 7
Trn:   0.0400	MSE: 0.0317	KLD: 0.0083
Val:   0.0396	MSE: 0.0314	KLD: 0.0082

Epoch 8
Trn:   0.0397	MSE: 0.0317	KLD: 0.0080
Val:   0.0393	MSE: 0.0314	KLD: 0.0079

Epoch 9
Trn:   0.0394	MSE: 0.0317	KLD: 0.0077
Val:   0.0391	MSE: 0.0314	KLD: 0.0077

Epoch 10
Trn:   0.0392	MSE: 0.0317	KLD: 0.0075
Val:   0.0388	MSE: 0.0314	KLD: 0.00