In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import pandas as pd

from lion_pytorch import Lion

In [3]:
from vae import VAE
from vae import GroupSoftmax
from trainer import Trainer

In [4]:
# Device config 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

Device: cuda


In [6]:
group_sizes = [2, 5, 5, 11, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7]
test = VAE(433, 1500, 6, 500, group_sizes)

total_params = sum(p.numel() for p in test.parameters())
trainable_params = sum(p.numel() for p in test.parameters() if p.requires_grad)
print(trainable_params)


57670299


In [7]:
# Dataset import 
# Windows path
# data = pd.read_csv('A:/csv_hus/psam_hus_pus_filtered.csv')
# WSL path
data = pd.read_csv('/mnt/a/csv_hde/one_hot_data.csv')


In [8]:
cols = list(data.columns)
cols = [col.split(":")[0] for col in cols]

onehot_counts = {col: sum(data.columns.str.startswith(f"{col}:")) for col in cols}
print(list(onehot_counts.values()))

print(sum(onehot_counts.values()))

[2, 5, 5, 11, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7]
433


In [9]:
# Data loader 

data_tensor = torch.tensor(data.values, dtype=torch.float32)
data_loader = DataLoader(data_tensor, batch_size=len(data_tensor), shuffle=True)

# Class_sizes 
group_sizes = list(onehot_counts.values())


In [10]:
model = VAE(433, 1500, 6, 500, group_sizes)
optimizer = Lion(model.parameters(), lr=1e-3)

trainer = Trainer(model, optimizer, device)

In [11]:
# train the model 
model_path = '/mnt/a/models/best_model.pth'
trainer.train(data_loader, 1000, model_path, gamma=2, alpha=0.125)

Epoch 0, Loss: 5.238574028015137
Epoch 1, Loss: 188634432.0
Epoch 2, Loss: 6019.009765625
Epoch 3, Loss: 0.6944985389709473
Epoch 4, Loss: 0.6327990889549255
Epoch 5, Loss: 0.6285586357116699
Epoch 6, Loss: 0.6242597699165344
Epoch 7, Loss: 0.6191036701202393
Epoch 8, Loss: 0.6157796382904053
Epoch 9, Loss: 0.611255943775177
Epoch 10, Loss: 0.6048687696456909
Epoch 11, Loss: 0.6009372472763062
Epoch 12, Loss: 0.599827229976654
Epoch 13, Loss: 0.5979099869728088
Epoch 14, Loss: 0.5965037941932678
Epoch 15, Loss: 0.5951194167137146
Epoch 16, Loss: 0.5938442945480347
Epoch 17, Loss: 0.592543363571167
Epoch 18, Loss: 0.5912529230117798
Epoch 19, Loss: 0.5899651646614075
Epoch 20, Loss: 0.5887374877929688
Epoch 21, Loss: 0.5874881148338318
Epoch 22, Loss: 0.5863428711891174
Epoch 23, Loss: 0.5849425792694092
Epoch 24, Loss: 0.5836571455001831
Epoch 25, Loss: 0.5823898911476135
Epoch 26, Loss: 0.5813723802566528
Epoch 27, Loss: 0.5800507664680481
Epoch 28, Loss: 0.5789521932601929
Epoch 29, 

In [12]:
torch.save(model.state_dict(), 'model.pth')