In [1]:
%cd /Users/masha/Documents/visual-reasoning

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import timm
import kornia.geometry.transform as K # Crucial for 2D rotation heuristic

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

BATCH_SIZE = 32
LR = 1e-4
EPOCHS = 30
SUBDIVISION_CAP = 4  # Split the rotation path into 4 segments

# DINOv3 Constants
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

/Users/masha/Documents/visual-reasoning


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_raw = np.load('./data/train_pairs.npy', allow_pickle=True)
test_raw = np.load('./data/test_balanced.npy', allow_pickle=True)

def prepare_data(raw_list, is_train=False):
    x0 = torch.tensor(np.stack([d['x0'] for d in raw_list])).float()
    x1 = torch.tensor(np.stack([d['x1'] for d in raw_list])).float()
    
    if x0.max() > 1.0:
        x0 = x0 / 255.0
        x1 = x1 / 255.0
        
    if x0.shape[1] == 1:
        x0 = x0.repeat(1, 3, 1, 1)
        x1 = x1.repeat(1, 3, 1, 1)
    
    normalize = transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    x0 = normalize(x0)
    x1 = normalize(x1)
    
    if is_train:
        y = torch.ones(len(raw_list))
    else:
        y = torch.tensor([1.0 if d.get('label') == 'same' else 0.0 for d in raw_list])
        
    return x0, x1, y

train_x0, train_x1, train_y = prepare_data(train_raw, is_train=True)
test_x0, test_x1, test_y = prepare_data(test_raw, is_train=False)

train_loader = DataLoader(TensorDataset(train_x0, train_x1, train_y), batch_size=BATCH_SIZE, shuffle=True)
test_loader  = DataLoader(TensorDataset(test_x0, test_x1, test_y), batch_size=BATCH_SIZE)

print(f"Train Size: {len(train_loader.dataset)}")

Train Size: 153


In [5]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.net(x)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_ch, out_ch))
    def forward(self, x): return self.net(x)

class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = DoubleConv(in_ch, out_ch)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class FlowReasoningModel(nn.Module):
    def __init__(self, backbone, backbone_dim=384, flow_dim=64):
        super().__init__()
        
        # Frozen Backbone
        self.backbone = backbone
        for param in self.backbone.parameters():
            param.requires_grad = False
            
        # Time Embedding
        self.time_mlp = nn.Sequential(
            nn.Linear(1, flow_dim * 4),
            nn.GELU(),
            nn.Linear(flow_dim * 4, flow_dim * 4)
        )

        # Encoder (64x64)
        self.inc = DoubleConv(3, flow_dim)            
        self.down1 = Down(flow_dim, flow_dim * 2)     
        self.down2 = Down(flow_dim * 2, flow_dim * 4) # Bottleneck
        
        # Projection
        self.cond_proj = nn.Linear(backbone_dim, flow_dim * 4)

        # Decoder
        self.up1 = Up(flow_dim * 6, flow_dim * 2) 
        self.up2 = Up(flow_dim * 3, flow_dim)
        self.outc = nn.Conv2d(flow_dim, 3, kernel_size=1)

    def forward(self, x_t, t, x0_clean):
        with torch.no_grad():
            x0_high = F.interpolate(x0_clean, size=(224, 224), mode='bilinear', align_corners=False)
            feats = self.backbone.forward_features(x0_high)
            cls_token = feats[:, 0, :] 

        t_emb = self.time_mlp(t)             
        cond = self.cond_proj(cls_token)     
        global_cond = (t_emb + cond).unsqueeze(-1).unsqueeze(-1)

        x1 = self.inc(x_t)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        
        # Inject Reasoning
        x3 = x3 + global_cond
        
        x = self.up1(x3, x2)
        x = self.up2(x, x1)
        return self.outc(x)

dinov3 = timm.create_model("vit_small_patch16_dinov3", pretrained=True).to(DEVICE).eval()
model = FlowReasoningModel(backbone=dinov3).to(DEVICE)
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)