In [2]:
%cd /Users/masha/Documents/visual-reasoning

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import timm
import kornia.geometry.transform as K
import cv2
import random

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

BATCH_SIZE = 32
IMG_SIZE = 224
BLOCK_SIZE = 25
EPOCHS = 20

CHIRAL_SHAPES = {
    'L': [(0, -1), (0, 0), (0, 1), (1, 1)],
    'J': [(0, -1), (0, 0), (0, 1), (-1, 1)],
    'S': [(0, 0), (1, 0), (0, 1), (-1, 1)],
    'Z': [(0, 0), (-1, 0), (0, 1), (1, 1)],
    'F': [(0, 0), (0, -1), (1, -1), (-1, 0), (0, 1)],
    'P': [(0, 0), (0, -1), (1, -1), (1, 0), (0, 1)],
}

def draw_base_shape(shape_name):
    img = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.uint8)
    center = IMG_SIZE // 2
    for dx, dy in CHIRAL_SHAPES[shape_name]:
        x = center + (dx * BLOCK_SIZE) - (BLOCK_SIZE // 2)
        y = center + (dy * BLOCK_SIZE) - (BLOCK_SIZE // 2)
        cv2.rectangle(img, (x, y), (x + BLOCK_SIZE, y + BLOCK_SIZE), 255, -1)
    return img

class TetrisBaseDataset(Dataset):
    def __init__(self, n_samples=2000):
        self.data = []
        keys = list(CHIRAL_SHAPES.keys())
        for _ in range(n_samples):
            name = random.choice(keys)
            img = draw_base_shape(name)
            t = torch.tensor(img).float() / 255.0
            t = t.unsqueeze(0).repeat(3, 1, 1)
            t = transforms.Normalize(mean=[0.485], std=[0.229])(t)
            self.data.append(t)
            
    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return self.data[idx]

train_loader = DataLoader(TetrisBaseDataset(1000), batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(TetrisBaseDataset(100), batch_size=BATCH_SIZE)

/Users/masha/Documents/visual-reasoning
