In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Subset
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
import seaborn as sns
from tqdm import tqdm
import os
import random

torch.manual_seed(47)
np.random.seed(47)
random.seed(47)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
def load_fashion_mnist(batch_size=64, subset_fraction=0.2, selected_classes=None):
    transform_train = transforms.Compose([
      transforms.Resize((32, 32)),
      transforms.RandomHorizontalFlip(),
      transforms.RandomRotation(10),
      transforms.ColorJitter(brightness=0.2, contrast=0.2),
      transforms.ToTensor(),
      transforms.Lambda(lambda x: x.repeat(3, 1, 1))  # convert 1 channel to 3
    ])

    transform_test = transforms.Compose([
      transforms.Resize((32, 32)),
      transforms.ToTensor(),
      transforms.Lambda(lambda x: x.repeat(3, 1, 1))
    ])

    train_dataset = FashionMNIST(root="./data", train=True, download=True, transform=transform_train)
    test_dataset = FashionMNIST(root="./data", train=False, download=True, transform=transform_test)

    # filter by classes and subset data
    if selected_classes is not None:
        train_indices = [i for i, (_, label) in enumerate(train_dataset) if label in selected_classes]
        test_indices = [i for i, (_, label) in enumerate(test_dataset) if label in selected_classes]
        train_dataset = Subset(train_dataset, train_indices)
        test_dataset = Subset(test_dataset, test_indices)

    if subset_fraction < 1.0:
        train_size = int(len(train_dataset) * subset_fraction)
        test_size = int(len(test_dataset) * subset_fraction)

        train_indices = random.sample(range(len(train_dataset)), train_size)
        test_indices = random.sample(range(len(test_dataset)), test_size)

        train_subset = Subset(train_dataset, train_indices)
        test_subset = Subset(test_dataset, test_indices)
    else:
        train_subset = train_dataset
        test_subset = test_dataset

    # split training data and create validation sets
    train_size = int(0.9 * len(train_subset))
    val_size = len(train_subset) - train_size
    train_dataset, valid_dataset = random_split(train_subset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(valid_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_subset, batch_size=batch_size)

    class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
                   'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

    if selected_classes:
        # filter class names
        class_names = [class_names[i] for i in selected_classes]

    print(f"Training set size: {len(train_dataset)}")
    print(f"Validation set size: {len(valid_dataset)}")
    print(f"Test set size: {len(test_subset)}")

    return train_loader, val_loader, test_loader, class_names

In [None]:
# patches
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=4, embed_dim=64, img_size=32):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        x = self.proj(x)  # [batch_size, embed_dim, grid_size, grid_size]
        x = x.flatten(2)  # [batch_size, embed_dim, num_patches]
        x = x.transpose(1, 2)  # [batch_size, num_patches, embed_dim]
        return x

class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=32,
        patch_size=4,
        in_channels=3,
        embed_dim=64,
        depth=6,
        num_heads=4,
        mlp_ratio=4.0,
        dropout=0.1,
        attn_dropout=0.1,
        embedding_dropout=0.1,
        embedding_dim=128,
        num_classes=10
    ):
        super().__init__()

        # oatch embedding
        self.patch_embed = PatchEmbedding(
            in_channels=in_channels,
            patch_size=patch_size,
            embed_dim=embed_dim,
            img_size=img_size
        )
        num_patches = self.patch_embed.num_patches

        # class token and positional embedding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

        # embedding dropout
        self.dropout = nn.Dropout(embedding_dropout)

        # transformer encoder
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=num_heads,
                dim_feedforward=int(embed_dim * mlp_ratio),
                dropout=dropout,
                activation="gelu",
                batch_first=True
            ),
            num_layers=depth
        )

        # mlp head for embeddings
        self.embedding_layer = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embedding_dim),
            nn.GELU()
        )

        # classification head
        self.head = nn.Linear(embedding_dim, num_classes)

        # initialize weights
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, return_embeddings=False):
        # patch embedding
        x = self.patch_embed(x)  # [batch_size, num_patches, embed_dim]

        # add class token
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)  # [batch_size, num_patches + 1, embed_dim]

        # add positional embedding
        x = x + self.pos_embed
        x = self.dropout(x)

        # transformer
        x = self.transformer(x)

        # class token for classification
        cls_token_final = x[:, 0]

        # embedding layer
        embeddings = self.embedding_layer(cls_token_final)

        # classification head
        logits = self.head(embeddings)

        if return_embeddings:
            return logits, embeddings
        else:
            return logits

# triplet loss (CHANGE THIS!!)
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        positive_dist = torch.sum((anchor - positive) ** 2, dim=1)
        negative_dist = torch.sum((anchor - negative) ** 2, dim=1)

        losses = F.relu(positive_dist - negative_dist + self.margin)
        return torch.mean(losses)


