In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, ConcatDataset
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np

import tqdm

import data
import anomaly_det_model as anom

In [2]:
# 1) Setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ckpt_path = '/n/netscratch/iaifi_lab/Lab/msliu/data/anomaly_detection_weights.pth'
base_path = '/n/netscratch/iaifi_lab/Lab/msliu/data'

MAP_TYPE = "Mcdm"
SUITE = "IllustrisTNG"
DATASET = "LH"
MAP_RESOLUTION = 256

cdm_test = data.CAMELS(
    root=base_path,
    idx_list=range(12000, 15000),
    map_type=MAP_TYPE,
    suite=SUITE,
    dataset=DATASET,
    parameters=['Omega_m','sigma_8','A_SN1','A_SN2','A_AGN1','A_AGN2','Omega_b'],
)

wdm_test = data.CAMELS(
    root=base_path,
    idx_list=range(12000, 15000),
    map_type=MAP_TYPE,
    suite=SUITE,
    dataset="WDM",
    parameters=['Omega_m','sigma_8','A_SN1','A_SN2','A_AGN1','A_AGN2','Wdm_mass'],
)

# 4) Wrapper to add class labels
def add_label(dataset, label):
    class LabeledDataset(Dataset):
        def __init__(self, base, lbl):
            self.base = base
            self.lbl = lbl
        def __len__(self):
            return len(self.base)
        def __getitem__(self, idx):
            img, params = self.base[idx]
            return img, float(self.lbl)
    return LabeledDataset(dataset, label)

In [3]:
cdm_test_l  = add_label(cdm_test,  0)
wdm_test_l  = add_label(wdm_test,  1)

test_ds  = ConcatDataset([cdm_test_l,  wdm_test_l])

test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False, num_workers=1)

model = anom.AnomalyDetectorImg(hidden=5, dr=0.1, channels=1)
state_dict = torch.load(ckpt_path, weights_only=True)
model.load_state_dict(state_dict)
model = model.to(device)

In [4]:
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs = imgs.to(device).float()
            labels = labels.to(device).long()
            preds = model(imgs).argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total


print(evaluate_model(model,test_loader))

0.5005
