In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import time
import numpy as np

# ==========================================
# 1. Configuration & Physics
# ==========================================
# OPTIONS: 'baseline', 'weak' (Covariance), 'strong' (LeJEPA/Epps-Pulley)
REG_MODE = 'baseline'
SIGR_ALPHA = 0.1   # Strength of the physics constraint
SKETCH_DIM = 64    # Dimension of the random observer

BATCH_SIZE = 128
LEARNING_RATE = 1e-3
EPOCHS = 400
WEIGHT_DECAY = 5e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DROPOUT = 0.0

MIXUP_ALPHA = 0.8
CUTMIX_ALPHA = 1.0

# Check for Apple Silicon (MPS)
if torch.backends.mps.is_available():
    DEVICE = 'mps'

print(f"Training on: {DEVICE} | Mode: {REG_MODE} | Alpha: {SIGR_ALPHA}")

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

# ------------------------------------------
# Physics Engine: The Regularizers
# ------------------------------------------

def sigreg_weak_loss(x, sketch_dim=64):
    """
    Forces Covariance(x) ~ Identity.
    Matches the 2nd Moment (Spherical Cloud).
    """
    N, C = x.size()
    # 1. Sketching (Optional for C=512, but good for consistency)
    if C > sketch_dim:
        S = torch.randn(sketch_dim, C, device=x.device) / (C ** 0.5)
        x = x @ S.T  # [N, sketch_dim]
    else:
        sketch_dim = C

    # 2. Centering & Covariance
    x = x - x.mean(dim=0, keepdim=True)
    cov = (x.T @ x) / (N - 1 + 1e-6)

    # 3. Target Identity
    target = torch.eye(sketch_dim, device=x.device)

    # 4. Off-diagonal suppression + Diagonal maintenance
    return torch.norm(cov - target, p='fro')

def sigreg_strong_loss(x, sketch_dim=64):
    """
    Forces ECF(x) ~ ECF(Gaussian).
    Matches ALL Moments (Maximum Entropy Cloud).
    Exact implementation of LeJEPA Algorithm 1.
    """
    N, C = x.size()

    # 1. Projection (The Observer)
    # Project channels down to sketch_dim
    A = torch.randn(C, sketch_dim, device=x.device)
    A = A / (A.norm(p=2, dim=0, keepdim=True) + 1e-6)

    # 2. Integration Points
    t = torch.linspace(-5, 5, 17, device=x.device)

    # 3. Theoretical Gaussian CF
    exp_f = torch.exp(-0.5 * t**2)

    # 4. Empirical CF
    # proj: [N, sketch_dim]
    proj = x @ A

    # args: [N, sketch_dim, T]
    args = proj.unsqueeze(2) * t.view(1, 1, -1)

    # ecf: [sketch_dim, T] (Mean over batch)
    ecf = torch.exp(1j * args).mean(dim=0)

    # 5. Weighted L2 Distance
    # |ecf - gauss|^2 * gauss_weight
    diff_sq = (ecf - exp_f.unsqueeze(0)).abs().square()
    err = diff_sq * exp_f.unsqueeze(0)

    # 6. Integrate
    loss = torch.trapz(err, t, dim=1) * N

    return loss.mean()

import torch
import torch.nn as nn
import torch.nn.functional as F

