<a href="https://colab.research.google.com/github/bryaanabraham/Deep-Fake-AI/blob/main/ArcFace_loss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
!pip uninstall torch torchvision torchaudio
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu

Found existing installation: torch 2.3.1+cu121
Uninstalling torch-2.3.1+cu121:
  Would remove:
    /usr/local/bin/convert-caffe2-to-onnx
    /usr/local/bin/convert-onnx-to-caffe2
    /usr/local/bin/torchrun
    /usr/local/lib/python3.10/dist-packages/functorch/*
    /usr/local/lib/python3.10/dist-packages/torch-2.3.1+cu121.dist-info/*
    /usr/local/lib/python3.10/dist-packages/torch/*
    /usr/local/lib/python3.10/dist-packages/torchgen/*
Proceed (Y/n)? y
Y
[31mERROR: Operation cancelled by user[0m[31m
[0mTraceback (most recent call last):
  File "/usr/lib/python3.10/shutil.py", line 816, in move
    os.rename(src, real_dst)
OSError: [Errno 18] Invalid cross-device link: '/usr/local/lib/python3.10/dist-packages/torch/' -> '/usr/local/lib/python3.10/dist-packages/~orch'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/base_command.py", line 179, in exc_logging_w

In [1]:
import torch
import torch.nn.functional as F

from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms as T, datasets

import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px

from tqdm.notebook import tqdm
from sklearn.manifold import TSNE

device = torch.device('mps' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


# Preparing dataset

In [2]:
train_dataset = datasets.MNIST(root='./sample_data', train=True, transform=T.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./sample_data', train=False, transform=T.ToTensor(), download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

# Additive Angular Margin Penalty

In [3]:
class AdditiveAngularMarginPenalty(nn.Module):
    """
        Insightface implementation : https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/losses.py
        ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
    """
    def __init__(self, s=64.0, margin=0.5):
        super(AdditiveAngularMarginPenalty, self).__init__()
        self.s = s
        self.margin = margin

        self.cos_m = math.cos(margin)
        self.sin_m = math.sin(margin)
        self.theta = math.cos(math.pi - margin)
        self.sinmm = math.sin(math.pi - margin) * margin
        self.easy_margin = False

    def forward(self, logits: torch.Tensor, labels: torch.Tensor):
        index = torch.where(labels != -1)[0]
        target_logit = logits[index, labels[index].view(-1)]

        with torch.no_grad():
            target_logit.arccos_()
            logits.arccos_()
            final_target_logit = target_logit + self.margin
            logits[index, labels[index].view(-1)] = final_target_logit
            logits.cos_()
        logits = logits * self.s
        return logits

# Example CNN model

In [4]:
class ToyMNISTModel(nn.Module):
  def __init__(self):
    super(ToyMNISTModel, self).__init__()

    self.conv1 = nn.Conv2d(1, 32, 5)
    self.conv2 = nn.Conv2d(32, 32, 5)
    self.conv3 = nn.Conv2d(32, 64, 5)
    self.dropout = nn.Dropout(0.25)
    self.fc1 = nn.Linear(3*3*64, 256)
    self.fc2 = nn.Linear(256, 10)
    self.angular_margin_penalty = AdditiveAngularMarginPenalty(10, 10)
    self.relu = nn.ReLU(inplace=True)
    self.maxpooling = nn.MaxPool2d(2, 2)

  def forward(self, x, label=None):
    # CNN part
    x = self.relu(self.conv1(x))
    x = self.dropout(x)
    x = self.relu(self.maxpooling(self.conv2(x)))
    x = self.dropout(x)
    x = self.relu(self.maxpooling(self.conv3(x)))
    x = self.dropout(x)

    # fully connected part
    x = x.view(x.size(0), -1)    # (batch_size, 3*3*64)
    x = self.relu(self.fc1(x))
    x = self.fc2(x)

    if label is not None:
      # angular margin penalty part
      logits = self.angular_margin_penalty(x, label)
    else:
      logits = x

    return logits

In [5]:
model = ToyMNISTModel()
model.to(device)

ToyMNISTModel(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (dropout): Dropout(p=0.25, inplace=False)
  (fc1): Linear(in_features=576, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=10, bias=True)
  (angular_margin_penalty): AdditiveAngularMarginPenalty()
  (relu): ReLU(inplace=True)
  (maxpooling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

# Training

In [6]:
epochs = 10
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [7]:
model.train()

for e in range(epochs):
  print('epochs: ', e)
  for idx, (images, labels) in enumerate(train_loader):
    images = images.to(device)
    labels = labels.to(device)

    optimizer.zero_grad()
    outputs = model(images, labels)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

  with torch.no_grad():

    losses = []
    accs = 0.0

    for j, (val_images, val_labels) in enumerate(test_loader):
      val_images = val_images.to(device)
      val_labels = val_labels.to(device)

      outputs = model(val_images, val_labels)
      loss = criterion(outputs, val_labels)

      pred = outputs.argmax(dim=1, keepdim=True)
      acc = pred.eq(val_labels.view_as(pred)).sum().item()
      accs += acc

      losses.append(loss.item())

    loss = np.mean(losses)
    acc = accs / len(test_dataset)

    print(f'validation loss: {loss}, validation acc: {acc}')

epochs:  0
validation loss: nan, validation acc: 0.098
epochs:  1
validation loss: nan, validation acc: 0.098
epochs:  2
validation loss: nan, validation acc: 0.098
epochs:  3
validation loss: nan, validation acc: 0.098
epochs:  4
validation loss: nan, validation acc: 0.098
epochs:  5
validation loss: nan, validation acc: 0.098
epochs:  6
validation loss: nan, validation acc: 0.098
epochs:  7
validation loss: nan, validation acc: 0.098
epochs:  8
validation loss: nan, validation acc: 0.098
epochs:  9
validation loss: nan, validation acc: 0.098


In [16]:
torch.save(model.state_dict(), 'model.pt')

In [None]:
activations = {}

def get_activation(name):
  def hook(model, input, output):
    activations[name] = output.detach()
  return hook

In [None]:
h1 = model.fc2.register_forward_hook(get_activation('fc2'))

# Image feature visualization with t-SNE

In [None]:
with torch.no_grad():

    image_features = []
    labels = []

    for j, (val_images, val_labels) in enumerate(test_loader):
      val_images = val_images.to(device)
      val_labels = val_labels.to(device)

      outputs = model(val_images, val_labels)

      image_features.append(activations['fc2'].cpu().numpy())
      labels.append(val_labels.cpu().numpy())

    image_features = np.concatenate(image_features, axis=0)
    labels = np.concatenate(labels, axis=0)

tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
tsne_results = tsne.fit_transform(image_features)

[t-SNE] Computing 121 nearest neighbors...
[t-SNE] Indexed 10000 samples in 0.013s...
[t-SNE] Computed neighbors for 10000 samples in 0.625s...
[t-SNE] Computed conditional probabilities for sample 1000 / 10000
[t-SNE] Computed conditional probabilities for sample 2000 / 10000
[t-SNE] Computed conditional probabilities for sample 3000 / 10000
[t-SNE] Computed conditional probabilities for sample 4000 / 10000
[t-SNE] Computed conditional probabilities for sample 5000 / 10000
[t-SNE] Computed conditional probabilities for sample 6000 / 10000
[t-SNE] Computed conditional probabilities for sample 7000 / 10000
[t-SNE] Computed conditional probabilities for sample 8000 / 10000
[t-SNE] Computed conditional probabilities for sample 9000 / 10000
[t-SNE] Computed conditional probabilities for sample 10000 / 10000
[t-SNE] Mean sigma: 0.012805
[t-SNE] KL divergence after 250 iterations with early exaggeration: 65.657303
[t-SNE] KL divergence after 300 iterations: 2.257129


In [None]:
fig = px.scatter(x=tsne_results[:, 0], y=tsne_results[:, 1], color=labels)
fig.show()

# Training model without an additive angular margin penalty (normal softmax)

In [8]:
model_softmax = ToyMNISTModel()
model_softmax.to(device)

ToyMNISTModel(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (dropout): Dropout(p=0.25, inplace=False)
  (fc1): Linear(in_features=576, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=10, bias=True)
  (angular_margin_penalty): AdditiveAngularMarginPenalty()
  (relu): ReLU(inplace=True)
  (maxpooling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

In [9]:
epochs = 10
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_softmax.parameters(), lr=1e-4)

In [10]:
model_softmax.train()

for e in range(epochs):
  print('epochs: ', e)
  for idx, (images, labels) in enumerate(train_loader):
    images = images.to(device)
    labels = labels.to(device)

    optimizer.zero_grad()
    outputs = model_softmax(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

  with torch.no_grad():

    losses = []
    accs = 0.0

    for j, (val_images, val_labels) in enumerate(test_loader):
      val_images = val_images.to(device)
      val_labels = val_labels.to(device)

      outputs = model_softmax(val_images)
      loss = criterion(outputs, val_labels)

      pred = outputs.argmax(dim=1, keepdim=True)
      acc = pred.eq(val_labels.view_as(pred)).sum().item()
      accs += acc

      losses.append(loss.item())

    loss = np.mean(losses)
    acc = accs / len(test_dataset)

    print(f'validation loss: {loss}, validation acc: {acc}')

epochs:  0
validation loss: 0.17375295973460006, validation acc: 0.9477
epochs:  1
validation loss: 0.10390236543693171, validation acc: 0.9688
epochs:  2
validation loss: 0.07932388334635906, validation acc: 0.9744
epochs:  3
validation loss: 0.06066950064769406, validation acc: 0.9801
epochs:  4
validation loss: 0.054899129367012314, validation acc: 0.9827
epochs:  5
validation loss: 0.051520519332284, validation acc: 0.9828
epochs:  6
validation loss: 0.0520197206449689, validation acc: 0.9832
epochs:  7
validation loss: 0.04152064457325158, validation acc: 0.9852
epochs:  8
validation loss: 0.04326070704242321, validation acc: 0.9867
epochs:  9
validation loss: 0.03890685338206637, validation acc: 0.986


In [11]:
activations_softmax = {}

def get_activation_softmax(name):
  def hook(model, input, output):
    activations_softmax[name] = output.detach()
  return hook

In [12]:
h2 = model_softmax.fc2.register_forward_hook(get_activation_softmax('fc2'))

In [14]:
model_softmax.eval()

with torch.no_grad():

    image_features = []
    labels = []

    for j, (val_images, val_labels) in enumerate(test_loader):
      val_images = val_images.to(device)
      val_labels = val_labels.to(device)

      outputs = model_softmax(val_images)

      image_features.append(activations_softmax['fc2'].cpu().numpy())
      labels.append(val_labels.cpu().numpy())

    image_features = np.concatenate(image_features, axis=0)
    labels = np.concatenate(labels, axis=0)

    if np.isnan(image_features).any():
        print("Warning: NaN values found in image_features. Consider investigating model stability or data preprocessing.")

tsne_ = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
tsne_results_ = tsne_.fit_transform(image_features)

[t-SNE] Computing 121 nearest neighbors...
[t-SNE] Indexed 10000 samples in 0.012s...
[t-SNE] Computed neighbors for 10000 samples in 0.879s...
[t-SNE] Computed conditional probabilities for sample 1000 / 10000
[t-SNE] Computed conditional probabilities for sample 2000 / 10000
[t-SNE] Computed conditional probabilities for sample 3000 / 10000
[t-SNE] Computed conditional probabilities for sample 4000 / 10000
[t-SNE] Computed conditional probabilities for sample 5000 / 10000
[t-SNE] Computed conditional probabilities for sample 6000 / 10000
[t-SNE] Computed conditional probabilities for sample 7000 / 10000
[t-SNE] Computed conditional probabilities for sample 8000 / 10000
[t-SNE] Computed conditional probabilities for sample 9000 / 10000
[t-SNE] Computed conditional probabilities for sample 10000 / 10000
[t-SNE] Mean sigma: 2.347072
[t-SNE] KL divergence after 250 iterations with early exaggeration: 69.588326
[t-SNE] KL divergence after 300 iterations: 2.457272


# Image feature visualization without an additive angular margin penalty

In [15]:
fig = px.scatter(x=tsne_results_[:, 0], y=tsne_results_[:, 1], color=labels)
fig.show()