### DAGMM on Arrhythmia

This notebook trains **DAGMM** with **Gaussian / Laplace / Student‑t** mixture components on the **Arrhythmia** dataset.
You can switch the mixture distribution with the `dist_type` parameter below (`'gaussian'|'laplace'|'student_t'`).

In [15]:
## If you haven't installed the repo dependencies in this environment, uncomment and run:
# !pip install -r requirements.txt
# !pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu

In [16]:
import os, sys, json, torch
from pathlib import Path
# Assuming this notebook lives inside the repo root; otherwise adjust:
sys.path.append(str(Path().resolve()))
from arrhythmia import ArrhythmiaLoader
from model import DaGMM
from solver import Solver

In [17]:
# ==== Configuration ====
data_path = 'arrhythmia.data'   # change if your dataset lives elsewhere
dist_type = 'gaussian'   # 'gaussian' | 'laplace' | 'student_t'
student_nu = 4.0         # only used if dist_type == 'student_t'
mode = 'train'           # 'train' or 'test'

# Training params
batch_size = 1024  # adjust per dataset size
num_epochs = 100
lr = 1e-4
gmm_k = 4
lambda_energy = 0.1
lambda_cov_diag = 0.005

In [18]:
# ==== Data loader ====
dataset = ArrhythmiaLoader(data_path, mode=mode)
from torch.utils.data import DataLoader
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=(mode=='train'))
print(f'Train set size: {len(dataset.train) if mode=="train" else "N/A"}')
print(f'Test set size : {len(dataset.test) if mode=="test" else len(dataset.test)}')

Train set size: 193
Test set size : 259


The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  features[col].fillna(features[col].mean(), inplace=True)


In [19]:
# ==== Initialize model & solver ====
config = {
    'lr': lr,
    'num_epochs': num_epochs,
    'batch_size': batch_size,
    'gmm_k': gmm_k,
    'lambda_energy': lambda_energy,
    'lambda_cov_diag': lambda_cov_diag,
    'dist_type': dist_type,
    'student_nu': student_nu,
    'model_save_path': './models',
    'input_dim': 279
}
solver = Solver(data_loader, config)
# Adjust input/output dimensions automatically (279 attributes after imputation)
input_dim = dataset.train.shape[1] if mode=='train' else dataset.test.shape[1]
solver.dagmm.encoder[0] = torch.nn.Linear(input_dim, solver.dagmm.encoder[0].out_features)
solver.dagmm.decoder[-1] = torch.nn.Linear(solver.dagmm.decoder[-1].in_features, input_dim)

In [20]:
# ==== Train or Test ====
if mode == 'train':
    solver.train()
else:
    solver.test()

100%|██████████| 1/1 [00:00<00:00, 55.05it/s]
100%|██████████| 1/1 [00:00<00:00, 70.87it/s]
100%|██████████| 1/1 [00:00<00:00, 128.17it/s]
100%|██████████| 1/1 [00:00<00:00, 143.60it/s]
100%|██████████| 1/1 [00:00<00:00, 128.65it/s]
100%|██████████| 1/1 [00:00<00:00, 148.48it/s]
100%|██████████| 1/1 [00:00<00:00, 129.60it/s]
100%|██████████| 1/1 [00:00<00:00, 124.70it/s]
100%|██████████| 1/1 [00:00<00:00, 122.60it/s]
100%|██████████| 1/1 [00:00<00:00, 118.82it/s]
100%|██████████| 1/1 [00:00<00:00, 132.01it/s]
100%|██████████| 1/1 [00:00<00:00, 116.92it/s]
100%|██████████| 1/1 [00:00<00:00, 153.33it/s]
100%|██████████| 1/1 [00:00<00:00, 127.93it/s]
100%|██████████| 1/1 [00:00<00:00, 157.40it/s]
100%|██████████| 1/1 [00:00<00:00, 171.88it/s]
100%|██████████| 1/1 [00:00<00:00, 169.80it/s]
100%|██████████| 1/1 [00:00<00:00, 170.04it/s]
100%|██████████| 1/1 [00:00<00:00, 160.86it/s]
100%|██████████| 1/1 [00:00<00:00, 159.67it/s]
100%|██████████| 1/1 [00:00<00:00, 176.49it/s]
100%|██████████

In [21]:
print(f"Results for {dist_type} distribution:")
solver.test()

Results for gaussian distribution:
Threshold : -8.932000732421875
Accuracy : 0.7529, Precision : 0.5161, Recall : 0.4848, F-score : 0.5000


(0.752895752895753, 0.5161290322580645, 0.48484848484848486, 0.5)