In [24]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [68]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import pandas as pd

from lion_pytorch import Lion

In [69]:
from vae import VAE
from vae import GroupSoftmax
#from trainer import Trainer



In [70]:
# Device config 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

Device: cuda


In [96]:
group_sizes = [2, 3, 5, 11, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]
test = VAE(375, 500, 6, 375, group_sizes)

total_params = sum(p.numel() for p in test.parameters())
trainable_params = sum(p.numel() for p in test.parameters() if p.requires_grad)
print(trainable_params)


6979875


In [8]:
# Dataset import 
# Windows path
# data = pd.read_csv('A:/csv_hus/psam_hus_pus_filtered.csv')
# WSL path
data = pd.read_csv('/mnt/a/csv_hde/one_hot_data.csv')


In [9]:
cols = list(data.columns)
cols = [col.split(":")[0] for col in cols]

onehot_counts = {col: sum(data.columns.str.startswith(f"{col}:")) for col in cols}
print(list(onehot_counts.values()))

print(sum(onehot_counts.values()))

[2, 3, 5, 11, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]
375


In [10]:
# Data loader 

data_tensor = torch.tensor(data.values, dtype=torch.float32)
data_loader = DataLoader(data_tensor, batch_size=len(data_tensor), shuffle=True)

# Class_sizes 
group_sizes = list(onehot_counts.values())


In [93]:
class Trainer:
    def __init__(self, model, optimizer, device):
        self.model = model
        self.optimizer = optimizer
        self.device = device
        self.model.to(device)


    def kl_loss(self, mu, logvar):
        return 0.5 * torch.mean(mu.pow(2) + logvar.exp() - logvar - 1)
    
    
    def focal_loss(self, x, x_hat, gamma=2, alpha=0.17):
        eps = 1e-7
        x_hat = torch.clamp(x_hat, eps, 1-eps)
        return -torch.mean(alpha * x * (1-x_hat)**gamma * torch.log(x_hat) + (1-alpha) * (1-x) * x_hat**gamma * torch.log(1-x_hat))
    # pre_training uses the focal loss and KL divergence loss
    def train(self, train_loader, epochs):
        for epoch in range(epochs):
            for i, x in enumerate(train_loader):
                x = x.to(self.device)
                x_hat, mu, logvar = self.model(x)
                kl_loss = self.kl_loss(mu, logvar)
                focal_loss = self.focal_loss(x, x_hat)
                print(f'KL Loss: {kl_loss.item()}')
                print(f'Focal Loss: {focal_loss.item()}')

                loss = kl_loss + focal_loss
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            print(f'Epoch {epoch}, Loss: {loss.item()}')



In [97]:
model = VAE(375, 500, 6, 250, group_sizes)
optimizer = Lion(model.parameters(), lr=1e-3)

trainer = Trainer(model, optimizer, device)

In [98]:
# train the model 
trainer.train(data_loader, 1000)

KL Loss: 2.6071369647979736
Focal Loss: 0.021755890920758247
Epoch 0, Loss: 2.6288928985595703
KL Loss: 3.7626640796661377
Focal Loss: 0.018432030454277992
Epoch 1, Loss: 3.7810962200164795
KL Loss: 0.7606384754180908
Focal Loss: 0.018272288143634796
Epoch 2, Loss: 0.778910756111145
KL Loss: 0.6945781111717224
Focal Loss: 0.01689068414270878
Epoch 3, Loss: 0.7114688158035278
KL Loss: 0.6656094193458557
Focal Loss: 0.01624963991343975
Epoch 4, Loss: 0.6818590760231018
KL Loss: 0.6497973203659058
Focal Loss: 0.016561495140194893
Epoch 5, Loss: 0.6663588285446167
KL Loss: 0.6401262283325195
Focal Loss: 0.017929712310433388
Epoch 6, Loss: 0.6580559611320496
KL Loss: 0.6343746185302734
Focal Loss: 0.01613572984933853
Epoch 7, Loss: 0.6505103707313538
KL Loss: 0.6295315623283386
Focal Loss: 0.015923388302326202
Epoch 8, Loss: 0.6454549431800842
KL Loss: 0.6258752942085266
Focal Loss: 0.016329219564795494
Epoch 9, Loss: 0.6422045230865479
KL Loss: 0.6234618425369263
Focal Loss: 0.016151908785

In [99]:
torch.save(model.state_dict(), 'model.pth')