class RotaryEmbedding2D(nn.Module):
    def __init__(self, dim, max_shape=(32, 32)):
        super().__init__()
        self.dim = dim
        # We split dim into two for x and y frequencies
        self.dim_x = dim // 2
        self.dim_y = dim - self.dim_x

        # Precompute frequencies
        inv_freq_x = 1.0 / (10000 ** (torch.arange(0, self.dim_x, 2).float() / self.dim_x))
        inv_freq_y = 1.0 / (10000 ** (torch.arange(0, self.dim_y, 2).float() / self.dim_y))

        self.register_buffer("inv_freq_x", inv_freq_x)
        self.register_buffer("inv_freq_y", inv_freq_y)

    def forward(self, h, w, device):
        # Generate grid
        seq_y = torch.arange(h, device=device, dtype=self.inv_freq_y.dtype)
        seq_x = torch.arange(w, device=device, dtype=self.inv_freq_x.dtype)

        # Outer product to get (H, W, dim/2)
        freqs_x = torch.einsum("i,j->ij", seq_x, self.inv_freq_x)
        freqs_y = torch.einsum("i,j->ij", seq_y, self.inv_freq_y)

        # Combine to (H, W, dim/2) -> repeat for cos/sin format
        emb_x = torch.cat((freqs_x, freqs_x), dim=-1)
        emb_y = torch.cat((freqs_y, freqs_y), dim=-1)

        # We need to construct the full 2D embeddings
        # Assuming we split the head dim: [x_part, y_part]
        # We broaden to fit the sequence length
        # Result shape: [H*W, 1, Dim] for broadcasting

        # Broadcast x along height, y along width
        # freqs_x: [W, dim_x] -> [H, W, dim_x]
        emb_x = emb_x.unsqueeze(0).repeat(h, 1, 1)
        # freqs_y: [H, dim_y] -> [H, W, dim_y]
        emb_y = emb_y.unsqueeze(1).repeat(1, w, 1)

        # Concatenate x and y frequencies: [H, W, dim]
        freqs = torch.cat([emb_x, emb_y], dim=-1)

        # Flatten: [H*W, dim]
        freqs = freqs.flatten(0, 1)
        return freqs[None, :, :] # [1, Seq, Dim]

def apply_rotary_pos_emb(q, k, freqs):
    # q, k: [B, H, Seq, Dim]
    # freqs: [1, Seq, Dim]

    # Split into pairs for rotation
    q_len = q.shape[-1]

    # Cos/Sin
    cos = freqs.cos()
    sin = freqs.sin()

    # Apply rotation
    # (x, y) -> (x cos - y sin, x sin + y cos)
    # Standard rotate_half implementation
    def rotate_half(x):
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

import torch.nn.functional as F

def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training: return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()
    return x.div(keep_prob) * random_tensor

class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
    def forward(self, x): return drop_path(x, self.drop_prob, self.training)

class ThermoAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # Note: self.scale is handled automatically by SDPA, but good to keep if needed manually
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(DROPOUT)

        # RoPE generator
        self.rope = RotaryEmbedding2D(head_dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        # Shape: [3, B, Heads, SeqLen, HeadDim]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # --- Apply 2D RoPE ---
        # RoPE modifies Q and K in place or returns new tensors.
        # It operates on the HeadDim, so it's compatible with the split heads.
        freqs = self.rope(H, W, x.device) # [1, SeqLen, HeadDim]
        q, k = apply_rotary_pos_emb(q, k, freqs)
        # ---------------------

        # --- Flash Attention ---
        # PyTorch 2.0+ automatically optimizes this using FlashAttention v2 on CUDA.
        # Input shapes are already (Batch, Heads, SeqLen, Dim), which SDPA expects.
        x = F.scaled_dot_product_attention(
            q, k, v,
            dropout_p=DROPOUT,
            is_causal=False  # ViT is bidirectional, not causal like GPT
        )

        # Reshape back: [B, Heads, N, Dim] -> [B, N, C]
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        # ---------------------
        return x

class ThermoViTBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., reg_mode='baseline', sketch_dim=64):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = ThermoAttention(dim, num_heads=num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(DROPOUT),
            nn.Linear(int(dim * mlp_ratio), dim),
            nn.Dropout(DROPOUT),
        )
        self.reg_mode = reg_mode
        self.sketch_dim = sketch_dim

        self.drop_path = DropPath(DROPOUT) if DROPOUT > 0. else nn.Identity()

    def forward(self, x, H, W):
        # Attention Residual
        x = x + self.drop_path(self.attn(self.norm1(x), H, W))

        # MLP Residual
        # Note: We apply SIGReg AFTER the block computation but BEFORE the next block.
        # This keeps the "Residual Stream" clean and Gaussian.

        mlp_out = self.drop_path(self.mlp(self.norm2(x)))
        x = x + mlp_out

        # --- PHYSICS INJECTION ---
        reg_loss = torch.tensor(0.0, device=x.device)
        if self.reg_mode != 'baseline':
            # Global Average Pool of the tokens [B, N, C] -> [B, C]
            # This represents the "Image Vector" at this depth
            flat_rep = x.mean(dim=1)

            # Crucial: Pre-Norm vs Post-Norm context.
            # LayerNorm forces variance=1. SIGReg forces Distribution=Gaussian.
            # They are compatible.
            if self.reg_mode == 'weak':
                reg_loss = sigreg_weak_loss(flat_rep, self.sketch_dim)
            elif self.reg_mode == 'strong':
                reg_loss = sigreg_strong_loss(flat_rep, self.sketch_dim)

        return x, reg_loss

class ThermoViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, num_classes=100,
                 dim=384, depth=12, heads=12, mlp_ratio=4,
                 reg_mode='strong', sketch_dim=64):
        super().__init__()

        self.img_size = img_size
        self.patch_size = patch_size
        self.H = img_size // patch_size
        self.W = img_size // patch_size
        num_patches = self.H * self.W

        # Patch Embedding
        self.patch_embed = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)

        # Blocks
        self.blocks = nn.ModuleList([
            ThermoViTBlock(dim, heads, mlp_ratio, reg_mode, sketch_dim)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, num_classes)

        # Initialize weights (trunc_normal is usually good for ViT)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        B = x.shape[0]

        # Patch Embed: [B, C, H, W] -> [B, N, C]
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)

        # Note: No absolute position embedding added here because we use RoPE

        total_phys_loss = 0.0

        # Pass through blocks
        for blk in self.blocks:
            x, l_loss = blk(x, self.H, self.W)
            total_phys_loss += l_loss

        # Classifier
        x = self.norm(x)
        x = x.mean(dim=1) # Global Average Pooling
        out = self.head(x)

        return out, (total_phys_loss / len(self.blocks))

# A "Small" ViT suitable for CIFAR-100 (approx ResNet18 parameter count)
def ViT_Small():
    return ThermoViT(
        img_size=32,
        patch_size=4,
        num_classes=100,
        dim=256,        # Hidden dimension
        depth=6,        # Fewer layers for speed testing
        heads=8,
        mlp_ratio=2,
        reg_mode=REG_MODE,
        sketch_dim=SKETCH_DIM
    )

# ==========================================
# 3. Data Preparation
# ==========================================
def get_data_loaders():
    print('==> Preparing data...')
    mean = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
    std = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandAugment(num_ops=2, magnitude=9), # <--- CRITICAL FOR ViT
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    return trainloader, testloader

# ==========================================
# 4. Training Engine
# ==========================================
def train(epoch, net, trainloader, optimizer, criterion):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    phys_loss_meter = 0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        # Apply Mixup/CutMix
        r = np.random.rand(1)
        if r < 0.5: # Mixup
            lam = np.random.beta(MIXUP_ALPHA, MIXUP_ALPHA)
            index = torch.randperm(inputs.size(0)).to(DEVICE)
            inputs = lam * inputs + (1 - lam) * inputs[index, :]
            targets_a, targets_b = targets, targets[index]
        else: # CutMix
            lam = np.random.beta(CUTMIX_ALPHA, CUTMIX_ALPHA)
            rand_index = torch.randperm(inputs.size(0)).to(DEVICE)
            target_a = targets
            target_b = targets[rand_index]
            bbx1, bby1, bbx2, bby2 = rand_bbox(inputs.size(), lam)
            inputs[:, :, bbx1:bbx2, bby1:bby2] = inputs[rand_index, :, bbx1:bbx2, bby1:bby2]
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (inputs.size()[-1] * inputs.size()[-2]))
            targets_a, targets_b = target_a, target_b

        optimizer.zero_grad()

        # Forward
        outputs, p_loss = net(inputs)

        # Task Loss
        c_loss = criterion(outputs, targets_a) * lam + criterion(outputs, targets_b) * (1. - lam)

        # Total Loss
        loss = (1 - SIGR_ALPHA) * c_loss + (SIGR_ALPHA * p_loss)

        loss.backward()
        optimizer.step()

        train_loss += ((1 - SIGR_ALPHA) * c_loss).item() # Log only task loss for comparison
        phys_loss_meter += (SIGR_ALPHA * p_loss).item()

        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += (lam * predicted.eq(targets_a).float() + (1 - lam) * predicted.eq(targets_b).float()).sum().item()

    acc = 100. * correct / total
    return train_loss / (batch_idx + 1), acc, phys_loss_meter / (batch_idx + 1)

