### Installing Dependencies

In [51]:
# %pip install timm
# %pip install onnx

TO DO
- [X] Train/Test Split
- [X] Normalization
- [X] Data Augmentation
- [ ] Hyperparameter Tuning
- [ ] Figure out ONNX Verification, Inference
- [ ] Export to TensorFlow.js?

In [52]:
"""
Outliers?
https://kevinmusgrave.github.io/pytorch-metric-learning/
https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/README.md
https://colab.research.google.com/github/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/SubCenterArcFaceMNIST.ipynb#scrollTo=GJ_L0TrTDnEA
---> Get_Outliers()
"""

'\nOutliers?\nhttps://kevinmusgrave.github.io/pytorch-metric-learning/\nhttps://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/README.md\nhttps://colab.research.google.com/github/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/SubCenterArcFaceMNIST.ipynb#scrollTo=GJ_L0TrTDnEA\n---> Get_Outliers()\n'

## Training ConvNext-Atto

In [53]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
import torch.nn.functional as F

from torch.utils.tensorboard import SummaryWriter

from pytorch_metric_learning import losses, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

import timm

import onnx
import onnxruntime

from PIL import UnidentifiedImageError

from pathlib import Path

In [54]:
log = False

In [55]:
if log:
    writer = SummaryWriter()

batch_size = 64
epochs = 500
learning_rate = 1e-3
loss_lr = 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

num_classes = 20 # ~100*12
embedding_size = 320

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

if log:
    writer.add_scalar('Hyperparameters/Batch_size', batch_size, 0)
    writer.add_scalar('Hyperparameters/Epochs', epochs, 0)
    writer.add_scalar('Hyperparameters/Learning_rate', learning_rate, 0)
    writer.add_scalar('Hyperparameters/Loss_lr', loss_lr, 0)
    writer.add_scalar('Hyperparameters/Num_classes', num_classes, 0)
    writer.add_scalar('Hyperparameters/Embedding_size', embedding_size, 0)

In [56]:
# Shouldn't really throw an error, but just in case
class RobustImageFolder(datasets.ImageFolder):
    def __getitem__(self, index):
        path, target = self.samples[index]
        try:
            sample = self.loader(path)
        except UnidentifiedImageError:
            print(f"\033[91mSkipping Corrupt Image:\033[0m {Path(path)}")            
            # return None, None
            return self.__getitem__(index + 1)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target

