In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import time
import numpy as np

# ==========================================
# 1. Configuration & Physics
# ==========================================
# OPTIONS: 'baseline', 'weak' (Covariance), 'strong' (LeJEPA/Epps-Pulley)
REG_MODE = 'strong'
SIGR_ALPHA = 0.01   # Strength of the physics constraint
SKETCH_DIM = 64    # Dimension of the random observer

BATCH_SIZE = 128
LEARNING_RATE = 0.1
EPOCHS = 400
WEIGHT_DECAY = 5e-4
MOMENTUM = 0.9
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

MIXUP_ALPHA = 0.8
CUTMIX_ALPHA = 1.0

# Check for Apple Silicon (MPS)
if torch.backends.mps.is_available():
    DEVICE = 'mps'

print(f"Training on: {DEVICE} | Mode: {REG_MODE} | Alpha: {SIGR_ALPHA}")

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

# ------------------------------------------
# Physics Engine: The Regularizers
# ------------------------------------------

def sigreg_weak_loss(x, sketch_dim=64):
    """
    Forces Covariance(x) ~ Identity.
    Matches the 2nd Moment (Spherical Cloud).
    """
    N, C = x.size()
    # 1. Sketching (Optional for C=512, but good for consistency)
    if C > sketch_dim:
        S = torch.randn(sketch_dim, C, device=x.device) / (C ** 0.5)
        x = x @ S.T  # [N, sketch_dim]
    else:
        sketch_dim = C

    # 2. Centering & Covariance
    x = x - x.mean(dim=0, keepdim=True)
    cov = (x.T @ x) / (N - 1 + 1e-6)

    # 3. Target Identity
    target = torch.eye(sketch_dim, device=x.device)

    # 4. Off-diagonal suppression + Diagonal maintenance
    return torch.norm(cov - target, p='fro')

def sigreg_strong_loss(x, sketch_dim=64):
    """
    Forces ECF(x) ~ ECF(Gaussian).
    Matches ALL Moments (Maximum Entropy Cloud).
    Exact implementation of LeJEPA Algorithm 1.
    """
    N, C = x.size()

    # 1. Projection (The Observer)
    # Project channels down to sketch_dim
    A = torch.randn(C, sketch_dim, device=x.device)
    A = A / (A.norm(p=2, dim=0, keepdim=True) + 1e-6)

    # 2. Integration Points
    t = torch.linspace(-5, 5, 17, device=x.device)

    # 3. Theoretical Gaussian CF
    exp_f = torch.exp(-0.5 * t**2)

    # 4. Empirical CF
    # proj: [N, sketch_dim]
    proj = x @ A

    # args: [N, sketch_dim, T]
    args = proj.unsqueeze(2) * t.view(1, 1, -1)

    # ecf: [sketch_dim, T] (Mean over batch)
    ecf = torch.exp(1j * args).mean(dim=0)

    # 5. Weighted L2 Distance
    # |ecf - gauss|^2 * gauss_weight
    diff_sq = (ecf - exp_f.unsqueeze(0)).abs().square()
    err = diff_sq * exp_f.unsqueeze(0)

    # 6. Integrate
    loss = torch.trapz(err, t, dim=1) * N

    return loss.mean()

# ==========================================
# 2. Thermodynamic ResNet-18
# ==========================================

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)

        # --- PHYSICS INJECTION ---
        # We calculate loss here, but we need to return it to the main model.
        # We use Global Average Pooling to get a (N, C) vector for regularization
        # This forces the semantic features to be isotropic.
        reg_loss = torch.tensor(0.0, device=x.device)

        if REG_MODE != 'baseline':
            # Pool spatial dims: [N, C, H, W] -> [N, C]
            flat = F.adaptive_avg_pool2d(out, (1, 1)).view(out.size(0), -1)

            if REG_MODE == 'weak':
                reg_loss = sigreg_weak_loss(flat, SKETCH_DIM)
            elif REG_MODE == 'strong':
                reg_loss = sigreg_strong_loss(flat, SKETCH_DIM)

        return out, reg_loss

class ThermoResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=100):
        super(ThermoResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)

        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.ModuleList(layers) # Changed to ModuleList to iterate manually

    def forward(self, x):
        total_phys_loss = 0.0

        out = F.relu(self.bn1(self.conv1(x)))

        # Manually iterate through layers to collect physics losses
        for layer_group in [self.layer1, self.layer2, self.layer3, self.layer4]:
            for block in layer_group:
                out, l_loss = block(out)
                total_phys_loss += l_loss

        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)

        # Normalize p_loss by number of layers to keep scale consistent
        return out, (total_phys_loss / 8.0)

def ResNet18():
    return ThermoResNet(BasicBlock, [2, 2, 2, 2])

# ==========================================
# 3. Data Preparation
# ==========================================
def get_data_loaders():
    print('==> Preparing data...')
    mean = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
    std = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    return trainloader, testloader