def train_with_metric_learning(model, train_loader, val_loader, optimizer, scheduler, num_epochs=15, save_path='best_model.pth'):
    classification_criterion = nn.CrossEntropyLoss()
    triplet_criterion = TripletLoss(margin=1.0)

    best_val_acc = 0.0
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_cls_loss = 0.0
        running_triplet_loss = 0.0
        correct = 0
        total = 0

        train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
        for inputs, labels in train_bar:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs, embeddings = model(inputs, return_embeddings=True)
            cls_loss = classification_criterion(outputs, labels)
            triplet_loss = torch.tensor(0.0).to(device)
            unique_labels = torch.unique(labels)

            # only compute triplet loss if we have at least 2 classes in the batch
            if len(unique_labels) >= 2:
                for label in unique_labels:
                    mask_anchor = (labels == label)
                    mask_negative = (labels != label)

                    if mask_anchor.sum() >= 2 and mask_negative.sum() >= 1:
                        anchor_indices = torch.where(mask_anchor)[0]
                        positive_indices = anchor_indices[torch.randperm(len(anchor_indices))]

                        mask_diff = (anchor_indices != positive_indices)
                        if mask_diff.sum() > 0:
                            anchor_indices = anchor_indices[mask_diff][:1]  # Take just one
                            positive_indices = positive_indices[mask_diff][:1]  # Take just one

                            negative_indices = torch.where(mask_negative)[0]
                            negative_indices = negative_indices[torch.randperm(len(negative_indices))][:1]  # Take just one

                            anchor_embeds = embeddings[anchor_indices]
                            positive_embeds = embeddings[positive_indices]
                            negative_embeds = embeddings[negative_indices]

                            batch_triplet_loss = triplet_criterion(anchor_embeds, positive_embeds, negative_embeds)
                            triplet_loss += batch_triplet_loss

            loss = cls_loss + 0.5 * triplet_loss

            loss.backward()
            optimizer.step()

            # training statistics
            running_loss += loss.item() * inputs.size(0)
            running_cls_loss += cls_loss.item() * inputs.size(0)
            running_triplet_loss += triplet_loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            train_bar.set_postfix(
                loss=loss.item(),
                cls_loss=cls_loss.item(),
                triplet=triplet_loss.item(),
                acc=correct/total
            )

        epoch_train_loss = running_loss / len(train_loader.dataset)
        epoch_cls_loss = running_cls_loss / len(train_loader.dataset)
        epoch_triplet_loss = running_triplet_loss / len(train_loader.dataset)
        epoch_train_acc = correct / total
        train_losses.append(epoch_train_loss)
        train_accs.append(epoch_train_acc)

        print(f"Train - Cls Loss: {epoch_cls_loss:.4f}, Triplet Loss: {epoch_triplet_loss:.4f}")

        model.eval()
        running_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")
            for inputs, labels in val_bar:
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                loss = classification_criterion(outputs, labels)

                running_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                val_bar.set_postfix(loss=loss.item(), acc=correct/total)

        epoch_val_loss = running_loss / len(val_loader.dataset)
        epoch_val_acc = correct / total
        val_losses.append(epoch_val_loss)
        val_accs.append(epoch_val_acc)

        print(f"Epoch {epoch+1}/{num_epochs}: "
              f"Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}, "
              f"Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")

        if scheduler:
            scheduler.step()

        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            torch.save(model.state_dict(), save_path)
            print(f"New best model saved with validation accuracy: {best_val_acc:.4f}")

    print(f"Best validation accuracy: {best_val_acc:.4f}")
    return train_losses, val_losses, train_accs, val_accs

def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = correct / total
    print(f"Test Accuracy: {accuracy:.4f}")

    return accuracy, np.array(all_preds), np.array(all_labels)

def extract_embeddings(model, data_loader):
    model.eval()
    embeddings = []
    labels = []

    with torch.no_grad():
        for inputs, targets in tqdm(data_loader, desc="Extracting embeddings"):
            inputs = inputs.to(device)
            _, batch_embeddings = model(inputs, return_embeddings=True)
            embeddings.append(batch_embeddings.cpu().numpy())
            labels.append(targets.numpy())

    embeddings = np.vstack(embeddings)
    labels = np.concatenate(labels)

    return embeddings, labels

def compute_distance_matrix(embeddings, metric='euclidean'):
    """Compute pairwise distances between embeddings"""
    if metric == 'euclidean':
        sq_dists = torch.cdist(
            torch.tensor(embeddings),
            torch.tensor(embeddings),
            p=2
        ).square().numpy()
        return sq_dists
    elif metric == 'cosine':
        # Cosine similarity
        normalized = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
        similarities = np.dot(normalized, normalized.T)
        return 1 - similarities
    else:
        raise ValueError(f"Unknown metric: {metric}")

def precision_at_k(distance_matrix, labels, k=5):
    n = distance_matrix.shape[0]
    neighbors = np.argsort(distance_matrix, axis=1)[:, 1:k+1]  # skip the first one

    precision = 0
    for i in range(n):
        query_label = labels[i]
        neighbor_labels = labels[neighbors[i]]
        precision += np.mean(neighbor_labels == query_label)

    return precision / n

