In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.amp
from torch.utils.checkpoint import checkpoint
import numpy as np
from contextlib import nullcontext
import math
import os

torch.cuda.empty_cache()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# CONFIG – optimized for T4 stability (reduced for no OOM)
triality = 3
dim = 240  # Changed from 256 to 240 to be divisible by triality (3) and num_heads (8)
latent_dim = 8
seq_len = 512  # reduced
batch_size = 32  # reduced
epochs = 3000  # reduced for fast run (sigma trend early)
lr = 5e-5
use_amp = True
use_checkpoint = True

checkpoint_dir = "./checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, "multimodal_noise_checkpoint.pth")

# Synthetic Grok multimodal noise adaptation proxy (video + audio + text features + extreme noise)
features_video = 96
features_audio = 48
features_text = 48

multimodal_data = []
for b in range(batch_size):
    t = torch.linspace(0, 10*math.pi, seq_len, device=device)

    video = torch.sin(t.unsqueeze(-1) * torch.arange(features_video, device=device)) * 0.5
    audio = torch.cos(t.unsqueeze(-1) * torch.arange(features_audio, device=device) * 1.2) * 0.4
    text = torch.sin(t.unsqueeze(-1) * torch.arange(features_text, device=device) * 1.5) * 0.3

    frame = torch.cat([video, audio, text], dim=-1)
    # Extreme noise for adaptation test
    noise_amp = torch.linspace(0.5, 1.0, batch_size, device=device)[b]
    frame += torch.randn_like(frame) * noise_amp
    multimodal_data.append(frame)

multimodal_data = torch.stack(multimodal_data).to(device)

# Project to shared dim - proj layer itself should be defined once
proj = nn.Linear(features_video + features_audio + features_text, dim).to(device)

# target is derived from clean_data, so it should also be created once
# Detach target from the graph as it's ground truth and shouldn't receive gradients
target = proj(multimodal_data).detach()

# E8 roots – precompute
def get_e8_roots():
    roots = []
    for i in range(8):
        for j in range(i+1, 8):
            for signs in [(1,1), (1,-1), (-1,1), (-1,-1)]:
                v = torch.zeros(8)
                v[i] = signs[0]; v[j] = signs[1]
                roots.append(v); roots.append(-v)
    for signs in range(1 << 8):
        v = torch.tensor([(1 if (signs & (1<<k)) else -1) for k in range(8)], dtype=torch.float32) * 0.5
        if bin(signs).count('1') % 2 == 0:
            roots.append(v); roots.append(-v)
    roots = torch.stack(roots[:240])
    return roots / roots.norm(dim=-1, keepdim=True)

e8_roots = get_e8_roots().to(device)

# Triality Cycle Block
class NoiseAdaptCycleBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.proj = nn.Linear(latent_dim, dim // triality, bias=False)
        self.register_buffer('roots', e8_roots)

    def forward(self, x, step):
        pos_emb = self.roots[torch.arange(x.shape[1], device=device) % 240]
        low_dim = self.proj(pos_emb)
        emb = low_dim.repeat(1, triality)
        with torch.no_grad():
            pump_scalar = 0.8 * math.sin(step * 0.006 * 2 * math.pi)
        pump = torch.full((1, x.shape[1], 1), pump_scalar, device=device)
        emb_broadcast = emb.unsqueeze(0)
        x_rot1 = x * (emb_broadcast.cos() + pump)
        x_rot2 = torch.roll(x_rot1, shifts=1, dims=1) * emb_broadcast.sin()
        x_rot3 = torch.roll(x_rot2, shifts=1, dims=1) * emb_broadcast.cos()
        fused = (x_rot1 + x_rot2 + x_rot3) / triality
        return fused

# Dummy cycle for ablation
class DummyCycle(nn.Module):
    def forward(self, x, step=None):
        return x

# Model with ablation support (reduced depth)
class E8NoiseAdaptFusion(nn.Module):
    def __init__(self, depth=16, use_triality=True):  # reduced
        super().__init__()
        self.use_triality = use_triality
        self.cycle = NoiseAdaptCycleBlock() if use_triality else DummyCycle()
        # Fix: ensure num_heads (8) divides embed_dim (dim=256)
        self.layers = nn.ModuleList([nn.MultiheadAttention(dim, 8, batch_first=True) for _ in range(depth)])
        self.norm = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, dim)

    def forward(self, x, step):
        x = self.cycle(x, step)
        for layer in self.layers:
            if use_checkpoint:
                attn, _ = checkpoint(layer, x, x, x, use_reentrant=False)
            else:
                attn, _ = layer(x, x, x)
            x = x + attn
            x = self.norm(x)
        return x

# Models
model = E8NoiseAdaptFusion(use_triality=True).to(device)
model_ablation = E8NoiseAdaptFusion(use_triality=False).to(device)

opt = torch.optim.AdamW(model.parameters(), lr=lr)
scaler = torch.amp.GradScaler('cuda') if use_amp else nullcontext()

