In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms.v2 as T
from torchvision.models import resnet50,ResNet50_Weights
from PIL import Image
import random, os
import wandb
import datetime


In [None]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

os.makedirs("models", exist_ok=True)
os.makedirs("tsneplots", exist_ok=True)
os.makedirs("ROCplots", exist_ok=True)

print("Directories created or already exist!")

In [None]:
BATCH_SIZE = 64 * 3
LR = 2e-4
DECAY = 1e-3
DROPOUT = 0.3
EPOCHS = 65
FACTOR = 0.75
EMB_DIM = 256
THRESHOLD = 0.9
ALPHA = 0.4
INITIAL_ALPHA = 0.1
FINAL_ALPHA = ALPHA
N_EPOCHS_MARGIN = 49 # might need to change this to 75 or 100
MARGIN = 0.5
INITIAL_MARGIN = 1.0 
PEAK_MARGIN = 0.6 # maybe make this a bit higher
FINAL_MARGIN = 0.3 # maybe lower this

now = datetime.datetime.now()

name = now.strftime("experiment_%d_%m_%H_%M")

wandb.init(
    project="face-verification",
    name=name,
    config={
        "batch_size": BATCH_SIZE,
        "learning_rate": LR,
        "architecture": "resnet50",
        "epochs": EPOCHS,
        "loss": "hybrid triplet loss",
        'alpha': ALPHA,
        'lr_scheduler': 'ReduceOnPlateau',
        "lr factor": FACTOR,
        "margin": MARGIN,
        "n_epochs_for_margin":N_EPOCHS_MARGIN,
        "dropout": DROPOUT,
        "dataset": "fullvggface2",
        "emb_dim": EMB_DIM,
        'threshold': THRESHOLD,
        "augmentations":"agressive",
        "margin_schedueler": True,
        "cosine_sim_loss": True,
        'hard negative mining': True,
        "initial_margin":INITIAL_MARGIN,
        "peak_margin": PEAK_MARGIN,
        "final_margin":FINAL_MARGIN,
    }
)

In [None]:
def get_all_identities(root_dir):
      all_entries = os.listdir(root_dir)
      identities = []
      for entry in all_entries:
          full_path = os.path.join(root_dir,entry)
          if os.path.isdir(full_path):
              identities.append(entry)
      return identities

def get_portion_of_identities(identities,fraction):
  n_identities = len(identities) 

  random.shuffle(identities)
  cutoff = int(n_identities * fraction)
  chosen_identities = identities[:cutoff]
  return chosen_identities

def build_label_to_images(root_dir,chosen_identities):
  label_to_images = {}

  for label in chosen_identities:
      label_path = os.path.join(root_dir,label)
      if os.path.isdir(label_path):
          image_names = os.listdir(label_path)

          image_paths = [os.path.join(label_path,image_name) for image_name in image_names]

          if len(image_paths) >= 2:
              label_to_images[label] = image_paths

  return label_to_images
def get_embedding(image_path, model, transform, device):
    """
    Load an image, apply transforms, and compute its embedding.
    """
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0).to(device) 
    with torch.no_grad():
        embedding = model(image)
    return embedding.squeeze().cpu().numpy()

In [None]:

class FaceVerificationDataset(Dataset):
    def __init__(self, label_to_imgs, transform=None):
        self.label_to_imgs = label_to_imgs
        self.transform = transform
        self.labels = list(label_to_imgs.keys())

        self.valid_pairs = [lbl for lbl in self.labels if len(self.label_to_imgs[lbl]) >= 2]

    def __len__(self):
        return len(self.valid_pairs) * 10

    def __getitem__(self, idx):
        anchor_label = random.choice(self.valid_pairs)

        anchor_path, positive_path = random.sample(self.label_to_imgs[anchor_label], 2)

        negative_candidates = [l for l in self.valid_pairs if l != anchor_label]
        negative_label = random.choice(negative_candidates)
        negative_path = random.choice(self.label_to_imgs[negative_label])

        try:
            anchor_img = self.load_and_transform(anchor_path)
            positive_img = self.load_and_transform(positive_path)
            negative_img = self.load_and_transform(negative_path)

            if (anchor_img is None) or (positive_img is None) or (negative_img is None):
                return self.__getitem__(random.randint(0, len(self) - 1))

            return anchor_img, positive_img, negative_img

        except:
            return self.__getitem__(random.randint(0, len(self) - 1))

    def load_and_transform(self, img_path):
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img



