In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.amp
from torch.utils.checkpoint import checkpoint
from datasets import load_dataset
from transformers import AutoTokenizer, CLIPProcessor, CLIPModel
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import requests
from io import BytesIO
from contextlib import nullcontext
import math

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# ────────────────────────────────────────────────
# CONFIG – optimized for Pro A100
# ────────────────────────────────────────────────
triality = 8 # Changed from 3 to 8 to be a divisor of dim (512)
dim = 512  # CLIP embedding dim (ViT-B/32)
latent_dim = 8
seq_len = 1  # one image + caption pair per "sequence" (batch of pairs)
batch_size = 64  # large for Pro
epochs = 10000  # test — increase to 50k+ on Pro
lr = 5e-5
use_amp = True

# ────────────────────────────────────────────────
# LAION-5B subset loader (small for Colab — increase on Pro)
# ────────────────────────────────────────────────
# Use LAION-Aesthetics subset (high-quality image-text pairs)
dataset = load_dataset("conceptual_captions", split="train[:10000]")  # Using conceptual_captions as a fallback

# CLIP model + processor
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Process batch (images + captions)
def get_clip_embeddings(batch):
    images = []
    for url in batch["image_url"]:
        try:
            response = requests.get(url, timeout=10)
            img = Image.open(BytesIO(response.content)).convert("RGB")
            images.append(img)
        except:
            images.append(Image.new("RGB", (224, 224)))  # blank fallback

    inputs = clip_processor(text=batch["caption"], images=images, return_tensors="pt", padding=True, truncation=True).to(device)

    with torch.no_grad():
        outputs = clip_model(**inputs)
        image_emb = outputs.image_embeds  # (batch, dim)
        text_emb = outputs.text_embeds    # (batch, dim)

    fused = (image_emb + text_emb) / 2  # simple fusion proxy
    return fused.unsqueeze(1)  # (batch, seq=1, dim)

batch_data = get_clip_embeddings(dataset[:batch_size])
real_data = batch_data.to(device)

# Apply masking (40–70% on fused embeddings)
missing = torch.linspace(0.4, 0.7, batch_size, device=device).view(batch_size, 1, 1)
mask = torch.rand_like(real_data) < missing
real_data[mask] = 0

target = batch_data.to(device)  # clean for reconstruction

# E8 roots – precompute
def get_e8_roots():
    roots = []
    for i in range(8):
        for j in range(i+1, 8):
            for signs in [(1,1), (1,-1), (-1,1), (-1,-1)]:
                v = torch.zeros(8)
                v[i] = signs[0]; v[j] = signs[1]
                roots.append(v); roots.append(-v)
    for signs in range(1 << 8):
        v = torch.tensor([(1 if (signs & (1<<k)) else -1) for k in range(8)], dtype=torch.float32) * 0.5
        if bin(signs).count('1') % 2 == 0:
            roots.append(v); roots.append(-v)
    roots = torch.stack(roots[:240])
    return roots / roots.norm(dim=-1, keepdim=True)

e8_roots = get_e8_roots().to(device)

# Triality Cycle Block
class MultimodalCycleBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.proj = nn.Linear(latent_dim, dim // triality, bias=False)
        self.register_buffer('roots', e8_roots)

    def forward(self, x, step):
        pos_emb = self.roots[torch.arange(x.shape[1], device=device) % 240]
        low_dim = self.proj(pos_emb)
        emb = low_dim.repeat(1, triality)
        pump = 0.8 * torch.sin(torch.tensor(step, device=device, dtype=torch.float32) * 0.006 * 2 * math.pi)
        x_rot1 = x * (emb.cos() + pump)
        x_rot2 = torch.roll(x_rot1, shifts=1, dims=-1) * emb.sin()
        x_rot3 = torch.roll(x_rot2, shifts=1, dims=-1) * emb.cos()
        fused = (x_rot1 + x_rot2 + x_rot3) / triality
        return fused

# Model
class E8MultimodalFusion(nn.Module):
    def __init__(self, depth=64):
        super().__init__()
        self.cycle = MultimodalCycleBlock()
        self.layers = nn.ModuleList([nn.MultiheadAttention(dim, triality, batch_first=True) for _ in range(depth)])
        self.norm = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, dim)

    def forward(self, x, step):
        x = self.cycle(x, step)
        for layer in self.layers:
            attn, _ = layer(x, x, x)
            x = x + self.norm(attn)
        return x

model = E8MultimodalFusion().to(device)
model = torch.compile(model)

opt = torch.optim.AdamW(model.parameters(), lr=lr)
scaler = torch.amp.GradScaler('cuda') if use_amp else nullcontext()
loss_fn = nn.MSELoss()

for epoch in range(epochs):
    opt.zero_grad(set_to_none=True)

    with torch.amp.autocast(device_type='cuda', dtype=torch.float16) if use_amp else nullcontext():
        recon = model(real_data, epoch)
        loss = loss_fn(recon, target)

    scaler.scale(loss).backward() if use_amp else loss.backward()
    scaler.unscale_(opt) if use_amp else None
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1e6)
    scaler.step(opt) if use_amp else opt.step()
    scaler.update() if use_amp else None

    if epoch % 500 == 0:
        print(f"Epoch {epoch} | Loss {loss.item():.6f}")

# Visualization (example image + caption reconstruction proxy)
with torch.no_grad():
    recon = model(real_data, 0).cpu()
    original = real_data.cpu()

# Simple proxy viz (embedding similarity heatmap or sample images)
print("Multimodal fusion complete — real LAION image-text validation")

# For full viz on Pro: decode recon to images/text (advanced — add if needed)


Using device: cuda


Online softmax is disabled on the fly since Inductor decides to
split the reduction. Cut an issue to PyTorch if this is an
important use case and you want to speed it up with online
softmax.

Online softmax is disabled on the fly since Inductor decides to
split the reduction. Cut an issue to PyTorch if this is an
important use case and you want to speed it up with online
softmax.



Epoch 0 | Loss 63.818089


Online softmax is disabled on the fly since Inductor decides to
split the reduction. Cut an issue to PyTorch if this is an
important use case and you want to speed it up with online
softmax.



Epoch 500 | Loss 0.035762
Epoch 1000 | Loss 0.018477
Epoch 1500 | Loss 0.009374
Epoch 2000 | Loss 0.007191
Epoch 2500 | Loss 0.009173
Epoch 3000 | Loss 0.004572
Epoch 3500 | Loss 0.007187
Epoch 4000 | Loss 0.005768
Epoch 4500 | Loss 0.006557
Epoch 5000 | Loss 0.004437
Epoch 5500 | Loss 0.001778
Epoch 6000 | Loss 0.004872
Epoch 6500 | Loss 0.006077
Epoch 7000 | Loss 0.005186
Epoch 7500 | Loss 0.005334
Epoch 8000 | Loss 0.005028
Epoch 8500 | Loss 0.001707
Epoch 9000 | Loss 0.004413
Epoch 9500 | Loss 0.004422


  return torch._C._get_cublas_allow_tf32()


Multimodal fusion complete — real LAION image-text validation