def test(epoch, net, testloader, criterion):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            outputs, _ = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    acc = 100. * correct / total
    return test_loss / (batch_idx + 1), acc

# ==========================================
# 5. Main Execution
# ==========================================
if __name__ == '__main__':
    trainloader, testloader = get_data_loaders()

    print(f'==> Building model (Mode: {REG_MODE})...')
    net = ViT_Small()
    net = net.to(DEVICE)

    if DEVICE == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

    print(f"Starting training for {EPOCHS} epochs...")
    best_acc = 0

    for epoch in range(EPOCHS):
        start_time = time.time()

        tr_loss, tr_acc, phys_loss = train(epoch, net, trainloader, optimizer, criterion)
        te_loss, te_acc = test(epoch, net, testloader, criterion)

        scheduler.step()

        if te_acc > best_acc:
            best_acc = te_acc
            # torch.save(net.state_dict(), f'thermo_resnet_{REG_MODE}.pth')

        epoch_time = time.time() - start_time

        print(f"Epoch {epoch+1:03d} | T: {epoch_time:.0f}s | "
              f"Train: {tr_loss:.4f} ({tr_acc:.1f}%) | "
              f"Phys: {phys_loss:.2f} | "
              f"Val: {te_loss:.4f} ({te_acc:.2f}%) | "
              f"Best: {best_acc:.2f}%")

    print(f"Final Best: {best_acc:.2f}%")

Training on: cuda | Mode: baseline | Alpha: 0.1
==> Preparing data...
==> Building model (Mode: baseline)...
Starting training for 400 epochs...
Epoch 001 | T: 27s | Train: 3.9753 (3.8%) | Phys: 0.00 | Val: 4.0502 (6.85%) | Best: 6.85%
Epoch 002 | T: 27s | Train: 3.8729 (5.5%) | Phys: 0.00 | Val: 3.9293 (9.83%) | Best: 9.83%
Epoch 003 | T: 27s | Train: 3.8653 (5.7%) | Phys: 0.00 | Val: 3.8961 (10.10%) | Best: 10.10%
Epoch 004 | T: 27s | Train: 3.8390 (5.9%) | Phys: 0.00 | Val: 3.8423 (10.27%) | Best: 10.27%
Epoch 005 | T: 27s | Train: 3.8255 (6.3%) | Phys: 0.00 | Val: 3.8616 (10.60%) | Best: 10.60%
Epoch 006 | T: 27s | Train: 3.8453 (6.0%) | Phys: 0.00 | Val: 3.8469 (11.48%) | Best: 11.48%
Epoch 007 | T: 27s | Train: 3.8588 (6.0%) | Phys: 0.00 | Val: 3.8347 (10.48%) | Best: 11.48%
Epoch 008 | T: 27s | Train: 3.8194 (6.7%) | Phys: 0.00 | Val: 3.7806 (11.41%) | Best: 11.48%
Epoch 009 | T: 27s | Train: 3.8205 (6.5%) | Phys: 0.00 | Val: 3.8335 (10.69%) | Best: 11.48%
Epoch 010 | T: 27s | T