In [None]:
import timm
print("\n".join(timm.list_models('vi*')))

In [1]:
# Cell 1: Imports and Dummy Dataset

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import SwinModel, SwinConfig
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import os

# Dummy dataset loader (replace with ML-JET loader)
class DummyDataset(torch.utils.data.Dataset):
    def __init__(self, size=1000):
        self.data = torch.rand(size, 1, 32, 32)
        self.labels_energy = torch.randint(0, 2, (size, 1)).float()
        self.labels_alpha = torch.randint(0, 3, (size,))
        self.labels_q0 = torch.randint(0, 4, (size,))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], {
            'energy_loss_output': self.labels_energy[idx],
            'alpha_output': self.labels_alpha[idx],
            'q0_output': self.labels_q0[idx]
        }


  from .autonotebook import tqdm as notebook_tqdm


In [6]:
import torch.nn as nn
from timm.models.vision_transformer import VisionTransformer

class ViTTinyMultiHeadClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        # 1) Build a ViT that takes 1×32×32 inputs directly:
        #    - img_size=32, patch_size=4 → (32/4=8)^2 = 64 patches
        #    - embed_dim small for speed, e.g. 192
        #    - depth=4 layers, num_heads=3 (must divide 192)
        self.backbone = VisionTransformer(
            img_size=32,
            patch_size=4,
            in_chans=1,
            embed_dim=192,
            depth=4,
            num_heads=3,
            mlp_ratio=4.0,
            qkv_bias=True,
            drop_rate=0.0,
            attn_drop_rate=0.0,
            drop_path_rate=0.1,
            norm_layer=nn.LayerNorm,
            num_classes=0,        # returns features
        )

        # 2) Feature dimension from ViT’s output:
        self.feature_dim = self.backbone.embed_dim  # 192

        # 3) Multi‑head classification layers:
        self.energy_head = nn.Linear(self.feature_dim, 1)
        self.alpha_head  = nn.Linear(self.feature_dim, 3)
        self.q0_head     = nn.Linear(self.feature_dim, 4)

    def forward(self, x):
        # x: (B,1,32,32)
        feats = self.backbone(x)    # → (B, feature_dim)
        return {
            'energy_loss_output': self.energy_head(feats),
            'alpha_output':       self.alpha_head(feats),
            'q0_output':          self.q0_head(feats)
        }


In [3]:
# Cell 3: Loss Computation

def compute_loss(outputs, targets):
    bce = nn.BCEWithLogitsLoss()
    ce = nn.CrossEntropyLoss()
    loss_energy = bce(outputs['energy_loss_output'], targets['energy_loss_output'])
    loss_alpha = ce(outputs['alpha_output'], targets['alpha_output'])
    loss_q0 = ce(outputs['q0_output'], targets['q0_output'])
    return loss_energy + loss_alpha + loss_q0



In [4]:
# Cell 4: Training Loop Function

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    for x, y in loader:
        x = x.to(device)
        y = {k: v.to(device) for k, v in y.items()}
        optimizer.zero_grad()
        outputs = model(x)
        loss = compute_loss(outputs, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)


In [7]:
# Cell 5: Main Training Script


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ViTTinyMultiHeadClassifier().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
train_loader = DataLoader(DummyDataset(), batch_size=32, shuffle=True)

for epoch in range(5):
    loss = train_one_epoch(model, train_loader, optimizer, device)
    print(f"Epoch {epoch+1}, Loss: {loss:.4f}")


Epoch 1, Loss: 3.3179
Epoch 2, Loss: 3.2251
Epoch 3, Loss: 3.1911
Epoch 4, Loss: 3.1922
Epoch 5, Loss: 3.2123


In [None]:
import sys, os
# point to the parent directory
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))
from models.model import create_model
class ViTParamClassifier(nn.Module):
    def __init__(self, vit_name='vit_base_patch16_224', pretrained=True, num_alpha=3, num_q0=4):
        super().__init__()
        # Load a ViT without its default head
        self.vit = create_model(vit_name, pretrained=pretrained, num_classes=0)
        embed_dim = self.vit.num_features  # e.g. 768

        # Single binary head for Energy‑Loss Module
        self.energy_head = nn.Linear(embed_dim, 1)

        # Categorical heads for alpha_s and Q0
        self.alpha_head  = nn.Linear(embed_dim, num_alpha)
        self.q0_head     = nn.Linear(embed_dim, num_q0)

    def forward(self, x):
        # x shape: (B, 1, 32, 32)  → expand to 3 channels & resize if needed
        if x.shape[1] == 1:
            x = x.repeat(1,3,1,1)
        x = nn.functional.interpolate(x, size=(224,224), mode='bilinear', align_corners=False)

        features = self.vit(x)  # → (B, embed_dim)
        return {
            'energy': self.energy_head(features).squeeze(1),        # (B,)
            'alpha' : self.alpha_head(features),                    # (B,3)
            'q0'    : self.q0_head(features)                        # (B,4)
        }


In [None]:

# Example losses
bce   = nn.BCEWithLogitsLoss()
ce    = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(start_epoch, cfg.epochs):
    model.train()
    for imgs, (y_energy, y_alpha, y_q0) in train_loader:
        imgs = imgs.to(device)
        out = model(imgs)
        loss = (
            bce(out['energy'], y_energy.to(device).float()) +
            ce(out['alpha'],  y_alpha.to(device)) +
            ce(out['q0'],     y_q0.to(device))
        )
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # … validation, logging, early‑stop checks …
