In [35]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from PIL import Image, ImageFile
from torch.nn import Parameter,Module, Sequential
from torch.nn import Conv2d,BatchNorm2d, PReLU,Flatten,BatchNorm1d, Linear
from torch import optim
from torch.optim.lr_scheduler import StepLR
from torch.nn import CrossEntropyLoss
from torchsummary import summary
import numpy as np
import cv2
import torch
import math,os

In [36]:
# Mobile Facenet
def l2_norm(input, axis = 1):
    norm = torch.norm(input, 2, axis, True)
    output = torch.div(input, norm)
    return output

class Flatten(Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class Conv_block(Module):
    def __init__(self, in_c, out_c, kernel = (1,1), stride = (1,1), padding = (0,0), groups = 1):
        super(Conv_block, self).__init__()
        self.conv = Conv2d(in_channels=in_c, 
                           out_channels = out_c, 
                           kernel_size = kernel,
                           groups = groups,
                           stride = stride,
                           padding = padding,
                           bias = False)
        self.bn = BatchNorm2d(num_features =out_c)
        self.prelu = PReLU(num_parameters = out_c)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.prelu(x)
        return x
    
class Linear_block(Module):
    def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
        super(Linear_block, self).__init__()
        self.conv = Conv2d(in_channels=in_c,
                          out_channels= out_c,
                          kernel_size = kernel,
                          groups = groups,
                          stride = stride,
                          padding = padding,
                          bias = False)
        self.bn = BatchNorm2d(out_c)
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x
    
class Depth_Wise(Module):
    def __init__(self, in_c, out_c, residual = False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
        super(Depth_Wise, self).__init__()
        self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
        self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride)
        self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
        self.residual = residual
        
    def forward(self, x):
        if self.residual:
            short_cut = x
        x = self.conv(x)
        x = self.conv_dw(x)
        x = self.project(x)
        if self.residual:
            output = short_cut + x
        else:
            output = x
        return output
    
class Residual(Module):
    def __init__(self, c, num_block, groups, kernel = (3,3), stride = (1,1), padding = (1,1)):
        super(Residual, self).__init__()
        modules = []
        for _ in range(num_block):
            modules.append(Depth_Wise(c, c, residual = True, kernel = kernel, padding = padding, stride = stride, groups = groups ))
        self.model = Sequential(*modules)
    def forward(self, x):
        return self.model(x)
        
class MobileFaceNet(Module):
    def __init__(self, embedding_size):
        super(MobileFaceNet, self).__init__()
        self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
        self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
        self.conv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128)
        self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
        self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512)
        self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv_6_sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
        self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7,7), stride=(1, 1), padding=(0, 0))
        self.conv_6_flatten = Flatten()
        self.linear = Linear(512, embedding_size, bias=False)
        self.bn = BatchNorm1d(embedding_size)
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2_dw(out)
        out = self.conv_23(out)
        out = self.conv_3(out)
        out = self.conv_34(out)
        out = self.conv_4(out)
        out = self.conv_45(out)
        out = self.conv_5(out)
        out = self.conv_6_sep(out)
        out = self.conv_6_dw(out)
        out = self.conv_6_flatten(out)
        out = self.linear(out)
        out = self.bn(out)
        return l2_norm(out)
        

In [37]:
# Arcface head
class Arcface(Module):
#     def __init__(self, embedding_size = 512, classnum = 51332, s = 64, m = 0.5):
    def __init__(self, embedding_size = 512, classnum = 3, s = 64, m = 0.5):
        super(Arcface, self).__init__()
        self.classnum = classnum
        self.kernel = Parameter(torch.Tensor(embedding_size, classnum))
        
        # initial kernel
        self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
        self.m = m 
        self.s = s
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.mm = self.sin_m * m
        self.threshold = math.cos(math.pi - m)
        
    def forward(self, embbedings, label):
        # weights norm
        nB = len(embbedings)
        kernel_norm = l2_norm(self.kernel,axis=0)
        # cos(theta+m)
        cos_theta = torch.mm(embbedings,kernel_norm)  # inner_product
#         output = torch.mm(embbedings,kernel_norm)
        cos_theta = cos_theta.clamp(-1,1) # for numerical stability
        cos_theta_2 = torch.pow(cos_theta, 2)
        sin_theta_2 = 1 - cos_theta_2
        sin_theta = torch.sqrt(sin_theta_2)
        
        # cos(theta + m)
        cos_theta_m = (cos_theta * self.cos_m - sin_theta * self.sin_m)
        
        # this condition controls the theta+m should in range [0, pi]
        #      0<=theta+m<=pi
        #     -m<=theta<=pi-m
        cond_v = cos_theta - self.threshold
        cond_mask = cond_v <= 0
        keep_val = (cos_theta - self.mm) # when theta not in [0,pi], use cosface instead
        cos_theta_m[cond_mask] = keep_val[cond_mask]
        output = cos_theta * 1.0 # a little bit hacky way to prevent in_place operation on cos_theta
        idx_ = torch.arange(0, nB, dtype=torch.long)
        output[idx_, label] = cos_theta_m[idx_, label]
        output *= self.s # scale up in order to make softmax work, first introduced in normface
        return output

