In [32]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
import numpy as np

In [33]:
from models import ArcFaceModel
from visualization import visualize_embeddings

In [34]:
#device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create an instance of ArcFaceModel
model = ArcFaceModel(num_classes=10, embedding_size=2).to(device)

In [35]:
model_path = "models/best_model.pth"

#load the model weights
model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [36]:
#Load cifa10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

dataset = CIFAR10(root='data/cifar10', download=True, transform=transform)

Files already downloaded and verified


In [37]:
#sample some examples belonging to class 1 and 2
# dataset.targets is a list, so need to first convert it to a tensor
class1 = dataset.data[np.array(dataset.targets) == 1][:10]

class2 = dataset.data[np.array(dataset.targets) == 2][:10]

In [38]:
#Embedding of class 1 and 2
embeddings1 = model.get_embedding(torch.tensor(class1).permute(0,3,1,2).float().to(device))
embeddings2 = model.get_embedding(torch.tensor(class2).permute(0,3,1,2).float().to(device))

In [39]:
#Cosine similarity between embeddings of class 1 and 2
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
print(cos(embeddings1, embeddings2))

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000], device='cuda:0')


In [1]:
# Read celeba annotations

import pandas as pd

# Read the annotations file
df = pd.read_csv("data/identity_CelebA.txt", sep=" ", header=None)

In [2]:
df.head()

Unnamed: 0,0,1
0,000001.jpg,2880
1,000002.jpg,2937
2,000003.jpg,8692
3,000004.jpg,5805
4,000005.jpg,9295


In [5]:
# Count the number of images for each identity sorted descending

df[1].value_counts()


1
3227    35
2820    35
3782    35
3745    34
3699    34
        ..
8815     1
9770     1
2264     1
9075     1
3481     1
Name: count, Length: 10177, dtype: int64

In [None]:
# The images are saved in the img_align_celeba folder

# Load images and labels and put them in a torch dataset

import os
from PIL import Image
import torch
import torchvision.transforms as transforms
import pandas as pd
from torch.utils.data import Subset


class CelebADataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, annotations_file, transform=None):
        self.root_dir = root_dir
        self.annotations = pd.read_csv(annotations_file, sep=" ", header=None)
        self.transform = transform

    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.annotations.iloc[index, 0])
        image = Image.open(img_path)
        label = torch.tensor(self.annotations.iloc[index, 1])
        
        if self.transform:
            image = self.transform(image)
            
        return (image, label)

# Define transformations
transform = transforms.Compose([
    transforms.CenterCrop(178),
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])

    
# Create an instance of the dataset
dataset = CelebADataset(root_dir="data/img_align_celeba", annotations_file="data/identity_CelebA.txt", transform=transform)

#Take only four identities with the most images
dataset = Subset(dataset, np.where(np.array(dataset.annotations[1].value_counts().sort_values(ascending=False)[:4]))[0])