In [1]:
import os
import yaml
import torch
import pandas as pd
from torch.utils.data import DataLoader, ConcatDataset, Dataset
import os
from datasets.dataset_spec import SpectrogramDataset
from torch.nn.functional import adaptive_avg_pool2d

import torch.optim as optim
import torch.nn.functional as F
from sklearn.metrics import (
    roc_auc_score, precision_score,
    recall_score, f1_score
)

from datasets.loader_common import (
    select_dirs,
    get_machine_type_dict,
)
from models.branch_pretrained    import BranchPretrained
from models.branch_transformer_ae import BranchTransformerAE
from models.branch_contrastive   import BranchContrastive
from models.branch_diffusion     import BranchDiffusion
from models.branch_flow          import BranchFlow
from models.fusion_attention     import FusionAttention

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# result columns you requested
result_column_dict = {
    "single_domain": [
        "section", "AUC", "pAUC", "precision", "recall", "F1 score"
    ],
    "source_target": [
        "section",
        "AUC (source)", "AUC (target)",
        "pAUC",
        "pAUC (source)", "pAUC (target)",
        "precision (source)", "precision (target)",
        "recall (source)", "recall (target)",
        "F1 score (source)", "F1 score (target)"
    ]
}

In [8]:
def save_checkpoint(model, optimizer, epoch, path):
    torch.save({
        'epoch':       epoch,
        'model_state': model.state_dict(),
        'optim_state': optimizer.state_dict()
    }, path)



def evaluate(model, loader, device):
    model.eval()
    secs, scores, labels = [], [], []

    with torch.no_grad():
        for feats, labs, sections in loader:
            # feats: [B, H, W]  →  x: [B,1,H,W]
            x     = feats.to(device).squeeze().unsqueeze(1)
            labs  = labs.to(device)

            # ── Branch 1: embedding only ────────────────────────────
            z1 = b1(x)

            # ── Branch 2: transformer‐AE + reconstruction loss ──────
            recon2, z2 = b2(x)
            loss2 = F.mse_loss(recon2, x)

            # ── Branch 3: contrastive, returns (z3, loss3) ─────────
            z3, loss3 = b3(x, labs)

            # ── Branch 5: fusion input branch ───────────────────────
            z_cat = torch.cat([z1, z2, z3], dim=1)
            loss5 = b5(z_cat)

            # ── Fusion: stack the three anomaly scores and predict ──
            # shape: [B, 3]
            anomaly_vector = torch.stack([loss2, loss3, loss5], dim=1)
            scores_batch   = fusion(anomaly_vector)

            # ── Collect for metrics ─────────────────────────────────
            scores.extend(scores_batch.cpu().tolist())
            labels.extend(labs.cpu().tolist())
            secs.extend(sections)

    # ── Per‐section aggregation ─────────────────────────────────
    df      = pd.DataFrame({'section': secs, 'score': scores, 'label': labels})
    results = []
    for sec, grp in df.groupby('section'):
        y_true   = grp['label'].values
        y_score  = grp['score'].values

        auc_val   = roc_auc_score(y_true, y_score)
        p_auc_val = roc_auc_score(y_true, y_score, max_fpr=0.1)
        preds     = (y_score >= 0.5).astype(int)

        results.append({
            'section':   sec,
            'AUC':       auc_val,
            'pAUC':      p_auc_val,
            'precision': precision_score(y_true, preds, zero_division=0),
            'recall':    recall_score(y_true, preds, zero_division=0),
            'F1 score':  f1_score(y_true, preds, zero_division=0)
        })

    return results



# wrap to attach labels & section
class WrappedSpecDS(Dataset):
    def __init__(self, ds, train: bool):
        self.ds    = ds
        self.train = train

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        # grab whatever the underlying dataset returns
        item = self.ds[idx]

        # unpack first two elements, ignore any extras
        spec, fname, *rest = item   # rest could be an attr‐dict, metadata, etc.

        # label = 0 for normal in train, or infer anomaly from filename
        lbl = 0 if self.train else int("_anomaly_" in fname)

        # extract section code (e.g. "00")
        sec = fname.split("_")[1]

        return spec, lbl, sec