# FIX THIS!!!
def visualize_neighbors(data_loader, embeddings, labels, class_names, num_queries=5, k=5):
    """Visualize query images and their nearest neighbors"""
    all_images = []
    all_indices = []

    for batch_idx, (images, batch_labels) in enumerate(data_loader):
        all_images.append(images)
        all_indices.extend(range(batch_idx * data_loader.batch_size,
                          min((batch_idx + 1) * data_loader.batch_size, len(data_loader.dataset))))
        if len(all_indices) >= 1000:
            break

    all_images = torch.cat(all_images, dim=0)

    distances = compute_distance_matrix(embeddings)

    query_indices = np.random.choice(len(all_indices), num_queries, replace=False)

    plt.figure(figsize=(15, num_queries * 2))

    for i, query_idx in enumerate(query_indices):
        real_query_idx = all_indices[query_idx]
        query_label = labels[real_query_idx]

        neighbor_indices = np.argsort(distances[real_query_idx])
        neighbor_indices = neighbor_indices[1:k+1]  # skip the first one

        plt.subplot(num_queries, k+1, i*(k+1) + 1)
        img = all_images[query_idx].permute(1, 2, 0).numpy()
        img = img[:, :, 0]  # take only first channel
        plt.imshow(img, cmap='gray')
        plt.title(f"Query: {class_names[query_label]}", fontsize=8)
        plt.axis('off')

        for j, neighbor_idx in enumerate(neighbor_indices):
            real_neighbor_idx = all_indices[neighbor_idx]
            neighbor_label = labels[real_neighbor_idx]

            plt.subplot(num_queries, k+1, i*(k+1) + j + 2)
            img = all_images[neighbor_idx].permute(1, 2, 0).numpy()
            img = img[:, :, 0]
            plt.imshow(img, cmap='gray')

            color = 'green' if neighbor_label == query_label else 'red'
            plt.title(f"{class_names[neighbor_label]}", fontsize=8, color=color)
            plt.axis('off')

    plt.tight_layout()
    plt.savefig('fashion_mnist_neighbors.png')
    plt.show()

def visualize_embeddings(embeddings, labels, class_names):
    tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=300)
    reduced_embeddings = tsne.fit_transform(embeddings)

    plt.figure(figsize=(12, 10))
    scatter = plt.scatter(
        reduced_embeddings[:, 0],
        reduced_embeddings[:, 1],
        c=labels,
        cmap='tab10',
        alpha=0.7,
        s=10
    )
    plt.colorbar(scatter, ticks=range(len(class_names)))
    plt.title('t-SNE visualization of Fashion MNIST embeddings')
    plt.tight_layout()

    handles, _ = scatter.legend_elements()
    plt.legend(handles, class_names, loc="upper right", title="Classes")

    plt.savefig('fashion_mnist_embeddings_pytorch.png')
    plt.show()

def plot_training_history(train_losses, val_losses, train_accs, val_accs):
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Training Accuracy')
    plt.plot(val_accs, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.savefig('fashion_mnist_training_history_pytorch.png')
    plt.show()

# Main function
def main():
    batch_size = 128
    num_epochs = 50
    learning_rate = 3e-4
    weight_decay = 1e-4

    # subset data (CHANGE!)
    subset_fraction = 0.7
    selected_classes = [0, 1, 2, 3, 4, 5]

    train_loader, val_loader, test_loader, class_names = load_fashion_mnist(
        batch_size=batch_size,
        subset_fraction=subset_fraction,
        selected_classes=selected_classes
    )

    num_classes = len(class_names)

    model = VisionTransformer(
        img_size=32,
        patch_size=4,
        in_channels=3,
        embed_dim=64,
        depth=6,
        num_heads=4,
        mlp_ratio=4.0,
        dropout=0.1,
        embedding_dim=128,
        num_classes=num_classes
    ).to(device)

    print(model)
    num_params = sum(p.numel() for p in model.parameters())
    print(f"number of parameters: {num_params:,}")

    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    os.makedirs('results', exist_ok=True)
    model_save_path = 'results/best_model.pth'

    train_losses, val_losses, train_accs, val_accs = train_with_metric_learning(
        model,
        train_loader,
        val_loader,
        optimizer,
        scheduler,
        num_epochs=num_epochs,
        save_path=model_save_path
    )

    plot_training_history(train_losses, val_losses, train_accs, val_accs)
    model.load_state_dict(torch.load(model_save_path))
    accuracy, all_preds, all_labels = evaluate_model(model, test_loader)
    test_embeddings, test_labels = extract_embeddings(model, test_loader)
    visualize_embeddings(test_embeddings, test_labels, class_names)

    distances = compute_distance_matrix(test_embeddings)
    prec_at_5 = precision_at_k(distances, test_labels, k=5)
    print(f"precision for retrieval: {prec_at_5:.4f}")

    visualize_neighbors(test_loader, test_embeddings, test_labels, class_names)


if __name__ == "__main__":
    main()