### Installing Dependencies

In [2]:
# %pip install timm
# %pip install onnx

In [3]:
"""
TO DO:
     - Train/Test Split
     - Normalization
     - Data Augmentation
     - Hyperparameter Tuning
     - Figure out ONNX Verification, Inference
     - Export to TensorFlow.js?
"""

'\nTO DO:\n     - Train/Test Split\n     - Normalization\n     - Data Augmentation\n     - Figure out ONNX Verification, Inference\n     - Export to TensorFlow.js?\n'

## Training ConvNext-Atto

In [17]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

from torch.utils.tensorboard import SummaryWriter

from pytorch_metric_learning import losses
import timm

import onnx
import onnxruntime

from PIL import UnidentifiedImageError

from pathlib import Path

In [5]:
batch_size = 64
epochs = 3
learning_rate = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_classes = 100 # ~100*12
embedding_size = 256

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

writer = SummaryWriter()

In [6]:
# Shouldn't really throw an error, but just in case
class RobustImageFolder(datasets.ImageFolder):
    def __getitem__(self, index):
        path, target = self.samples[index]
        try:
            sample = self.loader(path)
        except UnidentifiedImageError:
            print(f"\033[91mSkipping Corrupt Image:\033[0m {Path(path)}")            
            # return None, None
            return self.__getitem__(index + 1)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target

In [7]:
train_dataset = RobustImageFolder('../dataset/part-1', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [8]:
class ConvNeXtArcFace(nn.Module):
    def __init__(self, model_name, embedding_size, pretrained=False):
        super(ConvNeXtArcFace, self).__init__()
        self.convnext = timm.create_model(model_name, pretrained=pretrained)
        self.convnext.reset_classifier(num_classes=embedding_size, global_pool='avg')
        
    def forward(self, x):
        embeddings = self.convnext(x)
        return embeddings

In [9]:
model_name = 'convnextv2_atto'
model = ConvNeXtArcFace(model_name, embedding_size)
model = model.to(device)

criterion = losses.ArcFaceLoss(num_classes=num_classes, embedding_size=embedding_size)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [13]:
checkpoints = [3, 5, 10, 15, 25, 40, 60, 80, 100]

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        # if inputs is None:
        #     continue
        print("Batch: ", batch_idx + 1, "/", len(train_loader))
        inputs = inputs.to(device)        
        targets = targets.to(device)
        inputs = inputs.float()

        embeddings = model(inputs)

        loss = criterion(embeddings, targets)    
        writer.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + batch_idx)  
          
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()

    if batch_idx in checkpoints:
        torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': running_loss,
                    }, f"checkpoint_{epoch}.pth")
        writer.flush()


    print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item()}')

writer.flush()
writer.close()

Batch:  1 / 1215
Batch:  2 / 1215
Batch:  3 / 1215


KeyboardInterrupt: 

In [14]:
torch.save(model.state_dict(), 'convnext_atto_arcface.pth')

### Testing Inference

In [None]:
# model = ConvNeXtArcFace(model_name, embedding_size)
# model = model.to(device)
# model.load_state_dict(torch.load('convnext_atto_arcface.pth'))
# model.eval()

In [None]:
# test_dataset = RobustImageFolder('../dataset/part-1', transform=transform)
# test_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# correct = 0
# total = 0

# with torch.no_grad():
#     for batch_idx, (inputs, targets) in enumerate(test_loader):
#         inputs = inputs.to(device)
#         targets = targets.to(device)

#         embeddings = model(inputs)
#         cos_sim = torch.mm(embeddings, criterion.weight.t())
        
#         _, predicted = torch.max(cos_sim.data, 1)

#         total += targets.size(0)
#         correct += (predicted == targets).sum().item()

# accuracy = 100 * correct / total
# print('Test Accuracy: %f %%' % accuracy)

### Saving Model

In [15]:
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True).to(device)
dummy_output = model(dummy_input)
torch.onnx.export(model, dummy_input, "convnext_atto_arcface.onnx", export_params=True)

In [16]:
onnx_model = onnx.load("convnext_atto_arcface.onnx")
onnx.checker.check_model(onnx_model) # 

In [19]:
ort_session = onnxruntime.InferenceSession("convnext_atto_arcface.onnx", providers=["CUDAExecutionProvider"])

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(dummy_input)}
ort_outs = ort_session.run(None, ort_inputs)

# np.testing.assert_allclose(to_numpy(dummy_output), ort_outs[0], rtol=1e-03, atol=1e-05)
# print("Exported model has been tested with ONNXRuntime, and the result looks good!")