In [1]:
%cd /Users/masha/Documents/visual-reasoning

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import timm
import kornia.geometry.transform as K # Crucial for 2D rotation heuristic

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

BATCH_SIZE = 32
LR = 1e-4
EPOCHS = 30
SUBDIVISION_CAP = 4  # Split the rotation path into 4 segments

# DINOv3 Constants
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

/Users/masha/Documents/visual-reasoning


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_raw = np.load('./data/train_pairs.npy', allow_pickle=True)
test_raw = np.load('./data/test_balanced.npy', allow_pickle=True)

def prepare_data(raw_list, is_train=False):
    x0 = torch.tensor(np.stack([d['x0'] for d in raw_list])).float()
    x1 = torch.tensor(np.stack([d['x1'] for d in raw_list])).float()
    
    if x0.max() > 1.0:
        x0 = x0 / 255.0
        x1 = x1 / 255.0
        
    if x0.shape[1] == 1:
        x0 = x0.repeat(1, 3, 1, 1)
        x1 = x1.repeat(1, 3, 1, 1)
    
    normalize = transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    x0 = normalize(x0)
    x1 = normalize(x1)
    
    if is_train:
        y = torch.ones(len(raw_list))
    else:
        y = torch.tensor([1.0 if d.get('label') == 'same' else 0.0 for d in raw_list])
        
    return x0, x1, y

train_x0, train_x1, train_y = prepare_data(train_raw, is_train=True)
test_x0, test_x1, test_y = prepare_data(test_raw, is_train=False)

train_loader = DataLoader(TensorDataset(train_x0, train_x1, train_y), batch_size=BATCH_SIZE, shuffle=True)
test_loader  = DataLoader(TensorDataset(test_x0, test_x1, test_y), batch_size=BATCH_SIZE)

print(f"Train Size: {len(train_loader.dataset)}")

Train Size: 153
