In [1]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from tabulate import tabulate as tab
import types # to use dictionary as an object

# !pip3 install torch torchvision 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay


# !pip3 install umap-learn
# import umap

#!pip3 install plotly
import plotly.graph_objects as go # interactive plots

# synthData_path = os.path.join('..', 'synthetic_data') |
# sys.path.append(synthData_path)
# import generate_synthetic_data as gsd

import gmvae_architecture as ga
import gmvae_performance_and_validation as gpv

import wandb


In [2]:
import warnings
# Specifically for sklearn/numpy matmul issues
warnings.filterwarnings('ignore', message='.*matmul.*')
# warnings.filterwarnings('ignore', category=RuntimeWarning)

# Train and log 

In [3]:
wandb.init(
    project="GMVAE_paper",
    config={
        "learning_rate": 1e-3, # 1e-4 in paper
        "epochs": 2,           # 100 in paper
        "batch_size": 128,     # 100 in paper
        "z_dim": 8,            # latent dimension 
        "alpha": 50,           # classification weight in the loss (50 in paper)
        "noise_level": 0.001   # availale: 0.0, 0.001, 0.01, 0.05
    }
)
config = wandb.config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data = np.load(f"../synthetic_data/synthetic_training_allCases_120k_noise_{config.noise_level}.npz")
X_train = data["X"]
Y_train = data["y"]
shifts_pileup_train = data["meta"]
config["L"] = X_train.shape[1] # 296

data = np.load(f"../synthetic_data/synthetic_test_allCases_480k_noise_{config.noise_level}.npz")
X_test = data["X"]
Y_test = data["y"]
shifts_pileup_test = data["meta"]

model = ga.GMVAE(L=config.L, z_dim=config.z_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) # weight_decay=1e-5

train_dataset = TensorDataset(
    torch.tensor(X_train, dtype=torch.float32),
    torch.tensor(Y_train, dtype=torch.long)
)
test_dataset = TensorDataset(
    torch.tensor(X_test, dtype=torch.float32),
    torch.tensor(Y_test, dtype=torch.long)
)

train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    drop_last=True
)
test_loader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=True
)

# 4. Training Loop
# Watch gradients
# log_freq=100 logs every 100 batches to keep your dashboard fast
wandb.watch(model, log="all", log_freq=100, log_graph=True)

analyzer = gpv.GMVAEAnalyzer(
    model=model,
    dataloader=train_loader
)
# Tell W&B to use 'epoch' as the X-axis for these specific keys
wandb.define_metric("Clustering/*", step_metric="epoch")
wandb.define_metric("silhouette_score", step_metric="epoch")
wandb.define_metric("reconstruction_sample", step_metric="epoch")
wandb.define_metric("latent_space", step_metric="epoch")

for epoch in range(config.epochs):
    # log every 5 epochs
    ga.train_epoch_wandb(model, train_loader, optimizer, device, alpha=config.alpha)
    if epoch % 5 == 0:
        # ga.log_visualizations(model, test_loader, device, epoch)
        score = gpv.log_clustering_quality(analyzer, epoch, sample_size=5000)
        print(f"Epoch {epoch}: Silhouette Score = {score:.4f}")


# 5. validation
final_analyzer = gpv.GMVAEAnalyzer(model, test_loader, device=device)
gpv.run_final_inference_report(final_analyzer)

# 6. Finish the run
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mebertholet[0m ([33mebertholet-tel-aviv-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


Epoch 0: Silhouette Score = 0.1763
üöÄ Starting Final Inference Evaluation...
‚úÖ Final Report Sent to W&B.


0,1
Clustering/Silhouette_Score,‚ñÅ
ce_loss,‚ñà‚ñÅ
epoch,‚ñÅ
epoch_loss,‚ñà‚ñÅ
kl_div,‚ñà‚ñÅ
recon_loss,‚ñà‚ñÅ
train_acc,‚ñÅ‚ñà

0,1
Clustering/Silhouette_Score,0.17627
ce_loss,0.02079
epoch,0.0
epoch_loss,1.04499
final_silhouette_score,0.0635
kl_div,0.00023
recon_loss,0.00547
train_acc,99.57311