# ==========================================
# 4. Training Engine
# ==========================================
def train(epoch, net, trainloader, optimizer, criterion):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    phys_loss_meter = 0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        # Apply Mixup/CutMix
        r = np.random.rand(1)
        if r < 0.5: # Mixup
            lam = np.random.beta(MIXUP_ALPHA, MIXUP_ALPHA)
            index = torch.randperm(inputs.size(0)).to(DEVICE)
            inputs = lam * inputs + (1 - lam) * inputs[index, :]
            targets_a, targets_b = targets, targets[index]
        else: # CutMix
            lam = np.random.beta(CUTMIX_ALPHA, CUTMIX_ALPHA)
            rand_index = torch.randperm(inputs.size(0)).to(DEVICE)
            target_a = targets
            target_b = targets[rand_index]
            bbx1, bby1, bbx2, bby2 = rand_bbox(inputs.size(), lam)
            inputs[:, :, bbx1:bbx2, bby1:bby2] = inputs[rand_index, :, bbx1:bbx2, bby1:bby2]
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (inputs.size()[-1] * inputs.size()[-2]))
            targets_a, targets_b = target_a, target_b

        optimizer.zero_grad()

        # Forward
        outputs, p_loss = net(inputs)

        # Task Loss
        c_loss = criterion(outputs, targets_a) * lam + criterion(outputs, targets_b) * (1. - lam)

        # Total Loss
        loss = (1 - SIGR_ALPHA) * c_loss + (SIGR_ALPHA * p_loss)

        loss.backward()
        optimizer.step()

        train_loss += ((1 - SIGR_ALPHA) * c_loss).item() # Log only task loss for comparison
        phys_loss_meter += (SIGR_ALPHA * p_loss).item()

        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += (lam * predicted.eq(targets_a).float() + (1 - lam) * predicted.eq(targets_b).float()).sum().item()

    acc = 100. * correct / total
    return train_loss / (batch_idx + 1), acc, phys_loss_meter / (batch_idx + 1)

def test(epoch, net, testloader, criterion):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            outputs, _ = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    acc = 100. * correct / total
    return test_loss / (batch_idx + 1), acc

# ==========================================
# 5. Main Execution
# ==========================================
if __name__ == '__main__':
    trainloader, testloader = get_data_loaders()

    print(f'==> Building model (Mode: {REG_MODE})...')
    net = ResNet18()
    net = net.to(DEVICE)

    if DEVICE == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

    print(f"Starting training for {EPOCHS} epochs...")
    best_acc = 0

    for epoch in range(EPOCHS):
        start_time = time.time()

        tr_loss, tr_acc, phys_loss = train(epoch, net, trainloader, optimizer, criterion)
        te_loss, te_acc = test(epoch, net, testloader, criterion)

        scheduler.step()

        if te_acc > best_acc:
            best_acc = te_acc
            # torch.save(net.state_dict(), f'thermo_resnet_{REG_MODE}.pth')

        epoch_time = time.time() - start_time

        print(f"Epoch {epoch+1:03d} | T: {epoch_time:.0f}s | "
              f"Train: {tr_loss:.4f} ({tr_acc:.1f}%) | "
              f"Phys: {phys_loss:.2f} | "
              f"Val: {te_loss:.4f} ({te_acc:.2f}%) | "
              f"Best: {best_acc:.2f}%")

    print(f"Final Best: {best_acc:.2f}%")

Training on: cuda | Mode: strong | Alpha: 0.01
==> Preparing data...


  entry = pickle.load(f, encoding="latin1")


==> Building model (Mode: strong)...
Starting training for 400 epochs...
Epoch 001 | T: 24s | Train: 4.4491 (3.2%) | Phys: 0.38 | Val: 4.1227 (5.03%) | Best: 5.03%
Epoch 002 | T: 23s | Train: 4.2122 (5.2%) | Phys: 0.32 | Val: 3.9445 (7.92%) | Best: 7.92%
Epoch 003 | T: 23s | Train: 4.1147 (7.2%) | Phys: 0.32 | Val: 3.7519 (12.22%) | Best: 12.22%
Epoch 004 | T: 23s | Train: 4.0154 (9.0%) | Phys: 0.32 | Val: 3.6440 (13.81%) | Best: 13.81%
Epoch 005 | T: 23s | Train: 3.9552 (10.1%) | Phys: 0.32 | Val: 3.4771 (16.93%) | Best: 16.93%
Epoch 006 | T: 23s | Train: 3.8342 (12.1%) | Phys: 0.33 | Val: 3.3444 (19.03%) | Best: 19.03%
Epoch 007 | T: 23s | Train: 3.8018 (13.1%) | Phys: 0.33 | Val: 3.4722 (18.52%) | Best: 19.03%
Epoch 008 | T: 24s | Train: 3.7277 (14.6%) | Phys: 0.34 | Val: 3.1741 (23.45%) | Best: 23.45%
Epoch 009 | T: 23s | Train: 3.6949 (15.4%) | Phys: 0.35 | Val: 3.0692 (23.69%) | Best: 23.69%
Epoch 010 | T: 23s | Train: 3.6194 (17.0%) | Phys: 0.35 | Val: 3.0290 (25.25%) | Best: 25