In [38]:
class faceLoader:
    def __init__(self, data_root, batch_size, shuffle = True):
            self.data_root = data_root
            self.batch_size = batch_size
            self.shuffle = shuffle
            
    def get_loader(self, img_size = [112,112]):
        train_transforms = transforms.Compose([
            transforms.Resize(img_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
        ])
        
        train_datasets = ImageFolder(self.data_root, train_transforms)
        train_loader = DataLoader(train_datasets, batch_size = self.batch_size, num_workers = 4, pin_memory = True)
        
        num_classes = train_datasets[-1][1] + 1
        
        return train_loader, num_classes

In [47]:
class faceTrainer:
    def __init__(self, device, dataloader, embedding_size= 512):
        print('Trainer Initializing')
        
        self.step = 0
        self.device = device
        self.model = MobileFaceNet(embedding_size)
        print(summary(self.model, (3,112,112)))
        
        if torch.cuda.device_count() >1:
            print('CUDAs', torch.cuda.device_count(), 'GPUs')
        self.model = self.model.to(self.device)
        
        self.train_loader, self.class_num = dataloader.get_loader()
        
        self.header = Arcface(embedding_size = embedding_size, classnum = self.class_num).to(self.device)
        self.optimizer = optim.SGD(self.model.parameters(), lr = 1e-1, weight_decay = 5e-4)
        self.scheduler = StepLR(self.optimizer, step_size = 10, gamma = 0.1)
        self.loss = CrossEntropyLoss()
        
    def train(self, epochs, print_freq):
        self.model.train()
        
        for epoch in range(epochs):
            self.step = 0
            train_loss = 0
            correct = 0
            total = 0
            
            for imgs, labels in iter(self.train_loader):
                
                imgs = imgs.to(self.device)
                labels = labels.to(self.device)
                
                # Gradient Initialization
                self.optimizer.zero_grad()
                
                embeddings = self.model(imgs)
                thetas = self.header(embeddings, labels)
                output = self.loss(thetas, labels)
                output.backward()
                train_loss += output.item()
                
                self.optimizer.step()
                
                if self.step % print_freq == 0 and self.step !=0:
                    print('epoch:', epoch, 'step:', self.step, 'loss', output.item())
                    
                self.step +=1
                
            loss_avg = train_loss/len(self.train_loader)
            if not os.path.isdir('./data/weights_lr'):
                os.makedirs('./data/weights_lr')
            
            torch.save(self.header.state_dict(), f'./data/weights_lr/{str(epoch)}_{str(loss_avg)}.pth')
            print('epoch:', epoch, 'loss_avg', loss_avg)

            self.scheduler.step()


In [49]:
def start():
    print('Strated')
    data_root = './db/small_vgg/train'
#     batch_size = 64
    batch_size = 8

    classnum = 3
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    print('data Loading')
    dataloader = faceLoader(data_root, batch_size, shuffle = True)
    trainer = faceTrainer(device, dataloader, embedding_size = 512)
    print('Begin Trianing on:', device)
    trainer.train(50, 100)

In [50]:
start()


Strated
data Loading
Trainer Initializing
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 56, 56]           1,728
       BatchNorm2d-2           [-1, 64, 56, 56]             128
             PReLU-3           [-1, 64, 56, 56]              64
        Conv_block-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]             576
       BatchNorm2d-6           [-1, 64, 56, 56]             128
             PReLU-7           [-1, 64, 56, 56]              64
        Conv_block-8           [-1, 64, 56, 56]               0
            Conv2d-9          [-1, 128, 56, 56]           8,192
      BatchNorm2d-10          [-1, 128, 56, 56]             256
            PReLU-11          [-1, 128, 56, 56]             128
       Conv_block-12          [-1, 128, 56, 56]               0
           Conv2d-13          [-1, 128, 28, 28]           1,1

epoch: 0 step: 100 loss 15.133891105651855
epoch: 0 loss_avg 20.697264048058216
epoch: 1 step: 100 loss 9.070046424865723
epoch: 1 loss_avg 17.075614635399948
epoch: 2 step: 100 loss 3.2182557582855225
epoch: 2 loss_avg 18.00717204762256
epoch: 3 step: 100 loss 2.502995014190674
epoch: 3 loss_avg 15.694953623719103
epoch: 4 step: 100 loss 4.19165563583374
epoch: 4 loss_avg 15.482749274396522
epoch: 5 step: 100 loss 5.9325032234191895
epoch: 5 loss_avg 18.871565285630112
epoch: 6 step: 100 loss 3.4222664833068848
epoch: 6 loss_avg 12.876978333540789
epoch: 7 step: 100 loss 4.630937099456787
epoch: 7 loss_avg 14.121239853656197
epoch: 8 step: 100 loss 2.8835349082946777
epoch: 8 loss_avg 11.823997749118355
epoch: 9 step: 100 loss 2.6124894618988037
epoch: 9 loss_avg 12.173814538895614
epoch: 10 step: 100 loss 9.021117210388184
epoch: 10 loss_avg 26.493115719847793
epoch: 11 step: 100 loss 12.610273361206055
epoch: 11 loss_avg 25.287000832595226


KeyboardInterrupt: 