# 1. Baseline

## Libraries

In [1]:
import os
from pathlib import Path
from tqdm import tqdm
from easydict import EasyDict

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
from torchvision import transforms as trans

from data.ms1m import get_train_loader
from data.lfw import LFW

from backbone.arcfacenet import SEResNet_IR
from margin.ArcMarginProduct import ArcMarginProduct

from util.utils import save_checkpoint, test

## Configuration

In [2]:
conf = EasyDict()

conf.train_root = "./dataset/MS1M/"
conf.lfw_test_root = "./dataset/lfw_aligned_112/lfw_aligned_112"
conf.lfw_file_list = "./dataset/lfw_pair.txt"

conf.mode = "se_ir" #' "ir"
conf.depth = 50 #100,152
conf.margin_type = "ArcFace"
conf.feature_dim = 512 #ArcFace
conf.scale_size = 32.0 #ArcFace
conf.batch_size = 64 #16,32,64,80 (6GB)
conf.lr = 0.01 #Learning rate
conf.milestones = [8,10,12]
conf.total_epoch = 2

conf.save_folder = "./saved"
conf.save_dir = os.path.join(conf.save_folder, conf.mode + "_" + str(conf.depth)) #./saved/se_ir_50
conf.device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
conf.num_workers = 4
conf.pin_memory = True

In [3]:
os.makedirs(conf.save_dir, exist_ok=True)

## Data Loader

In [4]:
transform = trans.Compose([
    trans.ToTensor(), #range [0,255] -> [0.0, 1.0]
    trans.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5)) #R-G-B 
])

train_loader, class_num = get_train_loader(conf)

lfwdataset = LFW(conf.lfw_test_root, conf.lfw_file_list, transform= transform)
lfwloader = torch.utils.data.DataLoader(lfwdataset, batch_size=128, num_workers=conf.num_workers)

## Model

In [5]:
net = SEResNet_IR(conf.depth, feature_dim=conf.feature_dim,mode=conf.mode).to(conf.device)
margin = ArcMarginProduct(conf.feature_dim, class_num).to(conf.device)

criteron = nn.CrossEntropyLoss()
optimizer = optim.SGD([
    {"params":net.parameters(), "weight_decay": 5e-4},
    {"params":margin.parameters(), "weight_decay": 5e-4}
], lr=conf.lr, momentum=0.9, nesterov=True)

def schedule_lr():
    for params in optimizer.param_groups:
        params["lr"] /= 10
    print(optimizer)

## Train

In [6]:
best_acc = 0

for epoch in range(1,conf.total_epoch+1):
    
    net.train()

    print("epoch {}/{}".format(epoch, conf.total_epoch))
    if epoch == conf.milestones[0]:
        schedule_lr()
    if epoch == conf.milestones[1]:
        schedule_lr()
    if epoch == conf.milestones[2]:
        schedule_lr()
    for data in tqdm(train_loader):
        img, label = data[0].to(conf.device), data[1].to(conf.device)
        optimizer.zero_grad()
        logits = net(img)
        output = margin(logits,label)
        total_loss = criteron(output, label)
        total_loss.backward()
        optimizer.step()
    #test
    net.eval()
    lfw_acc = test(conf,net,lfwdataset,lfwloader)
    print("\nLFW: {:.4f} | train_loss: {:.4f}\n].".format(lfw_acc,total_loss.item()))
    is_best = lfw_acc > best_acc
    best_acc = max(lfw_acc, best_acc)
    save_checkpoint({'epoch':epoch,
                    'net_state_dict':net.state_dict(),
                     'margin_state_dict': margin.state_dict(),
                     'best_acc': best_acc
                    }, is_best, checkpoint=conf.save_dir)






        
        

 

   

epoch 1/2


100%|████████████████████████████████████████████████████████████████████████████████| 456/456 [05:14<00:00,  1.45it/s]



LFW: 0.6847 | train_loss: 0.0000
].
best model saved

epoch 2/2


100%|████████████████████████████████████████████████████████████████████████████████| 456/456 [05:33<00:00,  1.37it/s]



LFW: 0.6812 | train_loss: 0.0000
].
