In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

from utils import get_loaders
from models import SSLNet
from trainers import train_ssd
from evals import evaluate

print(torch.cuda.is_available())

In [None]:
config = {
    'method': 'ssd',
    'temperature': 0.07,
    'contrast_mode': 'one',
    
    # Training
    'epochs': 1000,
    'lr': 5e-4,
    'weight_decay': 1e-5,
    'patience': 50,
    'min_delta': 1e-8,
    'sched_patience': 10,
    'sched_factor': 0.5,
    
    # Network
    'dims' : [30, 16, 8],
    'drop': 0.1,
    'norm': True,
    'activation': 'ReLU',
    
    # Dataset
    'val_split': 0.1,
    'test_split': 0.1,
    'batch_size': 1024,
    
    # Utility
    'seed': 15,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'data_path': './Data/creditcard.csv',
    'print_freq': 20,
}

In [None]:
train_loader, val_loader, test_loader = get_loaders(config['data_path'], config['val_split'], config['test_split'], config['seed'], config['batch_size'], config['method'])

# Training

In [None]:
model = SSLNet(config).to(config['device'])
print(model)

config['temperature'] = 0.07
config['seed'] = 15

In [None]:
model, train_losses, val_losses, train_ap, val_ap, train_fpr, val_fpr = train_ssd(model, train_loader, val_loader, config)

In [None]:
# Plot the loss
plt.plot(train_losses, label = "Train")
plt.plot(val_losses, label = "Val")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
#Plot the metrics
plt.plot(train_fpr, label = "Train FPR@95")
plt.plot(val_fpr, label = "Val FPR@95")
plt.plot(train_ap, label = "Train AUPRC")
plt.plot(val_ap, label = "Val AUPRC")
plt.xlabel("Epoch")
plt.ylabel("Metrics")
plt.legend()
plt.show()

In [None]:
test_ap, test_fpr = evaluate(model, train_loader, test_loader)

In [None]:
print(f'Train FPR@95: {train_fpr[-config['patience']]:.4f}, Val FPR@95: {val_fpr[-config['patience']]:.4f}, Test FPR@95: {test_fpr:.4f}')
print(f'Train AUPRC: {train_ap[-config['patience']]:.4f}, Val AUPRC: {val_ap[-config['patience']]:.4f}, Test AUPRC: {test_ap:.4f}')