In [None]:
def margin_schedule(epoch, initial_margin=INITIAL_MARGIN, peak_margin=PEAK_MARGIN, final_margin=FINAL_MARGIN, peak_epoch=15, total_epochs=N_EPOCHS_MARGIN):
    #peak margin was 0.6 in GOOD RUN now 0.8
    if epoch < peak_epoch:
        return initial_margin + (peak_margin - initial_margin) * (epoch / peak_epoch)  # Gradual increase
    else:
        return peak_margin - ((peak_margin - final_margin) * ((epoch - peak_epoch) / (total_epochs - peak_epoch)))  # Then refine
    
def cosine_sim_loss(anchor_embs, positive_embs):
    return 1 - F.cosine_similarity(anchor_embs, positive_embs).mean()
  
def semi_hard_triplet_loss(anchor_embs, positive_embs, negative_embs, margin=MARGIN):
    
    d_ap = F.pairwise_distance(anchor_embs, positive_embs)  
    d_an = F.pairwise_distance(anchor_embs, negative_embs)  

    semi_hard_mask = (d_an > d_ap) & (d_an < (d_ap + margin))

    if not semi_hard_mask.any():
        losses = F.relu(d_ap - d_an + margin)
        return losses.mean()

    losses = F.relu(d_ap[semi_hard_mask] - d_an[semi_hard_mask] + margin)
    return losses.mean()

def hard_triplet_loss(anchor_embs, positive_embs, negative_embs, margin=MARGIN):
    d_ap = F.pairwise_distance(anchor_embs, positive_embs)
    d_an = F.pairwise_distance(anchor_embs, negative_embs)
    
    hard_mask = d_an < d_ap
    
    
    if not hard_mask.any():
        losses = F.relu(d_ap - d_an + margin)
        return losses.mean()
    
    losses = F.relu(d_ap[hard_mask] - d_an[hard_mask] + margin)
    return losses.mean()
  
def hybrid_triplet_loss(anchor_embs, positive_embs, negative_embs, epoch, margin = MARGIN, alpha = ALPHA):
  d_an = F.pairwise_distance(anchor_embs,negative_embs)
  d_ap = F.pairwise_distance(anchor_embs,positive_embs)
    
  margin = margin_schedule(epoch)

  hard_mask = d_an < d_ap

  semi_hard_mask = (d_an > d_ap) & (d_an < (d_ap + margin))
  
  hard_loss = None
  semi_loss = None
  if hard_mask.any():
    hard_loss = F.relu(d_ap[hard_mask] - d_an[hard_mask] + margin)
    hard_loss = hard_loss.mean()
    
  if semi_hard_mask.any():
    semi_loss = F.relu(d_ap[semi_hard_mask] - d_an[semi_hard_mask] + margin)  
    semi_loss = semi_loss.mean()
    
  base_loss = 0.0
  if hard_loss is not None and semi_loss is not None:
    base_loss = alpha * hard_loss + (1 - alpha) * semi_loss
    
  elif hard_loss is not None:
    base_loss = hard_loss
    
  elif semi_loss is not None:
    base_loss = semi_loss

  else:
    base_loss = F.relu(d_ap - d_an + margin).mean()
  
  cosine_weight = 0.25 if epoch > 10 else 0.0
  return base_loss + cosine_weight * cosine_sim_loss(anchor_embs,positive_embs)


In [None]:
class FaceVerificationModel(nn.Module):
    def __init__(self, backbone, embedding_size=256, dropout=0.3):
        super().__init__()
        self.backbone = backbone
        self._freeze_layers()
      
        in_features = self.backbone.fc.in_features  
        self.backbone.fc = nn.Identity()

        self.embedding_layer = nn.Sequential(
            nn.Linear(in_features, 1024),
            nn.BatchNorm1d(1024),
            nn.PReLU(),
            nn.Dropout(dropout),
            nn.Linear(1024, embedding_size),
        )

    def _freeze_layers(self):
      
        for param in self.backbone.parameters():
            param.requires_grad = False

        for param in self.backbone.layer4.parameters():
            param.requires_grad = True  
           
    def forward(self, x):
        features = self.backbone(x)
        emb = self.embedding_layer(features)
        return F.normalize(emb, p=2, dim=1)


In [None]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0, path='./models/'+name+'_model.pth'):
        
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False
        self.path = path

    def __call__(self, val_loss, val_acc, model):
      
        if self.best_loss is None:
            self.best_loss = val_loss
            torch.save(model.state_dict(), self.path)
            print("saved model")

        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0

            torch.save(model.state_dict(), self.path)
            print(f"saved model at loss : {val_loss:.4f} - Accuracy: {val_acc:.4f}")
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                