def pad_collate(batch):
    specs, labels, secs = zip(*batch)  
    # each spec is [1, H, Wi]
    max_W = max(s.shape[-1] for s in specs)
    padded = [
        F.pad(s, (0, max_W - s.shape[-1]))   # keeps shape [1, H, max_W]
        for s in specs
    ]
    specs_tensor = torch.stack(padded, dim=0)  # [B,1,H,max_W]
    labels_tensor = torch.tensor(labels)
    return specs_tensor, labels_tensor, list(secs)

In [4]:
# ── Settings ───────────────────────────────────────────────────────────
mode            = 'dev'
config_path     = "/lustre1/g/geog_pyloo/11_octa/dcase2025-asd/config.yaml"
baseline_config = "/lustre1/g/geog_pyloo/11_octa/dcase2023_task2_baseline_ae/baseline.yaml"
device          = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cfg     = yaml.safe_load(open(config_path))
root    = cfg['dev_data_root'] if mode=='dev' else cfg['eval_data_root']
name    = 'DCASE2025T2'
param   = yaml.safe_load(open(baseline_config))
param["dev_directory"] = "/lustre1/g/geog_pyloo/11_octa/dcase2023_task2_baseline_ae/data/dcase2025t2/dev_data/raw"
base_dirs = select_dirs(param, mode=(mode=='dev'))

mt_dict   = get_machine_type_dict(name, mode=(mode=='dev'))


2025-05-18 23:25:18,701 - INFO - load_directory <- development


In [7]:
train_dsets = []
eval_dsets  = []

for mt, sect_info in mt_dict['machine_type'].items():
    for sec in sect_info['dev']:
        # raw spectrogram datasets:
        ds_train_raw = SpectrogramDataset(
            base_dir     = cfg['dev_data_root'],
            machine_type = mt,
            section      = sec,
            mode         = 'train',
            config       = cfg
        )
        ds_sup_raw   = SpectrogramDataset(
            base_dir     = cfg['dev_data_root'],
            machine_type = mt,
            section      = sec,
            mode         = 'supplemental',
            config       = cfg
        )
        ds_test_raw  = SpectrogramDataset(
            base_dir     = cfg['dev_data_root'],
            machine_type = mt,
            section      = sec,
            mode         = 'test',
            config       = cfg
        )

        # wrap them so that each sample = (spec, label:int, section:str)
        if len(ds_train_raw):
            train_dsets.append(WrappedSpecDS(ds_train_raw, train=True))
        if len(ds_sup_raw):
            train_dsets.append(WrappedSpecDS(ds_sup_raw,   train=True))
        if len(ds_test_raw):
            eval_dsets.append(WrappedSpecDS(ds_test_raw,  train=False))


# concatenate all per-section datasets
full_train_ds = ConcatDataset(train_dsets)
full_eval_ds  = ConcatDataset(eval_dsets)

# finally, wrap in DataLoaders
train_loader = DataLoader(
    full_train_ds,
    batch_size=cfg['batch_size'],
    shuffle=True,
    num_workers=0,
    collate_fn=pad_collate
)

eval_loader = DataLoader(
    full_eval_ds,
    batch_size   = cfg['batch_size'],
    shuffle      = False,
    # num_workers  = cfg.get('num_workers', 4)
    num_workers = 0,
    collate_fn=pad_collate
)


# ── Instantiate branches ────────────────────────────────────────────────
b1 = BranchPretrained(cfg['ast_model'], cfg).to(device)
b2 = BranchTransformerAE(cfg['latent_dim'], cfg).to(device)

