In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import math
import time

REG_MODE = 'strong'
SIGR_ALPHA = 0.01   # Strength of the physics constraint
SKETCH_DIM = 64    # Dimension of the random observer

BATCH_SIZE = 128
LEARNING_RATE = 0.1
EPOCHS = 400
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
if torch.backends.mps.is_available(): DEVICE = 'mps'

# Regularization Config
MIXUP_ALPHA = 0.8
CUTMIX_ALPHA = 1.0

print(f"Training on device: {DEVICE}")

def get_data_loaders():
    print('==> Preparing data with Strong Augmentation...')

    mean = (0.5071, 0.4867, 0.4408)
    std = (0.2675, 0.2565, 0.2761)

    # FIX 1: Add RandAugment
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandAugment(num_ops=2, magnitude=9),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    # Increase workers to handle augmentation load
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

    testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

    return trainloader, testloader


# ------------------------------------------
# Physics Engine: The Regularizers
# ------------------------------------------

def sigreg_weak_loss(x, sketch_dim=64):
    """
    Forces Covariance(x) ~ Identity.
    Matches the 2nd Moment (Spherical Cloud).
    """
    N, C = x.size()
    # 1. Sketching (Optional for C=512, but good for consistency)
    if C > sketch_dim:
        S = torch.randn(sketch_dim, C, device=x.device) / (C ** 0.5)
        x = x @ S.T  # [N, sketch_dim]
    else:
        sketch_dim = C

    # 2. Centering & Covariance
    x = x - x.mean(dim=0, keepdim=True)
    cov = (x.T @ x) / (N - 1 + 1e-6)

    # 3. Target Identity
    target = torch.eye(sketch_dim, device=x.device)

    # 4. Off-diagonal suppression + Diagonal maintenance
    return torch.norm(cov - target, p='fro')

def sigreg_strong_loss(x, sketch_dim=64):
    """
    Forces ECF(x) ~ ECF(Gaussian).
    Matches ALL Moments (Maximum Entropy Cloud).
    Exact implementation of LeJEPA Algorithm 1.
    """
    N, C = x.size()

    # 1. Projection (The Observer)
    # Project channels down to sketch_dim
    A = torch.randn(C, sketch_dim, device=x.device)
    A = A / (A.norm(p=2, dim=0, keepdim=True) + 1e-6)

    # 2. Integration Points
    t = torch.linspace(-5, 5, 17, device=x.device)

    # 3. Theoretical Gaussian CF
    exp_f = torch.exp(-0.5 * t**2)

    # 4. Empirical CF
    # proj: [N, sketch_dim]
    proj = x @ A

    # args: [N, sketch_dim, T]
    args = proj.unsqueeze(2) * t.view(1, 1, -1)

    # ecf: [sketch_dim, T] (Mean over batch)
    ecf = torch.exp(1j * args).mean(dim=0)

    # 5. Weighted L2 Distance
    # |ecf - gauss|^2 * gauss_weight
    diff_sq = (ecf - exp_f.unsqueeze(0)).abs().square()
    err = diff_sq * exp_f.unsqueeze(0)

    # 6. Integrate
    loss = torch.trapz(err, t, dim=1) * N

    return loss.mean()

# ==========================================
# 3. Mixup / CutMix Utilities
# ==========================================
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

class LinearBlock(nn.Module):
    def __init__(self, dim, hidden_dim, reg_mode='baseline', sketch_dim=64):
        super().__init__()
        self.fc = nn.Linear(dim, hidden_dim)
        # Note: NO BATCH NORM. We rely purely on SIGReg.
        self.reg_mode = reg_mode
        self.sketch_dim = sketch_dim

    def forward(self, x):
        pre_act = self.fc(x)

        reg_loss = torch.tensor(0.0, device=x.device)
        if self.reg_mode != 'baseline':
            if self.reg_mode == 'weak':
                reg_loss = sigreg_weak_loss(pre_act, self.sketch_dim)
            elif self.reg_mode == 'strong':
                reg_loss = sigreg_strong_loss(pre_act, self.sketch_dim)

        out = F.relu(pre_act)

        return out, reg_loss

class ThermoMLP(nn.Module):
    def __init__(self, input_dim=3072, hidden_dim=1024, num_classes=100, depth=6, reg_mode='weak', sketch_dim=64):
        super().__init__()

        layers = []
        # Input Layer
        layers.append(LinearBlock(input_dim, hidden_dim, reg_mode))

        # Deep Layers (No Residuals!)
        for _ in range(depth - 2):
            layers.append(LinearBlock(hidden_dim, hidden_dim, reg_mode, sketch_dim))

        self.layers = nn.ModuleList(layers)
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        # Flatten: [B, C, H, W] -> [B, 3072]
        x = x.flatten(1)

        total_phys_loss = 0.0

        for layer in self.layers:
            x, l_loss = layer(x)
            total_phys_loss += l_loss

        out = self.classifier(x)

        # Normalize loss scale
        return out, (total_phys_loss / len(self.layers))

