In [None]:
# !wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
# !unzip tiny-imagenet-200.zip

In [5]:
import torch

In [7]:
import torch
from torchvision import datasets, transforms
from PIL import Image
import numpy as np
import random

permutations = [np.random.permutation(9) for _ in range(30)]

def jigsaw_transform(image, perm=None):
    """Divide la imagen en 9 piezas y las reordena según una permutación."""
    image = image.resize((64, 64))
    pieces = []
    w, h = image.size
    s = w // 3
    for i in range(3):
        for j in range(3):
            piece = image.crop((j*s, i*s, (j+1)*s, (i+1)*s))
            pieces.append(piece)
    
    if perm is None:
        idx = random.randint(0, len(permutations) - 1)
        perm = permutations[idx]
    else:
        idx = permutations.index(perm)
        
    shuffled = [pieces[p] for p in perm]
    new_im = Image.new('RGB', (w, h))
    
    for i in range(3):
        for j in range(3):
            new_im.paste(shuffled[i*3 + j], (j*s, i*s))
    
    return new_im, idx


In [8]:
class JigsawDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform=None):
        self.dataset = datasets.ImageFolder(root)
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, _ = self.dataset[idx]
        img, label = jigsaw_transform(img)
        if self.transform:
            img = self.transform(img)
        return img, label


In [9]:
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_data = JigsawDataset('tiny-imagenet-200/train', transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)

In [10]:
import torch.nn as nn
import torch.nn.functional as F

class JigsawNet(nn.Module):
    def __init__(self, num_classes=30):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 16 * 16, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = JigsawNet().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(5):
    model.train()
    running_loss = 0.0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, loss: {running_loss/len(train_loader):.4f}")


Epoch 1, loss: 2.1520
Epoch 2, loss: 1.1743
Epoch 3, loss: 0.5577
Epoch 4, loss: 0.2572
Epoch 5, loss: 0.2027