In [None]:
def train_one_epoch(model, loader, optimizer, loss_fn, scaler, device, epoch):
    model.train()
    running_loss = 0.0
    
    for step, (anchor_imgs, pos_imgs, neg_imgs) in enumerate(loader):
        anchor_imgs = anchor_imgs.to(device)
        pos_imgs = pos_imgs.to(device)
        neg_imgs = neg_imgs.to(device)
        
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast():
            anchor_embs = model(anchor_imgs)
            pos_embs = model(pos_imgs)
            neg_embs = model(neg_imgs)
            loss = loss_fn(anchor_embs, pos_embs, neg_embs, epoch)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item()
    
    return running_loss / len(loader)

@torch.no_grad()
def validate(model, loader, loss_fn, device, epoch, threshold = 0.75):
    model.eval()
    running_loss = 0.0
    all_sim = []
    all_labels = []

    for (anchor_imgs, pos_imgs, neg_imgs) in loader:
        anchor_imgs = anchor_imgs.to(device)
        pos_imgs = pos_imgs.to(device)
        neg_imgs = neg_imgs.to(device)

        anchor_embs = model(anchor_imgs)
        pos_embs = model(pos_imgs)
        neg_embs = model(neg_imgs)

        batch_loss = loss_fn(anchor_embs, pos_embs, neg_embs, epoch)
        running_loss += batch_loss.item()

        pos_sim = F.cosine_similarity(anchor_embs, pos_embs)
        neg_sim = F.cosine_similarity(anchor_embs, neg_embs)
        all_sim.extend(pos_sim.cpu().numpy())
        all_labels.extend([1]*len(pos_sim))
        all_sim.extend(neg_sim.cpu().numpy())
        all_labels.extend([0]*len(neg_sim))

    all_sim = np.array(all_sim)
    all_labels = np.array(all_labels)

    if threshold is None:
        thresholds = np.linspace(-1, 1, 200)  
        accuracies = [(all_sim > t) == all_labels for t in thresholds]
        accuracies = [np.mean(acc) for acc in accuracies]
        best_idx = np.argmax(accuracies)
        threshold = thresholds[best_idx]
        best_acc = accuracies[best_idx]
    else:
        best_acc = np.mean((all_sim > threshold) == all_labels)

    return running_loss / len(loader), best_acc


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train_transforms = T.Compose([
        T.Resize((256, 256)),               
        T.RandomCrop(224),     # Random crop
        T.RandomHorizontalFlip(),
        T.RandomRotation(10),
        T.ColorJitter(brightness=0.2, contrast=0.2, saturation = 0.1),
        T.RandomAffine(degrees=5, translate=(0.05, 0.05), scale=(0.95, 1.05)),  # More conservative
        T.ToImage(),
        T.ToDtype(torch.float32, scale=True),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225])
])

test_transforms = T.Compose([
    T.Resize((224, 224)),              
    T.ToImage(),
    T.ToDtype(torch.float32, scale=True),
    T.Normalize([0.485, 0.456, 0.406],
                [0.229, 0.224, 0.225])
])

In [None]:
dataset_path = "./data"

train_data_path = os.path.join(dataset_path,'train/train')
test_data_path = os.path.join(dataset_path,'val/test')

train_identities = get_all_identities(train_data_path)

train_chosen_identities = get_portion_of_identities(train_identities,1)
train_label_to_imgs = build_label_to_images(train_data_path,train_chosen_identities)

test_identities = get_all_identities(test_data_path)
test_label_to_imgs = build_label_to_images(test_data_path,test_identities)

train_dataset = FaceVerificationDataset(train_label_to_imgs, train_transforms)
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE ,shuffle = True, num_workers = min(4, os.cpu_count() // 2))

test_dataset = FaceVerificationDataset(test_label_to_imgs,test_transforms)
test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE, num_workers = min(4, os.cpu_count() // 2))

In [None]:
backbone = resnet50(weights=ResNet50_Weights.DEFAULT)
model = FaceVerificationModel(backbone, dropout=DROPOUT, embedding_size=EMB_DIM).to(device)

loss_fn = hybrid_triplet_loss  
early_stopping = EarlyStopping(patience=5,min_delta = 0.01)

optimizer = torch.optim.AdamW([
    {"params": model.backbone.layer4.parameters(), "lr": 5e-5},
    {"params": model.embedding_layer.parameters(), "lr": LR},
], weight_decay=DECAY)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=FACTOR, patience=3)
scaler = torch.cuda.amp.GradScaler()

In [None]:
best_threshold = None
for epoch in range(EPOCHS):

    train_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, scaler, device, epoch)
    val_loss, val_acc = validate(
        model, test_loader, loss_fn, device,epoch, threshold = THRESHOLD
    )
    if epoch > 10:
      early_stopping(val_loss, val_acc, model)
    scheduler.step(val_loss)
    
    current_lr = optimizer.param_groups[-1]['lr']
    
    wandb.log({
        "train_loss": train_loss,
        "val_loss": val_loss,
        "val_acc": val_acc,
        'epoch': epoch,
        'learning_rate': current_lr
    })
    
    if early_stopping.early_stop:
      print("Early stopping triggered! Stopping training.")
      break