In [57]:
train_dataset = RobustImageFolder('../faces/split/test', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = RobustImageFolder('../faces/split/val', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [58]:
class ConvNeXtArcFace(nn.Module):
    def __init__(self, model_name, embedding_size, pretrained=False):
        super(ConvNeXtArcFace, self).__init__()
        self.convnext = timm.create_model(model_name, pretrained=pretrained)
        self.convnext.reset_classifier(num_classes=0, global_pool='avg')
      
    def forward(self, x):
        x = self.convnext.forward_features(x) # 
        x = F.avg_pool2d(x, 7).flatten(1)
        print("Embeddings:", x.shape)
        return x

In [60]:
model_name = 'convnextv2_atto'
model = ConvNeXtArcFace(model_name, embedding_size)

In [61]:
x = torch.randn(1, 3, 224, 224)
model(x)

Embeddings: torch.Size([1, 320])


tensor([[-6.8194e-02,  5.1333e-02, -6.1956e-02, -3.6360e-02,  5.3770e-02,
         -8.9238e-02, -1.2089e-01, -7.4126e-02, -1.1603e-02, -5.8732e-02,
         -1.4981e-02,  3.6882e-02, -1.0532e-02,  8.5154e-02,  1.1332e-01,
         -8.5254e-02,  2.5808e-03, -1.2512e-01,  6.8433e-02, -1.1143e-01,
         -5.7732e-02,  6.4791e-02, -2.9316e-03,  5.4968e-02, -7.3018e-03,
          1.0388e-01,  1.8839e-01,  5.7166e-02, -2.2003e-01, -4.3469e-02,
          1.7603e-02, -5.1367e-02,  5.9913e-02, -4.6071e-03, -1.9538e-02,
         -1.6529e-05,  2.3502e-02, -4.2935e-02,  6.5636e-02, -2.4650e-01,
         -9.0053e-02, -1.6097e-01, -1.6759e-03, -6.9118e-03,  5.4979e-02,
         -7.1678e-03, -6.4141e-02,  5.7802e-02, -7.1169e-02, -2.8143e-01,
          2.3036e-01,  4.9556e-02, -8.6411e-02,  1.8848e-03, -8.5480e-02,
         -1.0621e-01, -1.9180e-01,  2.6773e-02, -1.0937e-01, -5.4265e-02,
         -8.3141e-02,  2.5344e-01, -2.3842e-02, -7.4361e-02,  7.1475e-02,
          6.3373e-03,  7.4410e-02, -6.

In [62]:
model_name = 'convnextv2_atto'
model = ConvNeXtArcFace(model_name, embedding_size)

model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

criterion = losses.ArcFaceLoss(num_classes=num_classes, embedding_size=embedding_size, margin=4).to(device)
loss_optimizer = optim.Adam(criterion.parameters(), lr=loss_lr)

start_epoch = 1

def load_checkpoint(filepath, model, optimizer, loss_optimizer, criterion):
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    criterion.load_state_dict(checkpoint['criterion_state_dict'])
    loss_optimizer.load_state_dict(checkpoint['loss_optimizer_state_dict'])
    epoch = checkpoint['epoch'] + 1
    loss = checkpoint['loss']
    return model, optimizer, loss_optimizer, criterion, epoch, loss

In [63]:
checkpoint = None # "epoch_5.pth"
if checkpoint:
    model, optimizer, loss_optimizer, criterion, start_epoch, loss = load_checkpoint(
        f"checkpoints/{checkpoint}", model, optimizer, loss_optimizer, criterion
        )

In [64]:
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester()
    return tester.get_all_embeddings(dataset, model)

accuracy_calculator = AccuracyCalculator(include=("precision_at_1",), k=1)

In [65]:
ckpt = [] # [1, 3, 5, 10, 15, 25, 40, 60, 80, 90, 110, 130, 150, 175]
for epoch in range(start_epoch, epochs+1):
    model.train()
    running_loss = 0.0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        print(f"Epoch: {epoch}, Batch: {batch_idx + 1}/{len(train_loader)}")
        inputs = inputs.to(device)        
        targets = targets.to(device)
        inputs = inputs.float()

        optimizer.zero_grad()
        loss_optimizer.zero_grad()

        embeddings = model(inputs)
        # print("Embeddings:", embeddings.shape)
        loss = criterion(embeddings, targets)
        if log:    
            writer.add_scalar('Loss/train', loss.item(), (epoch-1) * len(train_loader) + batch_idx + 1)  
          
        loss.backward()
        optimizer.step()
        loss_optimizer.step()
        
        running_loss += loss.item()


    # train_embeddings, train_labels = get_all_embeddings(train_dataset, model)
    # val_embeddings, val_labels = get_all_embeddings(val_dataset, model)

    # train_labels = train_labels.squeeze(1)
    # val_labels = val_labels.squeeze(1)

    # accuracies = accuracy_calculator.get_accuracy(
    #         train_embeddings, train_labels, train_embeddings, train_labels, False
    #     )
    # training_accuracy = accuracies['precision_at_1']
    # if log:
    #     writer.add_scalar('Accuracy/Training', training_accuracy, epoch)
    # print(f"\033[92mTrain Set Accuracy = {training_accuracy}\033[0m")

    # accuracies = accuracy_calculator.get_accuracy(
    #         val_embeddings, val_labels, train_embeddings, train_labels, False
    #     )
    # validation_accuracy = accuracies['precision_at_1']
    # if log:
    #     writer.add_scalar('Accuracy/Validation', validation_accuracy, epoch)
    # print(f"\033[92mTest Set Accuracy = {validation_accuracy}\033[0m")


    if (epoch) in ckpt:
        torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss_optimizer_state_dict': loss_optimizer.state_dict(),
                    'criterion_state_dict': criterion.state_dict(),
                    'loss': running_loss,
                    }, f"checkpoints/epoch_{epoch}.pth")
        if log:
            writer.flush()


    print(f"\033[91mEpoch [{epoch}/{epochs}], Loss: {loss.item()}\033[0m")

if log:
    writer.close()

Epoch: 1, Batch: 1/13
Embeddings: torch.Size([64, 320])
Epoch: 1, Batch: 2/13
Embeddings: torch.Size([64, 320])
Epoch: 1, Batch: 3/13
Embeddings: torch.Size([64, 320])
Epoch: 1, Batch: 4/13
Embeddings: torch.Size([64, 320])
Epoch: 1, Batch: 5/13
Embeddings: torch.Size([64, 320])
Epoch: 1, Batch: 6/13
Embeddings: torch.Size([64, 320])
Epoch: 1, Batch: 7/13
Embeddings: torch.Size([64, 320])
Epoch: 1, Batch: 8/13
Embeddings: torch.Size([64, 320])
Epoch: 1, Batch: 9/13
Embeddings: torch.Size([64, 320])
Epoch: 1, Batch: 10/13
Embeddings: torch.Size([64, 320])
Epoch: 1, Batch: 11/13
Embeddings: torch.Size([64, 320])
Epoch: 1, Batch: 12/13
Embeddings: torch.Size([64, 320])
Epoch: 1, Batch: 13/13
Embeddings: torch.Size([50, 320])
[91mEpoch [1/500], Loss: 6.559036731719971[0m
Epoch: 2, Batch: 1/13
Embeddings: torch.Size([64, 320])
Epoch: 2, Batch: 2/13
Embeddings: torch.Size([64, 320])
Epoch: 2, Batch: 3/13
Embeddings: torch.Size([64, 320])
Epoch: 2, Batch: 4/13
Embeddings: torch.Size([64, 32

KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), 'convnext_atto_arcface.pth')

### Testing Inference

In [None]:
model = ConvNeXtArcFace(model_name, embedding_size)
model = model.to(device)
model.load_state_dict(torch.load('convnext_atto_arcface.pth'))

<All keys matched successfully>

In [None]:
test_dataset = RobustImageFolder('../faces/split/test', transform=transform)
test_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [None]:
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester()
    return tester.get_all_embeddings(dataset, model)

accuracy_calculator = AccuracyCalculator(include=("precision_at_1",), k=1)

In [None]:
train_embeddings, train_labels = get_all_embeddings(train_dataset, model)
test_embeddings, test_labels = get_all_embeddings(test_dataset, model)

train_labels = train_labels.squeeze(1)
test_labels = test_labels.squeeze(1)

In [None]:
print("Computing accuracy...")
# accuracies = accuracy_calculator.get_accuracy(
        # test_embeddings, test_labels, train_embeddings, train_labels, False
    # )
# print("Test set accuracy = {}".format(accuracies["precision_at_1"]))

### Saving Model

In [None]:
model.to(device)
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True).to(device)
dummy_output = model(dummy_input).to(device)
torch.onnx.export(model, dummy_input, "convnext_atto_arcface.onnx", export_params=True)

Features: torch.Size([1, 320])
Features: torch.Size([1, 320])


In [None]:
# onnx_model = onnx.load("convnext_atto_arcface.onnx")
# onnx.checker.check_model(onnx_model) # 

In [None]:
# ort_session = onnxruntime.InferenceSession("convnext_atto_arcface.onnx", providers=["CUDAExecutionProvider"])

# def to_numpy(tensor):
#     return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(dummy_input)}
# ort_outs = ort_session.run(None, ort_inputs)

# # np.testing.assert_allclose(to_numpy(dummy_output), ort_outs[0], rtol=1e-03, atol=1e-05)
# # print("Exported model has been tested with ONNXRuntime, and the result looks good!")