In [2]:
!pip install git+https://github.com/KellerJordan/Muon

Collecting git+https://github.com/KellerJordan/Muon
  Cloning https://github.com/KellerJordan/Muon to /tmp/pip-req-build-oifnbiaf
  Running command git clone --filter=blob:none --quiet https://github.com/KellerJordan/Muon /tmp/pip-req-build-oifnbiaf
  Resolved https://github.com/KellerJordan/Muon to commit 6399c658d3c4a3356ba823fa6664b10e23871068
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: muon-optimizer
  Building wheel for muon-optimizer (setup.py) ... [?25l[?25hdone
  Created wheel for muon-optimizer: filename=muon_optimizer-0.1.0-py3-none-any.whl size=7141 sha256=12f3dd99d0372fffcf39f1c9733980145c0fec5f908bc5a3207f7ae316267028
  Stored in directory: /tmp/pip-ephem-wheel-cache-464g7e7q/wheels/6e/33/94/64d18603ba0f39064aab523d6edf493c388cfb7419bb5c9043
Successfully built muon-optimizer
Installing collected packages: muon-optimizer
Successfully installed muon-optimizer-0.1.0


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import math
import time

from muon import SingleDeviceMuonWithAuxAdam

# ==========================================
# 1. Configuration (Tuned for ViT on CIFAR)
# ==========================================
REG_MODE = 'baseline'
SIGR_ALPHA = 0.1   # Strength of the physics constraint
SKETCH_DIM = 64    # Dimension of the random observer

BATCH_SIZE = 128
LEARNING_RATE = 1e-2  # Slightly higher initial LR for AdamW with cosine schedule
EPOCHS = 400
WEIGHT_DECAY = 0.05
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
if torch.backends.mps.is_available(): DEVICE = 'mps'

# Regularization Config
MIXUP_ALPHA = 0.8
CUTMIX_ALPHA = 1.0
DROP_PATH_RATE = 0.1

print(f"Training on device: {DEVICE}")

# ==========================================
# 2. Data Preparation (THE FIX: Strong Augmentation)
# ==========================================
def get_data_loaders():
    print('==> Preparing data with Strong Augmentation...')

    mean = (0.5071, 0.4867, 0.4408)
    std = (0.2675, 0.2565, 0.2761)

    # FIX 1: Add RandAugment
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandAugment(num_ops=2, magnitude=9), # <--- CRITICAL FOR ViT
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    # Increase workers to handle augmentation load
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

    testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

    return trainloader, testloader


# ------------------------------------------
# Physics Engine: The Regularizers
# ------------------------------------------

def sigreg_weak_loss(x, sketch_dim=64):
    """
    Forces Covariance(x) ~ Identity.
    Matches the 2nd Moment (Spherical Cloud).
    """
    N, C = x.size()
    # 1. Sketching (Optional for C=512, but good for consistency)
    if C > sketch_dim:
        S = torch.randn(sketch_dim, C, device=x.device) / (C ** 0.5)
        x = x @ S.T  # [N, sketch_dim]
    else:
        sketch_dim = C

    # 2. Centering & Covariance
    x = x - x.mean(dim=0, keepdim=True)
    cov = (x.T @ x) / (N - 1 + 1e-6)

    # 3. Target Identity
    target = torch.eye(sketch_dim, device=x.device)

    # 4. Off-diagonal suppression + Diagonal maintenance
    return torch.norm(cov - target, p='fro')

def sigreg_strong_loss(x, sketch_dim=64):
    """
    Forces ECF(x) ~ ECF(Gaussian).
    Matches ALL Moments (Maximum Entropy Cloud).
    Exact implementation of LeJEPA Algorithm 1.
    """
    N, C = x.size()

    # 1. Projection (The Observer)
    # Project channels down to sketch_dim
    A = torch.randn(C, sketch_dim, device=x.device)
    A = A / (A.norm(p=2, dim=0, keepdim=True) + 1e-6)

    # 2. Integration Points
    t = torch.linspace(-5, 5, 17, device=x.device)

    # 3. Theoretical Gaussian CF
    exp_f = torch.exp(-0.5 * t**2)

    # 4. Empirical CF
    # proj: [N, sketch_dim]
    proj = x @ A

    # args: [N, sketch_dim, T]
    args = proj.unsqueeze(2) * t.view(1, 1, -1)

    # ecf: [sketch_dim, T] (Mean over batch)
    ecf = torch.exp(1j * args).mean(dim=0)

    # 5. Weighted L2 Distance
    # |ecf - gauss|^2 * gauss_weight
    diff_sq = (ecf - exp_f.unsqueeze(0)).abs().square()
    err = diff_sq * exp_f.unsqueeze(0)

    # 6. Integrate
    loss = torch.trapz(err, t, dim=1) * N

    return loss.mean()

# ==========================================
# 3. Mixup / CutMix Utilities
# ==========================================
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training: return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()
    return x.div(keep_prob) * random_tensor

class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
    def forward(self, x): return drop_path(x, self.drop_prob, self.training)

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=192):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class MLP(nn.Module):
    def __init__(self, dim, hidden_dim, drop=0.):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.drop = nn.Dropout(drop)
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., reg_mode='strong', sketch_dim=64):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=True, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, int(dim * mlp_ratio), drop=drop)

        self.reg_mode = reg_mode
        self.sketch_dim = sketch_dim
    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))


        # --- PHYSICS INJECTION ---
        reg_loss = torch.tensor(0.0, device=x.device)
        if self.reg_mode != 'baseline':
            # Global Average Pool of the tokens [B, N, C] -> [B, C]
            # This represents the "Image Vector" at this depth
            flat_rep = x.mean(dim=1)

            # Crucial: Pre-Norm vs Post-Norm context.
            # LayerNorm forces variance=1. SIGReg forces Distribution=Gaussian.
            # They are compatible.
            if self.reg_mode == 'weak':
                reg_loss = sigreg_weak_loss(flat_rep, self.sketch_dim)
            elif self.reg_mode == 'strong':
                reg_loss = sigreg_strong_loss(flat_rep, self.sketch_dim)

        return x, reg_loss

class VisionTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, num_classes=100, embed_dim=192, depth=9, num_heads=3, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, reg_mode='strong', sketch_dim=64):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.proj.weight.shape[2] # Just a hack to get patch count logic
        num_patches = (img_size // patch_size) ** 2
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, drop_rate, attn_drop_rate, dpr[i], reg_mode, sketch_dim)
            for i in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.xavier_uniform_(self.head.weight)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed

        total_phys_loss = 0.0

        # Pass through blocks
        for blk in self.blocks:
            x, l_loss = blk(x, )
            total_phys_loss += l_loss

        x = self.norm(x)
        out = self.head(x[:, 0])
        return out, (total_phys_loss / len(self.blocks))

# ==========================================
# 5. Training Engine (Updated for Mixup/CutMix)
# ==========================================
def train(epoch, net, trainloader, optimizer, criterion):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    phys_loss_meter = 0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        # Apply Mixup/CutMix
        r = np.random.rand(1)
        if r < 0.5: # Mixup
            lam = np.random.beta(MIXUP_ALPHA, MIXUP_ALPHA)
            index = torch.randperm(inputs.size(0)).to(DEVICE)
            inputs = lam * inputs + (1 - lam) * inputs[index, :]
            targets_a, targets_b = targets, targets[index]
        else: # CutMix
            lam = np.random.beta(CUTMIX_ALPHA, CUTMIX_ALPHA)
            rand_index = torch.randperm(inputs.size(0)).to(DEVICE)
            target_a = targets
            target_b = targets[rand_index]
            bbx1, bby1, bbx2, bby2 = rand_bbox(inputs.size(), lam)
            inputs[:, :, bbx1:bbx2, bby1:bby2] = inputs[rand_index, :, bbx1:bbx2, bby1:bby2]
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (inputs.size()[-1] * inputs.size()[-2]))
            targets_a, targets_b = target_a, target_b

        optimizer.zero_grad()

        # Forward
        outputs, p_loss = net(inputs)

        # Task Loss
        c_loss = criterion(outputs, targets_a) * lam + criterion(outputs, targets_b) * (1. - lam)

        # Total Loss
        loss = (1 - SIGR_ALPHA) * c_loss + (SIGR_ALPHA * p_loss)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
        optimizer.step()

        train_loss += ((1 - SIGR_ALPHA) * c_loss).item() # Log only task loss for comparison
        phys_loss_meter += (SIGR_ALPHA * p_loss).item()

        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += (lam * predicted.eq(targets_a).float() + (1 - lam) * predicted.eq(targets_b).float()).sum().item()

    acc = 100. * correct / total
    return train_loss / (batch_idx + 1), acc, phys_loss_meter / (batch_idx + 1)

