In [2]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, ConcatDataset
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np

import tqdm

import data
import anomaly_det_model as anom

In [None]:
# 1) Setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# ckpt_path = '/home/sid/cosmo_compression/src/cosmo_compression/uncond/step=step=2850-val_loss=val_loss=0.432.ckpt'
cdm_path = '/n/netscratch/iaifi_lab/Lab/msliu/CMD/data/IllustrisTNG/'
wdm_path='/n/netscratch/iaifi_lab/Lab/ccuestalazaro/DREAMS/Images/WDM/boxes/'

MAP_TYPE = "Mcdm"
SUITE = "IllustrisTNG"
DATASET = "LH"
MAP_RESOLUTION = 256

# 3) Prepare dataset (single‐sample access)
cdm_train = data.CAMELS(
    root=cdm_path,
    idx_list=range(0, 12000),
    map_type=MAP_TYPE,
    suite=SUITE,
    dataset=DATASET,
    parameters=['Omega_m','sigma_8','A_SN1','A_SN2','A_AGN1','A_AGN2','Omega_b'],
)

wdm_train = data.CAMELS(
    root=wdm_path,
    idx_list=range(0, 12000),
    map_type=MAP_TYPE,
    suite=SUITE,
    dataset="WDM",
    parameters=['Omega_m','sigma_8','A_SN1','A_SN2','A_AGN1','A_AGN2','Wdm_mass'],
)

cdm_test = data.CAMELS(
    root=cdm_path,
    idx_list=range(12000, 15000),
    map_type=MAP_TYPE,
    suite=SUITE,
    dataset=DATASET,
    parameters=['Omega_m','sigma_8','A_SN1','A_SN2','A_AGN1','A_AGN2','Omega_b'],
)

wdm_test = data.CAMELS(
    root=wdm_path,
    idx_list=range(12000, 15000),
    map_type=MAP_TYPE,
    suite=SUITE,
    dataset="WDM",
    parameters=['Omega_m','sigma_8','A_SN1','A_SN2','A_AGN1','A_AGN2','Wdm_mass'],
)

# 4) Wrapper to add class labels
def add_label(dataset, label):
    class LabeledDataset(Dataset):
        def __init__(self, base, lbl):
            self.base = base
            self.lbl = lbl
        def __len__(self):
            return len(self.base)
        def __getitem__(self, idx):
            img, params = self.base[idx]
            return img, float(self.lbl)
    return LabeledDataset(dataset, label)

In [None]:
cdm_train_l = add_label(cdm_train, 0)
wdm_train_l = add_label(wdm_train, 1)
cdm_test_l  = add_label(cdm_test,  0)
wdm_test_l  = add_label(wdm_test,  1)

# 5) Combine and create DataLoaders
train_ds = ConcatDataset([cdm_train_l, wdm_train_l])
test_ds  = ConcatDataset([cdm_test_l,  wdm_test_l])

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4)
test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False, num_workers=4)

model = anom.AnomalyDetectorImg(hidden=5, dr=0.1, channels=1)
model = model.to(device)


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 8) Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for imgs, labels in tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        imgs = imgs.to(device)
        labels = labels.to(device)
        logits = model(imgs)           # feature extractor
        loss     = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
    avg_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_loss:.4f}")

    # evaluation
    model.eval()
    correct = 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            feats = model(imgs)
            preds = (feats).argmax(dim=1)
            correct += (preds == labels).sum().item()
    acc = correct / len(test_loader.dataset)
    print(f"          Test Acc:  {acc:.4f}\n")

Epoch 1/10: 100%|██████████| 375/375 [00:59<00:00,  6.32it/s]

Epoch 1/10 - Train Loss: 0.1063





          Test Acc:  0.9973



Epoch 2/10: 100%|██████████| 375/375 [00:49<00:00,  7.62it/s]

Epoch 2/10 - Train Loss: 0.0016





          Test Acc:  0.9997



Epoch 3/10:  16%|█▌        | 60/375 [00:09<00:49,  6.37it/s]