In [1]:
import os, glob, random
from PIL import Image
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models

device = 'cuda' if torch.cuda.is_available() else 'cpu'
roots = {
    'sin':    'dataset_sin/images',
    'square': 'dataset_square/images',
    'random': 'dataset_random/images',
    'red':    'dataset_red/images',
    'drw':    'dataset_drw/images',
}


In [2]:
class PairRPDataset(Dataset):
    def __init__(self, roots, split='train', split_ratio=(0.7,0.15,0.15), seed=42, img_size=224):
        self.transform = T.Compose([T.Resize((img_size,img_size)), T.ToTensor()])
        rng = random.Random(seed)

        self.files_by_cls = {cls: sorted(glob.glob(os.path.join(path, "*.png"))) for cls, path in roots.items()}
        self.split_files = {}
        for cls, files in self.files_by_cls.items():
            n = len(files)
            n_tr = int(split_ratio[0] * n); n_va = int(split_ratio[1] * n)
            rng.shuffle(files)
            self.split_files[cls] = {
                'train': files[:n_tr],
                'val':   files[n_tr:n_tr+n_va],
                'test':  files[n_tr+n_va:]
            }

        self.cls_names = sorted(self.files_by_cls.keys())
        self.pool = {cls: self.split_files[cls][split] for cls in self.cls_names}

    def __len__(self):
        return 20000  # virtual pairs per epoch

    def __getitem__(self, idx):
        same = (idx % 2 == 0)
        if same:
            cls = random.choice(self.cls_names)
            f1, f2 = random.sample(self.pool[cls], 2)
            y = 1.0
        else:
            cls1, cls2 = random.sample(self.cls_names, 2)
            f1 = random.choice(self.pool[cls1])
            f2 = random.choice(self.pool[cls2])
            y = 0.0

        x1 = Image.open(f1).convert('L')
        x2 = Image.open(f2).convert('L')
        x1 = self.transform(x1); x2 = self.transform(x2)
        return x1, x2, torch.tensor([y], dtype=torch.float32)

In [3]:
class SiameseNet(nn.Module):
    def __init__(self, embed_dim=64):
        super().__init__()
        backbone = models.resnet18(weights=None)
        backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])  # (B,512,1,1)
        self.proj = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, 256), nn.ReLU(inplace=True),
            nn.Linear(256, embed_dim)
        )

    def forward(self, x):
        h = self.backbone(x)
        z = self.proj(h)
        return nn.functional.normalize(z, dim=1)

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin

    def forward(self, z1, z2, y):
        d = torch.nn.functional.pairwise_distance(z1, z2)
        loss_pos = y.squeeze() * (d ** 2)
        loss_neg = (1 - y.squeeze()) * torch.clamp(self.margin - d, min=0) ** 2
        return (loss_pos + loss_neg).mean()

In [4]:
train_ds = PairRPDataset(roots, split='train')
val_ds   = PairRPDataset(roots, split='val')

# show how many files landed in each split
print("TRAIN split sizes:")
for cls, subdict in train_ds.split_files.items():
    print(f"  {cls}: {len(subdict['train'])} images")

print("\nVAL split sizes:")
for cls, subdict in val_ds.split_files.items():
    print(f"  {cls}: {len(subdict['val'])} images")

# In notebooks/macOS, set num_workers=0 to avoid multiprocessing issues
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True,  num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=64, shuffle=False, num_workers=0, pin_memory=True)

TRAIN split sizes:
  sin: 70 images
  square: 70 images
  random: 70 images
  red: 70 images
  drw: 70 images

VAL split sizes:
  sin: 15 images
  square: 15 images
  random: 15 images
  red: 15 images
  drw: 15 images


In [5]:
import time

In [None]:
# Cell 5
model = SiameseNet(embed_dim=64).to(device)
crit  = ContrastiveLoss(margin=1.0)
opt   = torch.optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 10
best_val = float('inf')
total_start = time.time()

for epoch in range(1, num_epochs+1):
    epoch_start = time.time()
    # — training —
    model.train()
    train_loss = 0.0
    for x1, x2, y in train_loader:
        x1, x2, y = x1.to(device), x2.to(device), y.to(device)
        z1, z2 = model(x1), model(x2)
        loss = crit(z1, z2, y)
        opt.zero_grad(); loss.backward(); opt.step()
        train_loss += loss.item() * x1.size(0)
    train_loss /= len(train_loader.dataset)

    # — validation —
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for x1, x2, y in val_loader:
            x1, x2, y = x1.to(device), x2.to(device), y.to(device)
            z1, z2 = model(x1), model(x2)
            loss = crit(z1, z2, y)
            val_loss += loss.item() * x1.size(0)
    val_loss /= len(val_loader.dataset)

    epoch_end = time.time()
    print(f"Epoch {epoch:02d} | train {train_loss:.4f} | val {val_loss:.4f} | "
          f"time {(epoch_end-epoch_start):.1f}s")

    # save best
    if val_loss < best_val:
        best_val = val_loss
        torch.save(model.state_dict(), "siamese_best.pt")
        print("  ✓ saved siamese_best.pt")

total_end = time.time()
print(f"TOTAL TRAINING TIME: {(total_end-total_start)/60:.1f} minutes")



Epoch 01 | train 0.0190 | val 0.0001 | time 2252.6s
  ✓ saved siamese_best.pt
Epoch 02 | train 0.0027 | val 0.0075 | time 5003.7s
Epoch 03 | train 0.0004 | val 0.0000 | time 3949.0s
  ✓ saved siamese_best.pt
