In [9]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, ConcatDataset
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np

import tqdm

from cosmo_compression.data import data
from cosmo_compression.downstream import anomaly_det_model as anom

In [None]:
# 1) Setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataset_path = '/monolith/global_data/astro_compression/CAMELS/'

MAP_TYPE = "Mcdm"
SUITE = "IllustrisTNG"
DATASET = "LH"
MAP_RESOLUTION = 256

# 3) Prepare dataset (single‐sample access)
cdm_train = data.CAMELS(
    root=dataset_path,
    idx_list=range(0, 1000),
    map_type=MAP_TYPE,
    suite=SUITE,
    dataset=DATASET,
    parameters=['Omega_m','sigma_8','A_SN1','A_SN2','A_AGN1','A_AGN2','Omega_b'],
)

wdm_train = data.CAMELS(
    root=dataset_path,
    idx_list=range(0, 1000),
    map_type=MAP_TYPE,
    suite=SUITE,
    dataset="WDM",
    parameters=['Omega_m','sigma_8','A_SN1','A_SN2','A_AGN1','A_AGN2','Wdm_mass'],
)

cdm_test = data.CAMELS(
    root=dataset_path,
    idx_list=range(12000, 15000),
    map_type=MAP_TYPE,
    suite=SUITE,
    dataset=DATASET,
    parameters=['Omega_m','sigma_8','A_SN1','A_SN2','A_AGN1','A_AGN2','Omega_b'],
)

wdm_test = data.CAMELS(
    root=dataset_path,
    idx_list=range(12000, 15000),
    map_type=MAP_TYPE,
    suite=SUITE,
    dataset="WDM",
    parameters=['Omega_m','sigma_8','A_SN1','A_SN2','A_AGN1','A_AGN2','Wdm_mass'],
)

def add_label(dataset, label):
    class LabeledDataset(Dataset):
        def __init__(self, base, lbl):
            self.base = base
            self.lbl = lbl
        def __len__(self):
            return len(self.base)
        def __getitem__(self, idx):
            img, params = self.base[idx]
            return img, self.lbl
    return LabeledDataset(dataset, label)

In [11]:
cdm_train_l = add_label(cdm_train, 0)
wdm_train_l = add_label(wdm_train, 1)
cdm_test_l  = add_label(cdm_test,  0)
wdm_test_l  = add_label(wdm_test,  1)

train_ds = ConcatDataset([cdm_train_l, wdm_train_l])
test_ds  = ConcatDataset([cdm_test_l,  wdm_test_l])

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4)
test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False, num_workers=4)

model = anom.AnomalyDetectorImg(hidden=5, dr=0.1, channels=1)
model = model.to(device)


In [12]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 8) Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for imgs, labels in tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        imgs = imgs.to(device)
        labels = labels.to(device)
        logits = model(imgs)           # feature extractor
        loss     = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
    avg_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_loss:.4f}")

    # evaluation
    model.eval()
    correct = 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            feats = model(imgs)
            preds = (feats).argmax(dim=1)
            correct += (preds == labels).sum().item()
    acc = correct / len(test_loader.dataset)
    print(f"          Test Acc:  {acc:.4f}\n")

Epoch 1/10: 100%|██████████| 32/32 [00:02<00:00, 13.02it/s]

Epoch 1/10 - Train Loss: 0.8134





          Test Acc:  0.5245



Epoch 2/10: 100%|██████████| 32/32 [00:02<00:00, 13.04it/s]

Epoch 2/10 - Train Loss: 0.3267





          Test Acc:  0.6390



Epoch 3/10: 100%|██████████| 32/32 [00:02<00:00, 13.03it/s]

Epoch 3/10 - Train Loss: 0.0662





          Test Acc:  0.8127



Epoch 4/10: 100%|██████████| 32/32 [00:02<00:00, 12.96it/s]

Epoch 4/10 - Train Loss: 0.0129





          Test Acc:  0.9252



Epoch 5/10: 100%|██████████| 32/32 [00:02<00:00, 13.01it/s]

Epoch 5/10 - Train Loss: 0.0023





          Test Acc:  0.9348



Epoch 6/10: 100%|██████████| 32/32 [00:02<00:00, 13.01it/s]

Epoch 6/10 - Train Loss: 0.0015





          Test Acc:  0.9372



Epoch 7/10: 100%|██████████| 32/32 [00:02<00:00, 12.96it/s]

Epoch 7/10 - Train Loss: 0.0010





          Test Acc:  0.9413



Epoch 8/10: 100%|██████████| 32/32 [00:02<00:00, 13.06it/s]

Epoch 8/10 - Train Loss: 0.0007





          Test Acc:  0.9448



Epoch 9/10: 100%|██████████| 32/32 [00:02<00:00, 13.02it/s]

Epoch 9/10 - Train Loss: 0.0005





          Test Acc:  0.9438



Epoch 10/10: 100%|██████████| 32/32 [00:02<00:00, 13.02it/s]

Epoch 10/10 - Train Loss: 0.0004





          Test Acc:  0.9480



In [None]:
full_ds = np.load(os.path.join(dataset_path, 'Maps_Mcdm_IllustrisTNG_LH_z=0.00.npy'), allow_pickle=True)

# Save the first thousand to disk as cdm_train.npy
np.save('cdm_train.npy', full_ds[:1000])
np.save('cdm_test.npy', full_ds[-1000:])

full_ds = np.load(os.path.join(dataset_path, 'Maps_Mcdm_IllustrisTNG_WDM_z=0.00.npy'), allow_pickle=True)
np.save('wdm_train.npy', full_ds[:1000])
np.save('wdm_test.npy', full_ds[-1000:])


