### Installing Dependencies

In [17]:
# %pip install timm
# %pip install onnx

TO DO
- [X] Train/Test Split
- [X] Normalization
- [X] Data Augmentation
- [ ] Hyperparameter Tuning
- [ ] Figure out ONNX Verification, Inference
- [ ] Export to TensorFlow.js?

In [18]:
"""
Outliers?
https://kevinmusgrave.github.io/pytorch-metric-learning/
https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/README.md
https://colab.research.google.com/github/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/SubCenterArcFaceMNIST.ipynb#scrollTo=GJ_L0TrTDnEA
---> Get_Outliers()
"""

'\nOutliers?\nhttps://kevinmusgrave.github.io/pytorch-metric-learning/\nhttps://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/README.md\nhttps://colab.research.google.com/github/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/SubCenterArcFaceMNIST.ipynb#scrollTo=GJ_L0TrTDnEA\n---> Get_Outliers()\n'

## Training ConvNext-Atto

In [19]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchsummary import summary
import torch.nn.functional as F

from torch.utils.tensorboard import SummaryWriter

from pytorch_metric_learning import losses, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

import timm

import onnx
import onnxruntime

from PIL import UnidentifiedImageError

from pathlib import Path

def printg(string): print("\033[92m{}\033[00m".format(string))
def printr(string): print("\033[91m{}\033[00m".format(string))

In [20]:
log = False

In [21]:
if log:
    writer = SummaryWriter()

batch_size = 64
epochs = 500
learning_rate = 1e-3
loss_lr = 1e-4
factor=0.3
patience = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
printg(f"Using device: {device}")
num_classes = 20 # ~100*12
embedding_size = 320

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

if log:
    writer.add_scalar('Hyperparameters/Batch_size', batch_size, 0)
    writer.add_scalar('Hyperparameters/Epochs', epochs, 0)
    writer.add_scalar('Hyperparameters/Learning_rate', learning_rate, 0)
    writer.add_scalar('Hyperparameters/Loss_lr', loss_lr, 0)
    writer.add_scalar('Hyperparameters/Num_classes', num_classes, 0)
    writer.add_scalar('Hyperparameters/Embedding_size', embedding_size, 0)

[92mUsing device: cuda[00m


In [22]:
# Shouldn't really throw an error, but just in case
class RobustImageFolder(datasets.ImageFolder):
    def __getitem__(self, index):
        path, target = self.samples[index]
        try:
            sample = self.loader(path)
        except UnidentifiedImageError:
            print(f"\033[91mSkipping Corrupt Image:\033[0m {Path(path)}")            
            # return None, None
            return self.__getitem__(index + 1)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target

In [23]:
train_dataset = RobustImageFolder('../faces/split/test', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = RobustImageFolder('../faces/split/val', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [24]:
class ConvNeXtArcFace(nn.Module):
    def __init__(self, model_name, embedding_size, pretrained=False):
        super(ConvNeXtArcFace, self).__init__()
        self.convnext = timm.create_model(model_name, pretrained=pretrained)
        self.convnext.reset_classifier(num_classes=0, global_pool='avg')
      
    def forward(self, x):
        x = self.convnext.forward_features(x) # 
        x = F.avg_pool2d(x, 7).flatten(1)
        print("Embeddings:", x.shape)
        return x

In [25]:
model_name = 'convnextv2_atto'
model = ConvNeXtArcFace(model_name, embedding_size)

In [None]:
x = torch.randn(1, 3, 224, 224)
model(x)

In [27]:
class EarlyStopping:
    def __init__(self, patience=patience, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model, epoch, optimizer, scheduler, criterion, loss_optimizer, loss_scheduler, running_loss):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch, optimizer, scheduler, criterion, loss_optimizer, loss_scheduler, running_loss)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch, optimizer, scheduler, criterion, loss_optimizer, loss_scheduler, running_loss)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, epoch, optimizer, scheduler, criterion, loss_optimizer, loss_scheduler, running_loss):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),

                'criterion_state_dict': criterion.state_dict(),
                'loss_optimizer_state_dict': loss_optimizer.state_dict(),
                'loss_scheduler_state_dict': loss_scheduler.state_dict(),
                'loss': running_loss,
                }, f"checkpoints/best_{epoch}.pth")        
        self.val_loss_min = val_loss

In [28]:
model_name = 'convnextv2_atto'
model = ConvNeXtArcFace(model_name, embedding_size)

model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.3, patience=10, verbose=True)

criterion = losses.ArcFaceLoss(num_classes=num_classes, embedding_size=embedding_size, margin=4).to(device)
loss_optimizer = optim.Adam(criterion.parameters(), lr=loss_lr)
loss_scheduler = ReduceLROnPlateau(loss_optimizer, mode='max', factor=0.3, patience=10, verbose=True)

start_epoch = 1

def load_checkpoint(filepath, model, optimizer, scheduler, loss_optimizer, loss_scheduler, criterion):
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    criterion.load_state_dict(checkpoint['criterion_state_dict'])
    loss_optimizer.load_state_dict(checkpoint['loss_optimizer_state_dict'])
    loss_scheduler.load_state_dict(checkpoint['loss_scheduler_state_dict'])
    epoch = checkpoint['epoch'] + 1
    loss = checkpoint['loss']
    return model, optimizer, scheduler, loss_optimizer, loss_scheduler, criterion, epoch, loss

In [29]:
checkpoint = None
if checkpoint:
    model, optimizer, scheduler, loss_optimizer, loss_scheduler, criterion, start_epoch, loss = load_checkpoint(
        f"checkpoints/{checkpoint}", model, optimizer, scheduler, loss_optimizer, loss_scheduler, criterion
        )

In [30]:
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester()
    return tester.get_all_embeddings(dataset, model)

