In [9]:
%cd /Users/masha/Documents/visual-reasoning

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import timm
import kornia.geometry.transform as K
import cv2
import random

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

BATCH_SIZE = 32
IMG_SIZE = 224
BLOCK_SIZE = 20
LR = 1e-3
EPOCHS = 30
SUBDIVISION_CAP = 4

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

/Users/masha/Documents/visual-reasoning


In [13]:
# --- CHIRAL SHAPES (Asymmetric Only) ---
CHIRAL_SHAPES = {
    'Tetris_L': [(0, -1), (0, 0), (0, 1), (1, 1)],
    'Tetris_J': [(0, -1), (0, 0), (0, 1), (-1, 1)],
    'Tetris_S': [(0, 0), (1, 0), (0, 1), (-1, 1)],
    'Tetris_Z': [(0, 0), (-1, 0), (0, 1), (1, 1)],
    'Pento_F':  [(0, 0), (0, -1), (1, -1), (-1, 0), (0, 1)],
    'Pento_P':  [(0, 0), (0, -1), (1, -1), (1, 0), (0, 1)],
}

def draw_shape(shape_name):
    img = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.uint8)
    center = IMG_SIZE // 2
    for dx, dy in CHIRAL_SHAPES[shape_name]:
        x = center + (dx * BLOCK_SIZE) - (BLOCK_SIZE // 2)
        y = center + (dy * BLOCK_SIZE) - (BLOCK_SIZE // 2)
        cv2.rectangle(img, (x, y), (x + BLOCK_SIZE, y + BLOCK_SIZE), 255, -1)
    return img

class TetrisPairDataset(Dataset):
    def __init__(self, n_samples=2000, is_train=True):
        self.data = []
        keys = list(CHIRAL_SHAPES.keys())
        
        print(f"Generating {n_samples} pairs...")
        for _ in range(n_samples):
            shape = random.choice(keys)
            base_img = draw_shape(shape)
            
            angle_start = np.random.randint(0, 360)
            
            angle_diff = np.random.randint(30, 150)
            angle_end = angle_start + angle_diff
            
            # Helper to rotate using CV2 for high quality generation
            def rot_cv2(img, ang, flip=False):
                M = cv2.getRotationMatrix2D((IMG_SIZE//2, IMG_SIZE//2), ang, 1.0)
                out = cv2.warpAffine(img, M, (IMG_SIZE, IMG_SIZE))
                if flip: out = cv2.flip(out, 1)
                _, out = cv2.threshold(out, 127, 255, cv2.THRESH_BINARY)
                return out

            x0_img = rot_cv2(base_img, angle_start)
            
            label = 1.0 # Default Same
            if not is_train and np.random.rand() > 0.5:
                # Create Impossible Pair (Mirrored)
                x1_img = rot_cv2(base_img, angle_end, flip=True)
                label = 0.0
            else:
                # Create Valid Pair
                x1_img = rot_cv2(base_img, angle_end, flip=False)
            
            def to_tensor(img):
                t = torch.tensor(img).float() / 255.0
                t = t.unsqueeze(0).repeat(3, 1, 1)
                t = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)(t)
                return t
            
            self.data.append((to_tensor(x0_img), to_tensor(x1_img), torch.tensor(label)))

    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return self.data[idx]

train_ds = TetrisPairDataset(n_samples=800, is_train=True)
test_ds  = TetrisPairDataset(n_samples=200, is_train=False)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE)

Generating 800 pairs...
Generating 200 pairs...
