In [4]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class ArcFace(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.5):
        super(ArcFace, self).__init__()
        self.s = s  # Scaling factor
        self.m = m  # Margin
        self.W = nn.Parameter(torch.randn(out_features, in_features))  # Weight matrix
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, features, labels):
        # Normalize features and weights
        cosine = F.linear(F.normalize(features), F.normalize(self.W))
        sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))

        # Compute phi (cos(theta + m))
        phi = cosine * self.cos_m - sine * self.sin_m

        # Apply margin only to the correct class
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, labels.view(-1, 1), 1)

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

In [5]:
class FaceModel(nn.Module):
    def __init__(self, embedding_size=512, num_classes=10):
        super(FaceModel, self).__init__()
        self.backbone = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, embedding_size)
        )
        self.arcface = ArcFace(embedding_size, num_classes)

    def forward(self, x, labels=None):
        embeddings = self.backbone(x)
        if labels is not None:
            return self.arcface(embeddings, labels)
        else:
            return embeddings

In [None]:
import torch.optim as optim

# Hyperparameters
embedding_size = 512
num_classes = 10
batch_size = 32
learning_rate = 0.001
epochs = 5

# Initialize model, loss function, optimizer
model = FaceModel(embedding_size=embedding_size, num_classes=num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Synthetic data
features = torch.randn(1000, 128)  # 1000 samples, 128-dim input
labels = torch.randint(0, num_classes, (1000,))

# DataLoader
dataset = torch.utils.data.TensorDataset(features, labels)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop
for epoch in range(epochs):
    for inputs, targets in loader:
        outputs = model(inputs, targets)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

Epoch 1/5, Loss: 15.5128
Epoch 2/5, Loss: 15.0218
Epoch 3/5, Loss: 14.9976
Epoch 4/5, Loss: 13.6927
Epoch 5/5, Loss: 12.8411
