In [None]:
# generate new 2d shapes L and rotated/mirrored L with n number of intermediate steps
# Do not save as images for now
# Do save .npy files for testing on zero shot dinov3
# Load and train them in the same way as fm
# Evaluate on this dataset
# If it's working and better than the baseline, evaluate if this model works better on the 3d dataset

In [21]:
%cd /Users/masha/Documents/visual-reasoning

import torch
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt

import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score

# For dinov3
import timm

/Users/masha/Documents/visual-reasoning


In [22]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
OUTPUT_SIZE = (224, 224) 
BATCH_SIZE = 16   # Keep small for 224x224 training
LR = 1e-4
EPOCHS = 20

# DINOv3 Statistics
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

In [23]:
def prepare_tetris_data(file_path, is_train=False):
    print(f"Loading {file_path}...")
    raw_data = np.load(file_path, allow_pickle=True)
    
    # Data is already (1, 224, 224) from generator
    x0_np = np.stack([d['x0'] for d in raw_data])
    x1_np = np.stack([d['x1'] for d in raw_data])
    
    x0 = torch.tensor(x0_np).float()
    x1 = torch.tensor(x1_np).float()
    
    x0 = x0.repeat(1, 3, 1, 1)
    x1 = x1.repeat(1, 3, 1, 1)
    
    # Generator output is [-1, 1]. Map to [0, 1] first.
    x0 = (x0 + 1) * 0.5
    x1 = (x1 + 1) * 0.5
    
    # Apply ImageNet Stats
    normalize = transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    x0 = normalize(x0)
    x1 = normalize(x1)
    
    # Extract Labels (1.0 = Same, 0.0 = Different)
    if is_train:
        y = torch.ones(len(raw_data))
    else:
        y = torch.tensor([1.0 if d['label'] == 'same' else 0.0 for d in raw_data])
        
    return x0, x1, y

# Load Data
train_x0, train_x1, train_y = prepare_tetris_data("data_tetris/train.npy", is_train=True)
test_x0, test_x1, test_y = prepare_tetris_data("data_tetris/test.npy", is_train=False)

# Create Loaders
train_loader = DataLoader(TensorDataset(train_x0, train_x1, train_y), batch_size=BATCH_SIZE, shuffle=True)
test_loader  = DataLoader(TensorDataset(test_x0, test_x1, test_y), batch_size=BATCH_SIZE)

print(f"Train Size: {len(train_loader.dataset)}")
print(f"Test Size:  {len(test_loader.dataset)}")

Loading data_tetris/train.npy...
Loading data_tetris/test.npy...
Train Size: 1000
Test Size:  100


In [24]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.net(x)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_ch, out_ch))
    def forward(self, x): return self.net(x)

class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = DoubleConv(in_ch, out_ch)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        # Pad if needed
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

# --- Main Flow Reasoning Model ---
class FlowReasoningModel(nn.Module):
    def __init__(self, backbone, backbone_dim=384, flow_dim=64):
        super().__init__()
        
        # 1. Frozen DINOv3 Backbone
        self.backbone = backbone
        for param in self.backbone.parameters():
            param.requires_grad = False
            
        # 2. Time Embedding
        self.time_mlp = nn.Sequential(
            nn.Linear(1, flow_dim * 4),
            nn.GELU(),
            nn.Linear(flow_dim * 4, flow_dim * 4)
        )

        # 3. Encoder
        self.inc = DoubleConv(3, flow_dim)            # 64
        self.down1 = Down(flow_dim, flow_dim * 2)     # 128
        self.down2 = Down(flow_dim * 2, flow_dim * 4) # 256 (Bottleneck)
        
        # 4. Condition Projection (DINO -> Bottleneck)
        self.cond_proj = nn.Linear(backbone_dim, flow_dim * 4)

        # 5. Decoder (Fixed Channels)
        # Up1: 256 (Bottle) + 128 (Skip) = 384 -> Reduce to 128
        self.up1 = Up(flow_dim * 6, flow_dim * 2) 
        # Up2: 128 (Up1) + 64 (Skip) = 192 -> Reduce to 64
        self.up2 = Up(flow_dim * 3, flow_dim)
        
        self.outc = nn.Conv2d(flow_dim, 3, kernel_size=1)

    def forward(self, x_t, t, x0_clean):
        # A. Get DINO Identity
        with torch.no_grad():
            feats = self.backbone.forward_features(x0_clean)
            cls_token = feats[:, 0, :] 

        # B. Embeddings
        t_emb = self.time_mlp(t)             
        cond = self.cond_proj(cls_token)     
        global_cond = (t_emb + cond).unsqueeze(-1).unsqueeze(-1)

        # C. U-Net Pass
        x1 = self.inc(x_t)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        
        # Inject Reasoning
        x3 = x3 + global_cond
        
        x = self.up1(x3, x2)
        x = self.up2(x, x1)
        return self.outc(x)

