In [1]:
from google.colab import drive
drive.mount('/content/drive')

# Set the directory where your images are stored
image_dir = '/content/drive/My Drive/cavallo'


Mounted at /content/drive


In [4]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import random
import numpy as np
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

class ImageTilesDataset(Dataset):
    def __init__(self, directory, grid_size=(3, 3), tile_size=100, max_images=100):
        self.directory = directory
        self.grid_size = grid_size
        self.tile_size = tile_size
        self.data = []
        self.labels = []
        self.max_images = max_images
        self.load_images()

    def load_images(self):
        file_count = 0
        for filename in os.listdir(self.directory):
            if file_count >= self.max_images:  # Stop if the maximum number of images is reached
                break
            if filename.lower().endswith(('.png', '.jpeg', '.jpg')):
                file_count += 1
                image_path = os.path.join(self.directory, filename)
                image = Image.open(image_path).convert('RGB')
                self.jumble_image(image)

        if file_count == 0:
            print("Error: No images found.")
        else:
            print(f"Loaded {file_count} images.")

    def jumble_image(self, image):
        image = Resize((self.grid_size[0] * self.tile_size, self.grid_size[1] * self.tile_size))(image)
        tiles = [image.crop((j * self.tile_size, i * self.tile_size, (j + 1) * self.tile_size, (i + 1) * self.tile_size))
                 for i in range(self.grid_size[0]) for j in range(self.grid_size[1])]
        indices = list(range(len(tiles)))
        random.shuffle(indices)
        tiles = [tiles[i] for i in indices]
        self.data.extend(tiles)
        self.labels.extend(indices)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        transform = Compose([
            Resize(224),
            ToTensor(),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        tile = transform(self.data[idx])
        label = self.labels[idx]
        return tile, label

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, emb_size=768, img_size=224):
        super().__init__()
        self.projection = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_patches, emb_size))

    def forward(self, x):
        x = self.projection(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        x = x + self.position_embeddings
        return x

def get_2d_sincos_pos_embed(embed_dim, grid_size):
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    return pos_embed

def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega

    pos = grid.reshape(2, -1)  # (2, HW)
    out = np.einsum('d,hw->dhw', omega, pos)  # (D, H, W)

    emb_sin = np.sin(out)
    emb_cos = np.cos(out)

    emb = np.stack([emb_sin, emb_cos], dim=-1)
    emb = np.reshape(emb, [1, -1, embed_dim])
    return emb

class TransformerEncoderLayer(nn.Module):
    def __init__(self, emb_size=768, num_heads=8, forward_expansion=4, dropout_rate=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(emb_size)
        self.attention = MultiHeadAttention(emb_size, num_heads, dropout_rate)
        self.norm2 = nn.LayerNorm(emb_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(emb_size, forward_expansion * emb_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * emb_size, emb_size),
            nn.Dropout(dropout_rate)
        )

    def forward(self, x):
        x = self.norm1(x)
        attention = self.attention(x, x, x)
        x = attention + x
        x = self.norm2(x)
        forward = self.feed_forward(x)
        out = forward + x
        return out

class VisionTransformer(nn.Module):
    def __init__(self, patch_size=16, emb_size=768, depth=6, num_heads=8, num_classes=9, img_size=224, dropout_rate=0.1):
        super().__init__()
        self.patch_embedding = PatchEmbedding(patch_size=patch_size, emb_size=emb_size, img_size=img_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.transformer = nn.Sequential(*[TransformerEncoderLayer(emb_size, num_heads, dropout_rate=dropout_rate) for _ in range(depth)])
        self.to_cls_token = nn.Identity()
        self.fc = nn.Linear(emb_size, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        b, n, _ = x.shape
        cls_tokens = self.cls_token.expand(b, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.transformer(x)
        x = self.to_cls_token(x[:, 0])
        return self.fc(x)

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size=768, num_heads=8, dropout_rate=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(emb_size, num_heads, dropout=dropout_rate)

    def forward(self, query, key, value):
        return self.attention(query, key, value)[0]

image_dir = '/content/drive/My Drive/cavallo'
dataset = ImageTilesDataset(image_dir, max_images=100)
loader = DataLoader(dataset, batch_size=9, shuffle=True)

model = VisionTransformer(img_size=224, patch_size=32, num_classes=9, depth=6, num_heads=8)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_model(model, loader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        for tiles, labels in loader:
            optimizer.zero_grad()
            outputs = model(tiles)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}: Loss: {loss.item()}")

train_model(model, loader, criterion, optimizer)



Loaded 100 images.
Epoch 1: Loss: 2.5916965007781982
Epoch 2: Loss: 2.4304134845733643
Epoch 3: Loss: 2.4311647415161133
Epoch 4: Loss: 2.428215503692627
Epoch 5: Loss: 2.5313122272491455
Epoch 6: Loss: 2.362938404083252
Epoch 7: Loss: 2.39508056640625
Epoch 8: Loss: 2.241774797439575
Epoch 9: Loss: 2.202528476715088
Epoch 10: Loss: 2.140164613723755


In [6]:
def puzzle_accuracy(model, dataset):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for i in range(len(dataset)):
            tiles, labels = dataset[i]
            tiles = tiles.unsqueeze(0)
            labels = torch.tensor(labels).unsqueeze(0)
            outputs = model(tiles)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).item()
            total += 1
    return correct / total

puzzle_acc = puzzle_accuracy(model, dataset)
print(f"Puzzle Accuracy: {puzzle_acc:.4f}")


Puzzle Accuracy: 0.1111
