# 1. Baseline

## Libraries

In [53]:
import os
from pathlib import Path
from easydict import EasyDict as edict
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
from torchvision import transforms as trans

from data.ms1m import get_train_loader
from data.lfw import LFW

from backbone.arcfacenet import SEResNet_IR
from margin.ArcMarginProduct import ArcMarginProduct

from util.utils import save_checkpoint, test

## Configuration

In [21]:
conf = edict()

conf.train_root = './dataset/MS1M'
conf.lfw_test_root = './dataset/lfw_aligned_112'
conf.lfw_file_list = './dataset/lfw_pair.txt'

conf.mode = 'se_ir' # 'ir'
conf.depth = 50
conf.margin_type = 'ArcFace'
conf.feature_dim = 512 # as mentioned in article, getting best solution with 512.
conf.scale_size = 32.0
conf.batch_size = 96 # use 16 for low memory use (goes like 16 x k)
conf.lr = 0.01 # learning rate / batch size and learning rate has linear relation. (If one decreases, the other decreases too.)
conf.milestones = [8, 10, 12]
conf.total_epoch = 14

conf.save_folder = './saved'
conf.save_dir = os.path.join(conf.save_folder, conf.mode + '_' + str(conf.depth))  # ./saved/se_ir_50
conf.device = torch.device('cpu')
conf.num_workers = 4
conf.pin_memory = True

In [22]:
os.makedirs(conf.save_dir, exist_ok=True)

## Data Loader

In [25]:
transform = trans.Compose([
    trans.ToTensor(),  # [0, 255] -> [0.0 , 1.0] scaling
    trans.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

trainloader, class_num = get_train_loader(conf)

In [26]:
print('Number of ID: ', class_num)

Number of ID:  200


In [27]:
print(trainloader.dataset)

Dataset ImageFolder
    Number of datapoints: 29148
    Root location: ./dataset/MS1M
    StandardTransform
Transform: Compose(
               RandomHorizontalFlip(p=0.5)
               ToTensor()
               Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
           )


In [29]:
lfwdataset = LFW(conf.lfw_test_root, conf.lfw_file_list, transform=transform)
lfwloader = torch.utils.data.DataLoader(lfwdataset, batch_size=128, num_workers=conf.num_workers)

# Model

In [32]:
net = SEResNet_IR(conf.depth, feature_dim=conf.feature_dim, mode=conf.mode).to(conf.device)
margin = ArcMarginProduct(conf.feature_dim, class_num).to(conf.device)

In [33]:
print(net)

SEResNet_IR(
  (input_layer): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=64)
  )
  (output_layer): Sequential(
    (0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Dropout(p=0.4, inplace=False)
    (2): Flatten()
    (3): Linear(in_features=25088, out_features=512, bias=True)
    (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (body): Sequential(
    (0): BottleNeck_IR_SE(
      (shortcut_layer): MaxPool2d(kernel_size=1, stride=2, padding=0, dilation=1, ceil_mode=False)
      (res_layer): Sequential(
        (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (2): BatchNorm2d(64, eps=1e-05, moment

In [34]:
criterion = nn.CrossEntropyLoss()

In [35]:
optimizer = optim.SGD([
    {'params':net.parameters(), 'weight_decay': 5e-4},  # weight decay prevent overfitting.
    {'params':margin.parameters(), 'weight_decay': 5e-4}
], lr=conf.lr, momentum=0.9, nesterov=True)

In [36]:
print(optimizer)

SGD (
Parameter Group 0
    dampening: 0
    lr: 0.01
    maximize: False
    momentum: 0.9
    nesterov: True
    weight_decay: 0.0005

Parameter Group 1
    dampening: 0
    lr: 0.01
    maximize: False
    momentum: 0.9
    nesterov: True
    weight_decay: 0.0005
)


In [40]:
def schedule_lr():
    for params in optimizer.param_groups:
        params['lr'] /= 10
    print(optimizer, flush=True)

## Train

In [None]:
best_acc = 0

for epoch in range(1, conf.total_epoch+1):
    
    net.train()
    
    print(f'# epoch: {epoch}/{conf.total_epoch}', flush=True)
    
    if epoch == conf.milestones[0]:  # 8
        schedule_lr()
    if epoch == conf.milestones[1]:  # 10
        schedule_lr()
    if epoch == conf.milestones[2]:  # 12
        schedule_lr()    
    
    for data in tqdm(trainloader):
        img, label = data[0].to(conf.device), data[1].to(conf.device)
        optimizer.zero_grad()
        
        logits = net(img)
        output = margin(logits, label)
        total_loss = criterion(output, label)
        total_loss.backward()
        optimizer.step()
        
    
    # test
    
    net.eval()
    
    lfw_acc = test(conf, net, lfwdataset, lfwloader)
    print('\nLFW: {:.4f} | train_loss: {:.4f}\n'.format(lfw_acc, total_loss.item()))
    
    is_best = lfw_acc > best_acc
    best_acc = max(lfw_acc, best_acc)
    
    save_checkpoint({
        'epoch': epoch,
        'net_state_dict': net.state_dict(),
        'margin_state_dict': margin.state_dict(),
        'best_acc': best_acc
    }, is_best, checkpoint=conf.save_dir)
        

# epoch: 1/14


  2%|▋                                      | 5/304 [17:58<17:53:54, 215.50s/it]

In [None]:
""" 
    |For Best Solution|

SOTA: The State of the Art
1. Download the whole MS1M dataset. (or CASIA)
2. conf.mode = 'ir'
3. conf.depth = 100
4. conf.total_epoch = 20
5. conf.milestones = [12, 16, 18]
6. conf.device = torch.device('cuda')


    |Result|

=> 99.83%

Check out MobileFaceNet for using computer vision in mobile devices.
"""