wandb.log({'best_loss':early_stopping.best_loss})

In [None]:
import os
import random

def build_label_to_imgs(root_dir):
    """
    root_dir: path to the directory containing subfolders like 'n000001', 'n000009', etc.
    
    Returns a dictionary: { 'n000001': [path1, path2, ...], 'n000009': [...], ... }
    """
    label_to_imgs = {}
    for identity_folder in os.listdir(root_dir):
        identity_path = os.path.join(root_dir, identity_folder)
        
        if os.path.isdir(identity_path):
            image_filenames = os.listdir(identity_path)
            
            image_paths = [
                os.path.join(identity_path, img_name)
                for img_name in image_filenames
                if img_name.lower().endswith(('.jpg', '.png', '.jpeg'))
            ]
            
            if len(image_paths) > 0:
                label_to_imgs[identity_folder] = image_paths
    return label_to_imgs

import random

def random_subset_label_dict(label_to_imgs, num_identities=10, max_imgs_per_identity=None):

    all_identities = list(label_to_imgs.keys())
    random.shuffle(all_identities)

    chosen_identities = all_identities[:num_identities]

    subset_dict = {}
    for identity in chosen_identities:
     
        if identity in label_to_imgs:
            paths = label_to_imgs[identity]
            random.shuffle(paths)  
            if max_imgs_per_identity is not None:
                paths = paths[:max_imgs_per_identity]
            subset_dict[identity] = paths

    return subset_dict

