In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import pandas as pd

from lion_pytorch import Lion

In [3]:
data_set = 'one_hot_noNaNremoved.csv'
intermediate_models = 'best_model_fNaN.pth'
output_model = 'model_agep_fNaN.pth'

In [4]:
from vae import VAE
from vae import GroupSoftmax
from trainer import Trainer

In [5]:
# Device config 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

Device: cuda


In [6]:
group_sizes = [2, 5, 5, 11, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7]
test = VAE(433, 1500, 6, 500, group_sizes)

total_params = sum(p.numel() for p in test.parameters())
trainable_params = sum(p.numel() for p in test.parameters() if p.requires_grad)
print(trainable_params)


57670299


In [7]:
# Dataset import 
# Windows path
# data = pd.read_csv('A:/csv_hus/psam_hus_pus_filtered.csv')
# WSL path
data = pd.read_csv(f'/workspace/data/{data_set}')


In [8]:
cols = list(data.columns)
cols = [col.split(":")[0] for col in cols]

onehot_counts = {col: sum(data.columns.str.startswith(f"{col}:")) for col in cols}
print(list(onehot_counts.values()))

print(sum(onehot_counts.values()))

[3, 6, 6, 12, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7]
526


In [9]:
# Data loader 

data_tensor = torch.tensor(data.values, dtype=torch.float32)
data_loader = DataLoader(data_tensor, batch_size=len(data_tensor), shuffle=True)

# Class_sizes 
group_sizes = list(onehot_counts.values())


In [10]:
model = VAE(526, 1500, 6, 500, group_sizes)
# load pretrained model 
#params = torch.load('/workspace/models/model.pth')
#
#model.load_state_dict(params)

In [11]:
# optimizer
optimizer = Lion(model.parameters(), lr=1e-3)

trainer = Trainer(model, optimizer, device)

In [12]:
# train the model 
model_path = f'/workspace/models/{intermediate_models}'
trainer.train(data_loader, 4000, model_path, gamma=2, alpha=0.122)

Epoch 0, Loss: 10.687446594238281
Epoch 1, Loss: 985.0857543945312
Epoch 2, Loss: 0.6982658505439758
Epoch 3, Loss: 0.6279373168945312
Epoch 4, Loss: 0.6182161569595337
Epoch 5, Loss: 0.6118318438529968
Epoch 6, Loss: 0.6085224151611328
Epoch 7, Loss: 0.6067404747009277
Epoch 8, Loss: 0.6028412580490112
Epoch 9, Loss: 0.6021122336387634
Epoch 10, Loss: 0.5984706878662109
Epoch 11, Loss: 0.596794843673706
Epoch 12, Loss: 0.5973443984985352
Epoch 13, Loss: 0.5953237414360046
Epoch 14, Loss: 0.5934737920761108
Epoch 15, Loss: 0.5924119353294373
Epoch 16, Loss: 0.5903456807136536
Epoch 17, Loss: 0.588961124420166
Epoch 18, Loss: 0.5888233184814453
Epoch 19, Loss: 0.5871543884277344
Epoch 20, Loss: 0.5853736400604248
Epoch 21, Loss: 0.5839512348175049
Epoch 22, Loss: 0.5836941003799438
Epoch 23, Loss: 0.5818212032318115
Epoch 24, Loss: 0.5805931687355042
Epoch 25, Loss: 0.5793009996414185
Epoch 26, Loss: 0.5782347321510315
Epoch 27, Loss: 0.5768688321113586
Epoch 28, Loss: 0.576035916805267

In [13]:
torch.save(model.state_dict(), f'/workspace/models/{output_model}')