accuracy_calculator = AccuracyCalculator(include=("precision_at_1",), k=1)

In [31]:
summary(model, (3, 224, 224))

Embeddings: torch.Size([2, 320])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 40, 56, 56]           1,960
       LayerNorm2d-2           [-1, 40, 56, 56]              80
          Identity-3           [-1, 40, 56, 56]               0
            Conv2d-4           [-1, 40, 56, 56]           2,000
       LayerNorm2d-5           [-1, 40, 56, 56]              80
            Conv2d-6          [-1, 160, 56, 56]           6,560
              GELU-7          [-1, 160, 56, 56]               0
           Dropout-8          [-1, 160, 56, 56]               0
GlobalResponseNorm-9          [-1, 160, 56, 56]             320
           Conv2d-10           [-1, 40, 56, 56]           6,440
          Dropout-11           [-1, 40, 56, 56]               0
GlobalResponseNormMlp-12           [-1, 40, 56, 56]               0
         Identity-13           [-1, 40, 56, 56]               0
  

In [None]:
early_stopping = EarlyStopping()
ckpt = [1, 3, 5, 10, 15, 25, 40, 60, 80, 90, 110, 130, 150, 175]
for epoch in range(start_epoch, epochs+1):
    model.train()
    running_loss = 0.0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        print(f"Epoch: {epoch}, Batch: {batch_idx + 1}/{len(train_loader)}")
        inputs = inputs.to(device)        
        targets = targets.to(device)
        inputs = inputs.float()

        optimizer.zero_grad()
        loss_optimizer.zero_grad()

        embeddings = model(inputs)
        # print("Embeddings:", embeddings.shape)
        loss = criterion(embeddings, targets)
        if log:    
            writer.add_scalar('Loss/train', loss.item(), (epoch-1) * len(train_loader) + batch_idx + 1)  
          
        loss.backward()
        optimizer.step()        
        loss_optimizer.step()
        
        running_loss += loss.item()


    train_embeddings, train_labels = get_all_embeddings(train_dataset, model)
    val_embeddings, val_labels = get_all_embeddings(val_dataset, model)

    train_labels = train_labels.squeeze(1)
    val_labels = val_labels.squeeze(1)

    accuracies = accuracy_calculator.get_accuracy(
            train_embeddings, train_labels, train_embeddings, train_labels, False
        )
    training_accuracy = accuracies['precision_at_1']
    if log:
        writer.add_scalar('Accuracy/Training', training_accuracy, epoch)
    printg(f"Train Set Accuracy = {training_accuracy}")

    accuracies = accuracy_calculator.get_accuracy(
            val_embeddings, val_labels, train_embeddings, train_labels, False
        )
    validation_accuracy = accuracies['precision_at_1']
    if log:
        writer.add_scalar('Accuracy/Validation', validation_accuracy, epoch)
    printg(f"Test Set Accuracy = {validation_accuracy}")

    scheduler.step(validation_accuracy)
    loss_scheduler.step(validation_accuracy)

    if (epoch) in ckpt:
        torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    
                    'criterion_state_dict': criterion.state_dict(),
                    'loss_optimizer_state_dict': loss_optimizer.state_dict(),
                    'loss_scheduler_state_dict': loss_scheduler.state_dict(),
                    'loss': running_loss,
                    }, f"checkpoints/epoch_{epoch}.pth")
        if log:
            writer.flush()


    printr(f"Epoch [{epoch}/{epochs}], Loss: {loss.item()}")

    early_stopping(-validation_accuracy, model, epoch, optimizer, scheduler, criterion, loss_optimizer, loss_scheduler, running_loss)
    if early_stopping.early_stop:
        print("Early Stopping")
        break

if log:
    writer.close()

In [None]:
torch.save(model.state_dict(), 'convnext_atto_arcface.pth')

### Testing Inference

In [None]:
model = ConvNeXtArcFace(model_name, embedding_size)
model = model.to(device)
model.load_state_dict(torch.load('convnext_atto_arcface.pth'))

<All keys matched successfully>

In [None]:
test_dataset = RobustImageFolder('../faces/split/test', transform=transform)
test_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [None]:
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester()
    return tester.get_all_embeddings(dataset, model)

accuracy_calculator = AccuracyCalculator(include=("precision_at_1",), k=1)

In [None]:
train_embeddings, train_labels = get_all_embeddings(train_dataset, model)
test_embeddings, test_labels = get_all_embeddings(test_dataset, model)

train_labels = train_labels.squeeze(1)
test_labels = test_labels.squeeze(1)

In [None]:
print("Computing accuracy...")
# accuracies = accuracy_calculator.get_accuracy(
        # test_embeddings, test_labels, train_embeddings, train_labels, False
    # )
# print("Test set accuracy = {}".format(accuracies["precision_at_1"]))

### Saving Model

In [None]:
model.to(device)
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True).to(device)
dummy_output = model(dummy_input).to(device)
torch.onnx.export(model, dummy_input, "convnext_atto_arcface.onnx", export_params=True)

Features: torch.Size([1, 320])
Features: torch.Size([1, 320])


In [None]:
# onnx_model = onnx.load("convnext_atto_arcface.onnx")
# onnx.checker.check_model(onnx_model) # 

In [None]:
# ort_session = onnxruntime.InferenceSession("convnext_atto_arcface.onnx", providers=["CUDAExecutionProvider"])

# def to_numpy(tensor):
#     return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(dummy_input)}
# ort_outs = ort_session.run(None, ort_inputs)

# # np.testing.assert_allclose(to_numpy(dummy_output), ort_outs[0], rtol=1e-03, atol=1e-05)
# # print("Exported model has been tested with ONNXRuntime, and the result looks good!")