class TSNEDataset(Dataset):

    def __init__(self, label_to_imgs, transform=None):
        self.transform = transform
        self.samples = []
        for label, image_paths in label_to_imgs.items():
            for path in image_paths:
                self.samples.append((path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label

root_dir = "./data/val/test" 
label_to_imgs = build_label_to_imgs(root_dir)

subset_dict = random_subset_label_dict(label_to_imgs, num_identities=10, max_imgs_per_identity=20)

tsne_dataset = TSNEDataset(subset_dict, transform=test_transforms)
model.load_state_dict(torch.load('./models/'+name+'_model.pth',map_location=device))

tsne_dataloader = DataLoader(tsne_dataset, batch_size=32, shuffle=False)

import numpy as np

def extract_embeddings(model, dataloader, device):
    model.eval()
    all_embeddings = []
    all_labels = []

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            embeddings = model(images)  
            all_embeddings.append(embeddings.cpu())
            all_labels.append(labels)

    all_embeddings = torch.cat(all_embeddings, dim=0).numpy()
    import itertools
    all_labels = list(itertools.chain(*all_labels))
    all_labels = np.array(all_labels)
    return all_embeddings, all_labels

embeddings, labels = extract_embeddings(model, tsne_dataloader, device)
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

def plot_and_log_tsne(embeddings, labels, perplexity=30, artifact_name="tsne_plots"):
    tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
    reduced = tsne.fit_transform(embeddings)

    unique_labels = np.unique(labels)
    label_to_index = {label: idx for idx, label in enumerate(unique_labels)}
    numeric_labels = np.array([label_to_index[label] for label in labels])

    plt.figure(figsize=(10, 7))
    scatter = plt.scatter(
        reduced[:, 0], 
        reduced[:, 1], 
        c=numeric_labels, 
        cmap="viridis", 
        alpha=0.7
    )
    cbar = plt.colorbar(scatter, ticks=range(len(unique_labels)))
    cbar.ax.set_yticklabels(unique_labels)

    plt.title("t-SNE Visualization of Face Embeddings")
    plt.xlabel("t-SNE Dim 1")
    plt.ylabel("t-SNE Dim 2")

    plot_filename = "./tsneplots/tsne_plot"+name+".png"
    plt.savefig(plot_filename)
    plt.close()
    
    artifact = wandb.Artifact(artifact_name, type="analysis")
    artifact.add_file(plot_filename)

    wandb.log_artifact(artifact)

    wandb.log({"tsne_plot": wandb.Image(plot_filename)})



plot_and_log_tsne(embeddings, labels, perplexity=40, artifact_name="tsne_"+name)


In [None]:
import numpy as np
import random
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
import matplotlib.pyplot as plt
import torch.nn.functional as F
import wandb


positive_similarities = []
negative_similarities = []

for label, img_paths in label_to_imgs.items():
    if len(img_paths) < 2:
        continue
    img1, img2 = random.sample(img_paths, 2)
    emb1 = get_embedding(img1, model, test_transforms, device) 
    emb2 = get_embedding(img2, model, test_transforms, device)
    pos_sim = F.cosine_similarity(torch.tensor(emb1).unsqueeze(0), torch.tensor(emb2).unsqueeze(0)).item()
    positive_similarities.append(pos_sim)

all_labels = list(label_to_imgs.keys())
num_negatives = len(positive_similarities) 
for _ in range(num_negatives):
    label1, label2 = random.sample(all_labels, 2)
    img1 = random.choice(label_to_imgs[label1])
    img2 = random.choice(label_to_imgs[label2])
    emb1 = get_embedding(img1, model, test_transforms, device)
    emb2 = get_embedding(img2, model, test_transforms, device)
    neg_sim = F.cosine_similarity(torch.tensor(emb1).unsqueeze(0), torch.tensor(emb2).unsqueeze(0)).item()
    negative_similarities.append(neg_sim)

similarities = np.concatenate([np.array(positive_similarities), np.array(negative_similarities)])
true_labels = np.concatenate([np.ones(len(positive_similarities)), np.zeros(len(negative_similarities))])

thresholds = np.arange(0.0, 1.0, 0.01)
best_thresh = 0.0
best_acc = 0.0
acc_list = []

for thresh in thresholds:
    preds = similarities > thresh
    acc = accuracy_score(true_labels, preds)
    acc_list.append(acc)
    if acc > best_acc:
        best_acc = acc
        best_thresh = thresh

print("Optimal threshold: {:.2f} with accuracy: {:.2f}".format(best_thresh, best_acc))

pred_labels = similarities > best_thresh

precision = precision_score(true_labels, pred_labels)
recall = recall_score(true_labels, pred_labels)
f1 = f1_score(true_labels, pred_labels)

fp = np.sum((pred_labels == 1) & (true_labels == 0))
tn = np.sum((pred_labels == 0) & (true_labels == 0))
fn = np.sum((pred_labels == 0) & (true_labels == 1))
tp = np.sum((pred_labels == 1) & (true_labels == 1))
fpr = fp / (fp + tn) if (fp + tn) > 0 else 0
fnr = fn / (fn + tp) if (fn + tp) > 0 else 0

fpr_curve, tpr_curve, _ = roc_curve(true_labels, similarities, pos_label=1)
roc_auc = auc(fpr_curve, tpr_curve)

print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"False Positive Rate: {fpr:.4f}")
print(f"False Negative Rate: {fnr:.4f}")
print(f"ROC AUC: {roc_auc:.4f}")

wandb.log({
    'best_threshold': best_thresh,
    'best_thresh_acc': best_acc,
    'precision': precision,
    'recall': recall,
    'f1_score': f1,
    'FPR': fpr,
    'FNR': fnr,
    'ROC_AUC': roc_auc
})

plt.figure(figsize=(8, 6))
plt.plot(fpr_curve, tpr_curve, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve for Face Verification (Cosine Similarity)')
plt.legend(loc="lower right")
plt.show()

plt.figure(figsize=(8, 6))
plt.plot(thresholds, acc_list, color='blue', label='Accuracy')
plt.xlabel('Threshold')
plt.ylabel('Accuracy')
plt.title('Accuracy vs. Threshold (Cosine Similarity)')
plt.axvline(x=best_thresh, color='red', linestyle='--', label=f"Best threshold: {best_thresh:.2f}")
plt.legend()
plt.show()
plot_filename = "./ROCplots/ROC_CURVE"+name+".png"
plt.savefig(plot_filename)
plt.close()

wandb.log({'best_threshold': best_thresh, 'best_thresh_acc': best_acc})

artifact = wandb.Artifact("roc_plots", type="analysis")
artifact.add_file(plot_filename)

wandb.log_artifact(artifact)

wandb.log({"tsne_plot": wandb.Image(plot_filename)})
wandb.finish()