print("AE pos-emb len:", b2.encoder.embeddings.position_embeddings.shape[1])
print("Expected seq_len:", 2 + ((cfg['n_mels'] - b2.encoder.config.patch_size)//b2.encoder.config.frequency_stride + 1) *
                                ((cfg['time_steps'] - b2.encoder.config.patch_size)//b2.encoder.config.time_stride + 1))

b3 = BranchContrastive(cfg['latent_dim'], cfg).to(device)
# b4 = BranchDiffusion(
#     image_size     = (cfg['n_mels'], cfg['time_steps']),                # height of spectrogram
#     unet_dim       = cfg['diffusion_unet_dim'],    # UNet base channels
#     unet_dim_mults = tuple(cfg['diffusion_mults']),# UNet channel multipliers
#     timesteps      = cfg['diffusion_steps']        # diffusion steps
# ).to(device)
b5     = BranchFlow(cfg['flow_dim']).to(device)
fusion = FusionAttention(num_branches=4).to(device)

optimizer = optim.Adam(
    list(b1.parameters()) + 
    list(b2.parameters()) +
    list(b3.parameters()) + 
    # list(b4.parameters()) +
    list(b5.parameters()) + 
    list(fusion.parameters()),
    lr=float(cfg['lr'])
)

best_auc = 0.0
os.makedirs(cfg['save_dir'], exist_ok=True)

metrics_csv = os.path.join(cfg['save_dir'], 'metrics_all_epochs.csv')


AST pos-emb old_len=1214, new_len=252
AE pos-emb len: 252
Expected seq_len: 252


In [6]:
# b1.train(); b2.train(); b3.train(); b4.train(); b5.train(); fusion.train()
# with torch.no_grad():
#     dummy = torch.randn(2,1, cfg['n_mels'], cfg['time_steps']).to(device)
#     z1 = b1(dummy)  # should now succeed without errors
# print("BranchPretrained forward OK →", z1.shape)

In [6]:
import gc
gc.collect()

40

In [10]:
# ── Training + Evaluation Loop ──────────────────────────────────────────
for epoch in range(1, cfg['epochs']+1):
    # train
    b1.train(); b2.train(); b3.train(); b5.train(); fusion.train()
    total_loss = 0.0
    for feats, labels, sections in train_loader:
        feats = feats.squeeze()        # now [B, H, W]
        feats = feats.unsqueeze(1)     # back to [B, 1, H, W]
        print(feats.shape)
        feats  = feats.to(device)  # shape: [B,1,H,W]
        labels = labels.to(device)
        # forward each branch
        z1 = b1(feats)  
        print(f"z1.shape = {z1.shape}")
        recon2, z2 = b2(feats)
        print(f"recon2.shape = {recon2.shape}, feats.shape = {feats.shape}")
        print(f"z2.shape = {z2.shape}")
        feats_ds = adaptive_avg_pool2d(feats, (cfg['n_mels'], recon2.shape[-1]))
        loss2 = F.mse_loss(recon2, feats_ds)
        z3, loss3 = b3(feats, labels)
        print(f"z3.shape = {z3.shape}, feats_ds.shape = {feats_ds.shape}")
        # loss4    = b4(feats)
        z_cat    = torch.cat([z1, z2, z3], dim=1)
        loss5    = b5(z_cat)

        total_branch_loss = (
            cfg['w2']*loss2 + 
            cfg['w3']*loss3 +
            # cfg['w4']*loss4 + 
            cfg['w5']*loss5
        )
        optimizer.zero_grad()
        total_branch_loss.backward()
        optimizer.step()
        total_loss += total_branch_loss.item()

    # evaluate
    epoch_results = evaluate(fusion, eval_loader, device)
    epoch_auc     = sum(r['AUC'] for r in epoch_results) / len(epoch_results)

    # save checkpoints
    save_checkpoint(fusion, optimizer, epoch,
                    os.path.join(cfg['save_dir'], 'checkpoint_last.pth'))
    if epoch_auc > best_auc:
        best_auc = epoch_auc
        save_checkpoint(fusion, optimizer, epoch,
                        os.path.join(cfg['save_dir'], 'checkpoint_best.pth'))

    # dump metrics
    df = pd.DataFrame(epoch_results)[ result_column_dict['single_domain'] ]
    df['epoch'] = epoch
    if epoch == 1:
        df.to_csv(metrics_csv, index=False)
    else:
        df.to_csv(metrics_csv, mode='a', header=False, index=False)

    print(f"Epoch {epoch}/{cfg['epochs']} — TrainLoss: {total_loss/len(train_loader):.4f} — Dev AUC: {epoch_auc:.4f}")

torch.Size([16, 1, 64, 512])
z1.shape = torch.Size([16, 128])
enc_out: torch.Size([16, 252, 512]) dec_inp: torch.Size([16, 1, 512])
recon2.shape = torch.Size([16, 1, 64, 252]), feats.shape = torch.Size([16, 1, 64, 512])
z2.shape = torch.Size([16, 128])
z3.shape = torch.Size([16, 128]), feats_ds.shape = torch.Size([16, 1, 64, 252])


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

In [None]:
feats.squeeze(dim=3).shape