In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision import transforms as tf

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")

In [None]:
def set_seeds(seed: int = 42):
    """
    Sets random seed for reproducibility.
    :param seed: (int, optional): Random seed. Defaults to 42
    :return: 
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
!nvidia-smi

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
print(f"Using device: {device}")    

# Create Celebs Dataset

In [None]:
from dataset.celeb import CelebDataset

In [None]:
train_dataset = CelebDataset(mode='train')
val_dataset = CelebDataset(mode='val')
test_dataset = CelebDataset(mode='test')

# Let's look at cropped images from celeb dataset

In [None]:
import random

nrows, ncols = 1,  10
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10, 7))
ax = ax.flatten()
to_pil = tf.ToPILImage()
for i in range(nrows * ncols):
    rand_index = random.randint(0, len(train_dataset) - 1)
    image_tensor = train_dataset.__getitem__(rand_index)[0]
    pil_image = to_pil(image_tensor)
    ax[i].imshow(pil_image)

In [None]:
batch_size = 64

train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=batch_size)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size)

## Train simple classifier


# ResNet + Cross-Entropy Loss

Let's use pretrained ResNet network. Here I change fully-connected layers and freeze convolutional layers to relearn the network to recognize faces. 

In [None]:
from torchvision.models import resnet34, ResNet101_Weights, ResNet

model: ResNet = resnet34(weights=ResNet101_Weights.IMAGENET1K_V1)

# Freeze feature parameters in the "feature" section of the model, because we want to train only classifier
for param in model.parameters():
    param.requires_grad = False

set_seeds()
# Let's assign a brand-new classifier to our model. We will train it during train procedure
num_features = model.fc.in_features

model.fc = nn.Sequential(
    nn.Linear(in_features=num_features, out_features=num_features * 2),
    nn.ReLU(inplace=True),
    nn.Linear(in_features=num_features * 2, out_features=num_features),
    nn.ReLU(inplace=True),
    nn.Linear(in_features=num_features, out_features=128),
    nn.ReLU(inplace=True),
    nn.Linear(in_features=128, out_features=128),
    nn.ReLU(inplace=True),
    nn.Linear(in_features=128, out_features=500)
)

In [None]:
# Let's look at network architecture
from torchinfo import summary
summary(
    model=model,
    input_size=(32, 3, 224, 224),
    verbose=0,
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20
)

## Loss Function and Optimizer

In [None]:
criterion = F.cross_entropy
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=)

## Run Train Loop

In [None]:
from train.train import train

set_seeds()

history = train(
    epochs=50,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    device=device
)

# Triplet Loss 

## Triplet Celeb Dataset

In [None]:
from dataset.celeb_triplet import CelebTripletDataset

In [None]:
triplet_train_dataset = CelebTripletDataset(mode='train', file_location='Local')
triplet_val_dataset = CelebTripletDataset(mode='val', file_location='Local')

## Example of random triplets

In [None]:
import random

nrows, ncols = 5,  3
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10, 7))
ax = ax.flatten()
to_pil = tf.ToPILImage()
for i in range(0, nrows * ncols, 3):
    rand_index = random.randint(0, len(triplet_train_dataset) - 1)
    anchor, positive, negative = triplet_train_dataset[rand_index]
    ax[i + 0].imshow(to_pil(anchor))
    ax[i + 1].imshow(to_pil(positive))
    ax[i + 2].imshow(to_pil(negative))
    

## Let's create dataloaders

In [ ]:
triplet_train_dataloader = DataLoader(dataset=triplet_train_dataset)
triplet_val_dataloader = DataLoader(dataset=triplet_val_dataset)

# Choosing of model

In [ ]:
from torchvision.models import resnet18, ResNet, ResNet18_Weights

triplet_model: ResNet = resnet18(pretrained=True, weights=ResNet18_Weights.IMAGENET1K_V1)

In [ ]:
triplet_optimizer = optim.Adam(model.parameters())

In [ ]:
from train import triplet_loss_train as triplet_train

# triplet loss already built-in triplet_train loop
loss_history = triplet_train.train(
    epochs=20,
    model=triplet_model,
    train_dataloader=triplet_train_dataloader,
    val_dataloader=triplet_val_dataloader,
    optimizer=triplet_optimizer,
    device=device
)