In [1]:
import numpy as np

from torchvision.datasets import ImageFolder
from torchvision import transforms

from facenet_pytorch import MTCNN, fixed_image_standardization, training, extract_face
from facenet_pytorch import InceptionResnetV1, training
from torch.optim import Adam
import utils

from torch.utils.data import DataLoader, SubsetRandomSampler, SequentialSampler
from triplets_loader import TripletsDataset
from losses.triplet_loss import TripletLoss

import tqdm
import torch

data_dir = 'lfw_cropped'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 32

### Load the Dataset

In [3]:
from losses.contrastive_loss import ContrastiveLoss, ContrastiveDataset


dataset = ContrastiveDataset(csv_file='lfw_cropped_annots.csv')
loader = DataLoader(dataset, 
                    num_workers=4,
                    pin_memory=True,
                    batch_size=batch_size, 
                    shuffle=True)

### Define the model, loss and optimizer

In [4]:
# Create an inception resnet (in train mode):
resnet = InceptionResnetV1(
    classify=False,
    num_classes=len(dataset.class_to_idx)
    ).to(device)

# Using Adam optimizer
optimizer = torch.optim.AdamW(resnet.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [5, 10])

metrics = {
    'fps': training.BatchTimer(),
    'acc': training.accuracy
}

In [5]:
torch.cuda.empty_cache()
contras_loss = ContrastiveLoss().to(device)
for i, (anchor, target_img, label) in enumerate(loader):
    anchor = anchor.to(device)
    target_img = target_img.to(device)
    embeddings_anchor = resnet(anchor)
    embeddings_target = resnet(target_img)
    label = label.to(device)
    loss = contras_loss(embeddings_anchor, embeddings_target, label)
    break

loss

tensor(1.2647, device='cuda:0', grad_fn=<MeanBackward0>)

In [7]:
for epoch in range(1, 3):
    resnet.train()
    train_loss = 0.0
    train_num = 0
    for i, (anchor, target_img, label) in enumerate(loader):
        anchor = anchor.to(device)
        target_img = target_img.to(device)
        
        # Compute embeddings
        embeddings_anchor = resnet(anchor)
        embeddings_target = resnet(target_img)
        
        label = label.to(device)
        loss = contras_loss(embeddings_anchor, embeddings_target, label)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() 
        train_num += 1
    
    scheduler.step()
    train_loss /= train_num
    print("Epoch: {}, Loss: {}".format(epoch, train_loss))

Epoch: 1, Loss: 0.8118770112143991


KeyboardInterrupt: 

In [None]:
# Save model
torch.save(resnet, 'resnet_contrLoss_17_epochs.pt')

# Remove to free up GPU memory
del resnet
torch.cuda.empty_cache()