In [None]:
import timm
print("\n".join(timm.list_models('*')))


In [3]:
# Cell 1: Imports and Dummy Dataset

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import SwinModel, SwinConfig
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import os

# Dummy dataset loader (replace with ML-JET loader)
class DummyDataset(torch.utils.data.Dataset):
    def __init__(self, size=1000):
        self.data = torch.rand(size, 1, 32, 32)
        self.labels_energy = torch.randint(0, 2, (size, 1)).float()
        self.labels_alpha = torch.randint(0, 3, (size,))
        self.labels_q0 = torch.randint(0, 4, (size,))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], {
            'energy_loss_output': self.labels_energy[idx],
            'alpha_output': self.labels_alpha[idx],
            'q0_output': self.labels_q0[idx]
        }


In [4]:
# Cell 2: Custom SwinTiny for 32×32 without upsampling

import torch.nn as nn
from timm.models.swin_transformer import SwinTransformer

class SwinMultiHeadClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        # 1) Build a Swin that takes 1×32×32 inputs directly:
        #    - patch_size divides 32, e.g. 4 → produces (32/4=8) patches per dim
        #    - window_size also divides 8, e.g. 4 → non‐overlapping windows
        self.backbone = SwinTransformer(
            img_size=32,
            patch_size=4,
            in_chans=1,
            embed_dim=96,
            depths=[2, 2, 6, 2],
            num_heads=[3, 6, 12, 24],
            window_size=4,
            mlp_ratio=4.,
            qkv_bias=True,
            drop_rate=0.,
            attn_drop_rate=0.,
            drop_path_rate=0.1,
            norm_layer=nn.LayerNorm,
            patch_norm=True,
            use_checkpoint=False,
            num_classes=0  # <— so it returns features, not 1000‐class logits
        )
        # 2) The final feature dim is embed_dim * 2^(len(depths)-1)
        #    Since swin builds hierarchies, use its num_features attr:
        self.feature_dim = self.backbone.num_features

        # 3) Classification heads
        self.energy_head = nn.Linear(self.feature_dim, 1)
        self.alpha_head  = nn.Linear(self.feature_dim, 3)
        self.q0_head     = nn.Linear(self.feature_dim, 4)

    def forward(self, x):
        # x: (B,1,32,32)
        # SwinTransformer returns (B, num_features) when num_classes=0
        feats = self.backbone(x)
        return {
            'energy_loss_output': self.energy_head(feats),
            'alpha_output':       self.alpha_head(feats),
            'q0_output':          self.q0_head(feats)
        }


In [5]:
# Cell 3: Loss Computation

def compute_loss(outputs, targets):
    bce = nn.BCEWithLogitsLoss()
    ce = nn.CrossEntropyLoss()
    loss_energy = bce(outputs['energy_loss_output'], targets['energy_loss_output'])
    loss_alpha = ce(outputs['alpha_output'], targets['alpha_output'])
    loss_q0 = ce(outputs['q0_output'], targets['q0_output'])
    return loss_energy + loss_alpha + loss_q0



In [6]:
# Cell 4: Training Loop Function

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    for x, y in loader:
        x = x.to(device)
        y = {k: v.to(device) for k, v in y.items()}
        optimizer.zero_grad()
        outputs = model(x)
        loss = compute_loss(outputs, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)


In [7]:
# Cell 5: Main Training Script


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SwinMultiHeadClassifier().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
train_loader = DataLoader(DummyDataset(), batch_size=32, shuffle=True)

for epoch in range(5):
    loss = train_one_epoch(model, train_loader, optimizer, device)
    print(f"Epoch {epoch+1}, Loss: {loss:.4f}")


Epoch 1, Loss: 3.5435
Epoch 2, Loss: 3.2577
Epoch 3, Loss: 3.2780
Epoch 4, Loss: 3.2319
Epoch 5, Loss: 3.2665
