In [15]:
%cd ..

import torch
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

import marimo as mo

/Users/masha/Documents


In [14]:
train_raw = np.load('./data/train_pairs.npy', allow_pickle=True)
test_raw = np.load('./data/test_balanced.npy', allow_pickle=True)

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

def prepare_data(raw_list, is_train=False):
    # Stack all images into one tensor (N, 1, 64, 64)
    x0 = torch.tensor(np.stack([d['x0'] for d in raw_list])).float()
    x1 = torch.tensor(np.stack([d['x1'] for d in raw_list])).float()
    
    # DINOv3 requires 3-channel input even for grayscale data
    x0 = x0.repeat(1, 3, 1, 1)
    x1 = x1.repeat(1, 3, 1, 1)
    
    # Normalize with respect to ImageNet for DINOv3
    x0 = (x0 + 1) * 0.5 
    x1 = (x1 + 1) * 0.5
    # print(x0, x1)
    
    # Apply ImageNet Normalization
    normalize = transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    x0 = normalize(x0)
    x1 = normalize(x1)
    
    if is_train:
        y = torch.ones(len(raw_list))
    else:
        y = torch.tensor([1.0 if d.get('label') == 'same' else 0.0 for d in raw_list])
        
    return x0, x1, y

train_x0, train_x1, train_y = prepare_data(train_raw, is_train=True)
test_x0, test_x1, test_y = prepare_data(test_raw, is_train=False)

train_loader = DataLoader(TensorDataset(train_x0, train_x1, train_y), batch_size=32, shuffle=True)
test_loader  = DataLoader(TensorDataset(test_x0, test_x1, test_y), batch_size=32)

batch = next(iter(train_loader))
print(f"Batch Shape: {batch[0].shape}")
print(f"Label Shape: {batch[2].shape}")

Batch Shape: torch.Size([32, 3, 64, 64])
Label Shape: torch.Size([32])


In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.net(x)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_ch, out_ch))
    def forward(self, x): return self.net(x)

class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = DoubleConv(in_ch, out_ch)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

In [17]:
class FlowReasoningModel(nn.Module):
    def __init__(self, backbone, backbone_dim=384, flow_dim=64):
        super().__init__()
        
        self.backbone = backbone
        for param in self.backbone.parameters():
            param.requires_grad = False
            
        # Time Embedding (t -> feature)
        self.time_mlp = nn.Sequential(
            nn.Linear(1, flow_dim * 4),
            nn.GELU(),
            nn.Linear(flow_dim * 4, flow_dim * 4)
        )

        # FLow UNet
        self.inc = DoubleConv(3, flow_dim)
        self.down1 = Down(flow_dim, flow_dim * 2)
        self.down2 = Down(flow_dim * 2, flow_dim * 4) # Bottleneck (256 ch)
        
        # Project DINO features to Bottleneck size
        self.cond_proj = nn.Linear(backbone_dim, flow_dim * 4)

        self.up1 = Up(flow_dim * 8, flow_dim * 2)
        self.up2 = Up(flow_dim * 4, flow_dim)
        self.outc = nn.Conv2d(flow_dim, 3, kernel_size=1)

    def forward(self, x_t, t, x0_clean):
        with torch.no_grad():
            # Extract features (adjust key based on specific DINOv3 output)
            feats = self.backbone.forward_features(x0_clean)
            cls_token = feats['x_norm_clstoken'] # Standard DINO output key

        # B. Embed Time & Condition
        t_emb = self.time_mlp(t)                 # (B, 256)
        cond = self.cond_proj(cls_token)         # (B, 256)
        global_cond = (t_emb + cond).unsqueeze(-1).unsqueeze(-1)

        x1 = self.inc(x_t)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        
        # Inject Reasoning (Time + DINO)
        x3 = x3 + global_cond
        
        x = self.up1(x3, x2)
        x = self.up2(x, x1)
        return self.outc(x)

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dinov3 = torch.hub.load('facebookresearch/dinov3', 'dinov3_vits14')
dinov3.to(device)

# backbone_dim=384 for ViT-Small, 768 for ViT-Base
model = FlowReasoningModel(backbone=dinov3, backbone_dim=384)
model.to(device)

# 3. Sanity Check
print(f"Model Parameters (Trainable): {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
dummy_in = torch.randn(2, 3, 64, 64).to(device)
dummy_t = torch.rand(2, 1).to(device)
out = model(dummy_in, dummy_t, dummy_in)
print(f"Output Shape: {out.shape}")

Using cache found in /Users/masha/.cache/torch/hub/facebookresearch_dinov3_main


ModuleNotFoundError: No module named 'termcolor'