def test(epoch, net, testloader, criterion):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            outputs, _ = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    acc = 100. * correct / total
    return test_loss / (batch_idx + 1), acc

if __name__ == '__main__':
    trainloader, testloader = get_data_loaders()

    net = VisionTransformer(img_size=32, patch_size=4, embed_dim=192, depth=9, num_heads=3, drop_path_rate=0.1, reg_mode=REG_MODE, sketch_dim=SKETCH_DIM)
    net = net.to(DEVICE)

    hidden_weights = [p for p in net.parameters() if p.ndim == 2]
    non_hidden_params = [p for p in net.parameters() if p.ndim != 2]

    param_groups = [
      dict(params=hidden_weights, use_muon=True, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY),
      dict(params=non_hidden_params, use_muon=False, lr=1e-3, betas=(0.9, 0.95), weight_decay=WEIGHT_DECAY)
    ]

    # Standard CrossEntropy for final eval, SoftLabel for training is handled by Mixup logic
    criterion = nn.CrossEntropyLoss()

    optimizer = SingleDeviceMuonWithAuxAdam(param_groups)

    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-5)

    print(f"Starting training for {EPOCHS} epochs with RandAugment + Mixup/CutMix...")

    best_acc = 0

    for epoch in range(EPOCHS):
        start_time = time.time()

        train_loss, train_acc, physics_loss = train(epoch, net, trainloader, optimizer, criterion)
        test_loss, test_acc = test(epoch, net, testloader, criterion)

        scheduler.step()

        if test_acc > best_acc:
            best_acc = test_acc
            # torch.save(net.state_dict(), f'thermo_resnet_{REG_MODE}.pth')

        epoch_time = time.time() - start_time

        print(f"Epoch {epoch+1} | T: {epoch_time:.0f}s | "
              f"Train: {train_loss:.4f} ({train_acc:.1f}%) | "
              f"Phys: {physics_loss:.2f} | "
              f"Val: {test_loss:.4f} ({test_acc:.2f}%) | "
              f"Best: {best_acc:.2f}%")

Training on device: cuda
==> Preparing data with Strong Augmentation...
Starting training for 400 epochs with RandAugment + Mixup/CutMix...
Epoch 1 | T: 37s | Train: 3.8585 (6.1%) | Phys: 0.00 | Val: 3.6131 (13.88%) | Best: 13.88%
Epoch 2 | T: 39s | Train: 3.6574 (9.2%) | Phys: 0.00 | Val: 3.2959 (20.19%) | Best: 20.19%
Epoch 3 | T: 37s | Train: 3.5778 (11.0%) | Phys: 0.00 | Val: 3.1777 (22.81%) | Best: 22.81%
Epoch 4 | T: 37s | Train: 3.5181 (12.4%) | Phys: 0.00 | Val: 3.0726 (24.68%) | Best: 24.68%
Epoch 5 | T: 37s | Train: 3.4263 (14.0%) | Phys: 0.00 | Val: 2.9215 (29.43%) | Best: 29.43%
Epoch 6 | T: 37s | Train: 3.3413 (16.1%) | Phys: 0.00 | Val: 2.7626 (31.64%) | Best: 31.64%
Epoch 7 | T: 37s | Train: 3.2502 (17.9%) | Phys: 0.00 | Val: 2.5451 (36.59%) | Best: 36.59%
Epoch 8 | T: 37s | Train: 3.1865 (19.5%) | Phys: 0.00 | Val: 2.3575 (40.44%) | Best: 40.44%
Epoch 9 | T: 37s | Train: 3.1738 (20.0%) | Phys: 0.00 | Val: 2.5585 (37.12%) | Best: 40.44%
Epoch 10 | T: 37s | Train: 3.1352 