In [25]:
dinov3 = timm.create_model("vit_small_patch16_dinov3", pretrained=True, num_classes=0)
dinov3.to(DEVICE)
dinov3.eval()

model = FlowReasoningModel(backbone=dinov3, backbone_dim=384, flow_dim=64)
model = model.to(DEVICE)

optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)

print(f"Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Trainable Parameters: 2,049,411


In [27]:
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
# elif torch.backends.mps.is_available():
#     DEVICE = torch.device("mps")
#     print("Succesfully switched to MPS")
else:
    DEVICE = torch.device("cpu")
    print("Running on CPU")

train_losses = []

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    
    for batch_idx, (x0, x1, _) in enumerate(train_loader):
        x0, x1 = x0.to(DEVICE), x1.to(DEVICE)
        
        # 1. Sample Time & Interpolate
        t = torch.rand(x0.shape[0], 1, device=DEVICE)
        t_view = t.view(-1, 1, 1, 1)
        x_t = (1 - t_view) * x0 + t_view * x1
        
        # 2. Target Velocity (Straight Line)
        target_v = x1 - x0
        
        # 3. Predict & Optimize
        pred_v = model(x_t, t, x0)
        loss = F.mse_loss(pred_v, target_v)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        
    avg_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_loss)
    
    if (epoch + 1) % 5 == 0:
        print(f"Epoch [{epoch+1}/{EPOCHS}] Loss: {avg_loss:.6f}")

print("Training Complete.")
plt.plot(train_losses)
plt.title("Flow Matching Loss (Tetris)")
plt.show()

Running on CPU


KeyboardInterrupt: 

In [None]:
from sklearn.metrics import roc_auc_score

# --- ODE Solver Helper ---
@torch.no_grad()
def solve_flow(model, x_start, steps=10):
    model.eval()
    dt = 1.0 / steps
    curr = x_start.clone()
    B = curr.shape[0]
    
    for i in range(steps):
        t = torch.full((B, 1), i / steps, device=DEVICE)
        v = model(curr, t, x_start)
        curr = curr + v * dt
    return curr

# --- Run Eval ---
print("Evaluating on Test Set...")
errors_same = []
errors_diff = []
all_labels = []
all_errors = []

for x0, x1, labels in test_loader:
    x0, x1 = x0.to(DEVICE), x1.to(DEVICE)
    
    # Generate Prediction
    x_pred = solve_flow(model, x0, steps=10)
    
    # Calculate Reconstruction Error (MSE per image)
    loss = torch.mean((x_pred - x1)**2, dim=[1, 2, 3]).cpu().numpy()
    lbls = labels.numpy()
    
    for i in range(len(lbls)):
        err = loss[i]
        lbl = lbls[i]
        
        all_errors.append(err)
        all_labels.append(lbl)
        
        if lbl == 1.0:
            errors_same.append(err)
        else:
            errors_diff.append(err)

# --- Results ---
errors_same = np.array(errors_same)
errors_diff = np.array(errors_diff)

print(f"\n--- Synth-Tetris Results ---")
print(f"Mean MSE (Same/Valid):     {np.mean(errors_same):.5f} (Target: ~0.0)")
print(f"Mean MSE (Diff/Mirrored):  {np.mean(errors_diff):.5f} (Target: HIGH)")

# AUC (High Error = Class 0/Diff)
auc = roc_auc_score(all_labels, -np.array(all_errors))
print(f"Reasoning AUC Score:       {auc:.4f}")

# Plot
plt.figure(figsize=(10, 5))
plt.hist(errors_same, bins=30, alpha=0.7, label='Valid Rotation', color='green', range=(0, np.max(all_errors)))
plt.hist(errors_diff, bins=30, alpha=0.7, label='Impossible (Mirrored)', color='red', range=(0, np.max(all_errors)))
plt.title("Reasoning Gap: Synth-Tetris")
plt.xlabel("Reconstruction MSE")
plt.legend()
plt.show()