# ==========================================
# 5. Training Engine (Updated for Mixup/CutMix)
# ==========================================
def train(epoch, net, trainloader, optimizer, criterion):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    phys_loss_meter = 0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        # Apply Mixup/CutMix
        r = np.random.rand(1)
        if r < 0.5: # Mixup
            lam = np.random.beta(MIXUP_ALPHA, MIXUP_ALPHA)
            index = torch.randperm(inputs.size(0)).to(DEVICE)
            inputs = lam * inputs + (1 - lam) * inputs[index, :]
            targets_a, targets_b = targets, targets[index]
        else: # CutMix
            lam = np.random.beta(CUTMIX_ALPHA, CUTMIX_ALPHA)
            rand_index = torch.randperm(inputs.size(0)).to(DEVICE)
            target_a = targets
            target_b = targets[rand_index]
            bbx1, bby1, bbx2, bby2 = rand_bbox(inputs.size(), lam)
            inputs[:, :, bbx1:bbx2, bby1:bby2] = inputs[rand_index, :, bbx1:bbx2, bby1:bby2]
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (inputs.size()[-1] * inputs.size()[-2]))
            targets_a, targets_b = target_a, target_b

        optimizer.zero_grad()

        # Forward
        outputs, p_loss = net(inputs)

        # Task Loss
        c_loss = criterion(outputs, targets_a) * lam + criterion(outputs, targets_b) * (1. - lam)

        # Total Loss
        loss = (1 - SIGR_ALPHA) * c_loss + (SIGR_ALPHA * p_loss)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
        optimizer.step()

        train_loss += ((1 - SIGR_ALPHA) * c_loss).item() # Log only task loss for comparison
        phys_loss_meter += (SIGR_ALPHA * p_loss).item()

        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += (lam * predicted.eq(targets_a).float() + (1 - lam) * predicted.eq(targets_b).float()).sum().item()

    acc = 100. * correct / total
    return train_loss / (batch_idx + 1), acc, phys_loss_meter / (batch_idx + 1)

def test(epoch, net, testloader, criterion):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            outputs, _ = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    acc = 100. * correct / total
    return test_loss / (batch_idx + 1), acc

if __name__ == '__main__':
    trainloader, testloader = get_data_loaders()

    net = ThermoMLP(reg_mode=REG_MODE, sketch_dim=SKETCH_DIM).to(DEVICE)
    net = net.to(DEVICE)

    # Standard CrossEntropy for final eval, SoftLabel for training is handled by Mixup logic
    criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(net.parameters(), lr=LEARNING_RATE, momentum=0.0, weight_decay=0.0)

    print(f"Starting training for {EPOCHS} epochs with RandAugment + Mixup/CutMix...")

    best_acc = 0

    for epoch in range(EPOCHS):
        start_time = time.time()

        train_loss, train_acc, physics_loss = train(epoch, net, trainloader, optimizer, criterion)
        test_loss, test_acc = test(epoch, net, testloader, criterion)

        if test_acc > best_acc:
            best_acc = test_acc
            # torch.save(net.state_dict(), f'thermo_resnet_{REG_MODE}.pth')

        epoch_time = time.time() - start_time

        print(f"Epoch {epoch+1} | T: {epoch_time:.0f}s | "
              f"Train: {train_loss:.4f} ({train_acc:.1f}%) | "
              f"Phys: {physics_loss:.2f} | "
              f"Val: {test_loss:.4f} ({test_acc:.2f}%) | "
              f"Best: {best_acc:.2f}%")

Training on device: cuda
==> Preparing data with Strong Augmentation...
Starting training for 400 epochs with RandAugment + Mixup/CutMix...
Epoch 1 | T: 19s | Train: 4.5508 (1.3%) | Phys: 0.33 | Val: 4.5239 (2.30%) | Best: 2.30%
Epoch 2 | T: 19s | Train: 4.5200 (1.8%) | Phys: 0.17 | Val: 4.4048 (3.24%) | Best: 3.24%
Epoch 3 | T: 19s | Train: 4.4558 (2.4%) | Phys: 0.13 | Val: 4.2692 (3.96%) | Best: 3.96%
Epoch 4 | T: 19s | Train: 4.4151 (3.1%) | Phys: 0.11 | Val: 4.1996 (5.53%) | Best: 5.53%
Epoch 5 | T: 19s | Train: 4.3776 (3.6%) | Phys: 0.10 | Val: 4.1133 (6.36%) | Best: 6.36%
Epoch 6 | T: 19s | Train: 4.3421 (4.2%) | Phys: 0.10 | Val: 4.0502 (7.36%) | Best: 7.36%
Epoch 7 | T: 19s | Train: 4.3210 (4.5%) | Phys: 0.09 | Val: 4.0155 (7.83%) | Best: 7.83%
Epoch 8 | T: 19s | Train: 4.3037 (4.9%) | Phys: 0.09 | Val: 3.9932 (7.78%) | Best: 7.83%
Epoch 9 | T: 19s | Train: 4.2932 (5.0%) | Phys: 0.08 | Val: 3.9408 (8.93%) | Best: 8.93%
Epoch 10 | T: 19s | Train: 4.2853 (5.0%) | Phys: 0.08 | Val