opt_ablation = torch.optim.AdamW(model_ablation.parameters(), lr=lr)
scaler_ablation = torch.amp.GradScaler('cuda') if use_amp else nullcontext()

loss_fn = nn.MSELoss()

loss_hist = []
loss_abl_hist = []

start_epoch = 0

# Load checkpoint if exists (resume on disconnect)
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model'])
    model_ablation.load_state_dict(checkpoint['model_ablation'])
    opt.load_state_dict(checkpoint['opt'])
    opt_ablation.load_state_dict(checkpoint['opt_ablation'])
    scaler.load_state_dict(checkpoint['scaler'])
    scaler_ablation.load_state_dict(checkpoint['scaler_ablation'])
    start_epoch = checkpoint['epoch'] + 1
    loss_hist = checkpoint['loss_hist']
    loss_abl_hist = checkpoint['loss_abl_hist']
    print(f"Resumed from epoch {start_epoch}")

for epoch in range(start_epoch, epochs):
    opt.zero_grad(set_to_none=True)
    opt_ablation.zero_grad(set_to_none=True)

    # High masking (70–90% — partial input + noise adaptation proxy)
    # Recompute clean_data, mask, and real_data each epoch to ensure fresh computational graph
    # Detach target from proj's graph, then re-enable grad tracking for real_data_current_epoch
    clean_data_current_epoch = proj(multimodal_data).detach() # Detach from proj's graph
    missing_rate_current_epoch = torch.linspace(0.7, 0.9, batch_size, device=device).view(batch_size, 1, 1)
    mask_current_epoch = torch.rand_like(clean_data_current_epoch) < missing_rate_current_epoch

    # Apply masking using masked_fill to create a new tensor, then make it require gradients
    real_data_current_epoch = clean_data_current_epoch.masked_fill(mask_current_epoch, 0.0).requires_grad_(True)

    with torch.amp.autocast(device_type='cuda', dtype=torch.float16) if use_amp else nullcontext():
        # Pass the newly created real_data for the current epoch
        recon = model(real_data_current_epoch, epoch)
        loss = loss_fn(recon, target)

        recon_abl = model_ablation(real_data_current_epoch, epoch)
        loss_abl = loss_fn(recon_abl, target)

    # With detached inputs from proj, the graphs for model and model_ablation are independent.
    # So, retain_graph=True is not necessary for the first backward pass.
    scaler.scale(loss).backward() if use_amp else loss.backward()
    scaler.unscale_(opt) if use_amp else None
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1e6)
    scaler.step(opt) if use_amp else opt.step()
    scaler.update() if use_amp else None

    scaler_ablation.scale(loss_abl).backward() if use_amp else loss_abl.backward()
    scaler_ablation.unscale_(opt_ablation) if use_amp else None
    torch.nn.utils.clip_grad_norm_(model_ablation.parameters(), 1e6)
    scaler_ablation.step(opt_ablation) if use_amp else opt_ablation.step()
    scaler_ablation.update() if use_amp else None

    loss_hist.append(loss.item())
    loss_abl_hist.append(loss_abl.item())

    if epoch % 500 == 0:
        print(f"Epoch {epoch} | Triality Loss {loss.item():.6f} | Ablation Loss {loss_abl.item():.6f}")

    # Checkpoint every 1000 epochs
    if (epoch + 1) % 1000 == 0:
        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'model_ablation': model_ablation.state_dict(),
            'opt': opt.state_dict(),
            'opt_ablation': opt_ablation.state_dict(),
            'scaler': scaler.state_dict(),
            'scaler_ablation': scaler_ablation.state_dict(),
            'loss_hist': loss_hist,
            'loss_abl_hist': loss_abl_hist,
        }, checkpoint_path)
        print(f"Checkpoint saved at epoch {epoch}")

# Final Sigma Test
triality_mean = np.mean(loss_hist)
abl_mean = np.mean(loss_abl_hist)
std = np.std(loss_hist + loss_abl_hist)
sigma = (abl_mean - triality_mean) / std if std > 0 else 0

print(f"Final Sigma (Triality vs Ablation): {sigma:.2f} (higher = triality advantage)")

print("Sim complete — epochs + sigma test done")

Using device: cuda
Epoch 0 | Triality Loss 0.850489 | Ablation Loss 0.836162
Epoch 500 | Triality Loss 0.688533 | Ablation Loss 0.676749
Checkpoint saved at epoch 999
Epoch 1000 | Triality Loss 0.634558 | Ablation Loss 0.625433
Epoch 1500 | Triality Loss 0.589428 | Ablation Loss 0.581777
Checkpoint saved at epoch 1999
Epoch 2000 | Triality Loss 0.545486 | Ablation Loss 0.539314
Epoch 2500 | Triality Loss 0.502906 | Ablation Loss 0.497996
Checkpoint saved at epoch 2999
Final Sigma (Triality vs Ablation): -0.09 (higher = triality advantage)
Sim complete — epochs + sigma test done
