## Mount drive, unzip data

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!unzip /content/drive/MyDrive/fingerprint/train_data.zip

[1;30;43mKết quả truyền trực tuyến bị cắt bớt đến 5000 dòng cuối.[0m
  inflating: train_data/14484_v_shifted.jpg  
  inflating: train_data/14485_h_shifted.jpg  
  inflating: train_data/14494_zoom.jpg  
  inflating: train_data/14494_v_shifted.jpg  
  inflating: train_data/14497_h_shifted.jpg  
  inflating: train_data/14506_zoom.jpg  
  inflating: train_data/14515_rotated.jpg  
  inflating: train_data/14578_zoom.jpg  
  inflating: train_data/14578_rotated.jpg  
  inflating: train_data/14605_h_shifted.jpg  
  inflating: train_data/14606_original.jpg  
  inflating: train_data/14613_original.jpg  
  inflating: train_data/14621_zoom.jpg  
  inflating: train_data/14666_noise.jpg  
  inflating: train_data/14680_original.jpg  
  inflating: train_data/14693_v_shifted.jpg  
  inflating: train_data/14713_h_shifted.jpg  
  inflating: train_data/14763_v_shifted.jpg  
  inflating: train_data/14775_rotated.jpg  
  inflating: train_data/14785_v_shifted.jpg  
  inflating: train_data/14803_original.jpg

## Dataset class

In [3]:
import random
from PIL import Image

from torch.utils.data import Dataset


def get_img_label(img_fp):
    img_fn = img_fp.split('/')[-1]
    img_label = img_fn.split('_')[0]
    return int(img_label)


class TripletFingerprintDataset(Dataset):
    def __init__(self, imgs_fp, classes, transform=None):
        self.imgs_fp = imgs_fp
        self.classes = classes
        self.img_label_dict = {get_img_label(img_fp): [img_fp] for img_fp in self.imgs_fp}
        self.transform = transform

    def __getitem__(self, idx):
        def generate_triplets(idx):
            anchor_fp = self.imgs_fp[idx]
            anchor_class = get_img_label(anchor_fp)
            pos_fp = random.choice(self.img_label_dict[anchor_class])
            neg_class = random.randint(0, len(self.classes) - 1)
            while neg_class == anchor_class:
                neg_class = random.randint(0, len(self.classes) - 1)
            neg_fp = random.choice(self.img_label_dict[neg_class])
            anchor = Image.open(anchor_fp).convert('L')
            pos = Image.open(pos_fp).convert('L')
            neg = Image.open(neg_fp).convert('L')

            return anchor, pos, neg

        anchor, pos, neg = generate_triplets(idx)
        if self.transform is not None:
            anchor = self.transform(anchor)
            pos = self.transform(pos)
            neg = self.transform(neg)

        return (anchor, pos, neg), []

    def __len__(self):
        return len(self.imgs_fp)


## Network and loss

In [4]:
import torch.nn as nn


class EmbeddingNet(nn.Module):
    # Input size = (1, 128, 128)
    def __init__(self):
        super(EmbeddingNet, self).__init__()
        self.convnet = nn.Sequential(nn.Conv2d(1, 32, 3), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2),
                                     nn.Conv2d(32, 64, 3), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2),
                                     nn.Conv2d(64, 128, 3), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2)
                                     )

        self.fc = nn.Sequential(nn.Linear(128 * 14 * 14, 256),
                                nn.PReLU(),
                                nn.Linear(256, 256),
                                nn.PReLU(),
                                nn.Linear(256, 128)
                                )

    def forward(self, x):
        output = self.convnet(x)
        output = output.view(output.size()[0], -1)
        output = self.fc(output)
        return output

    def get_embedding(self, x):
        return self.forward(x)


class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super(TripletNet, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, anchor, pos, neg):
        anchor_embedding = self.embedding_net(anchor)
        pos_embedding = self.embedding_net(pos)
        neg_embedding = self.embedding_net(neg)
        return anchor_embedding, pos_embedding, neg_embedding

    def get_embedding(self, x):
        return self.embedding_net(x)


In [5]:
import torch.nn as nn
import torch.nn.functional as F


class TripletLoss(nn.Module):
    """
    Triplet loss
    Takes embeddings of an anchor sample, a positive sample and a negative sample
    """

    def __init__(self, margin=1.):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative, size_average=True):
        distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
        distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()


## Trainer

In [8]:
import torch
import numpy as np
from tqdm import tqdm
import math 

def fit(train_loader, val_loader, model, loss_fn, optimizer, scheduler, n_epochs, cuda, model_save_path='model.pth', start_epoch=0):
    for epoch in range(0, start_epoch):
        scheduler.step()

    best_loss = math.inf

    for epoch in range(start_epoch, n_epochs):
        scheduler.step()
        print('Epoch: {}/{}'.format(epoch + 1, n_epochs))

        # Train stage
        train_loss = train_epoch(train_loader, model, loss_fn, optimizer, cuda)

        message = '\nAverage training loss: {:.4f}. '.format(train_loss)

        val_loss = test_epoch(val_loader, model, loss_fn, cuda)
        val_loss /= len(val_loader)

        message += 'Average validating loss: {:.4f}'.format(val_loss)
        print(message)

        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), model_save_path, _use_new_zipfile_serialization=False)
            print('Saving best model...')

        print('\n' + '='*80 + '\n')

def train_epoch(train_loader, model, loss_fn, optimizer, cuda):
    model.train()
    losses = []
    total_loss = 0

    for batch_idx, (data, target) in enumerate(tqdm(train_loader, desc="Training epoch", position=0, leave=False)):
        target = target if len(target) > 0 else None
        if not type(data) in (tuple, list):
            data = (data,)
        if cuda:
            data = tuple(d.cuda() for d in data)
            if target is not None:
                target = target.cuda()

        optimizer.zero_grad()
        outputs = model(*data)

        if type(outputs) not in (tuple, list):
            outputs = (outputs,)

        loss_inputs = outputs
        if target is not None:
            target = (target,)
            loss_inputs += target

        loss_outputs = loss_fn(*loss_inputs)
        loss = loss_outputs[0] if type(loss_outputs) in (tuple, list) else loss_outputs
        losses.append(loss.item())
        total_loss += loss.item()
        loss.backward()
        optimizer.step()

        losses = []

    total_loss /= (batch_idx + 1)
    return total_loss


def test_epoch(val_loader, model, loss_fn, cuda):
    with torch.no_grad():
        model.eval()
        val_loss = 0
        for batch_idx, (data, target) in enumerate(tqdm(val_loader, desc="Validating epoch", position=0, leave=False)):
            target = target if len(target) > 0 else None
            if not type(data) in (tuple, list):
                data = (data,)
            if cuda:
                data = tuple(d.cuda() for d in data)
                if target is not None:
                    target = target.cuda()

            outputs = model(*data)

            if type(outputs) not in (tuple, list):
                outputs = (outputs,)
            loss_inputs = outputs
            if target is not None:
                target = (target,)
                loss_inputs += target

            loss_outputs = loss_fn(*loss_inputs)
            loss = loss_outputs[0] if type(loss_outputs) in (tuple, list) else loss_outputs
            val_loss += loss.item()

    return val_loss


## Train model

In [None]:
import glob
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
import torchvision.transforms as transforms


# Device
cuda = torch.cuda.is_available()
device = torch.device('cuda' if cuda else 'cpu')

# Hyperparameters
in_channel = 1
batch_size = 128
learning_rate = 0.001
step_size = 50
num_epochs = 40


# Load Data
img_dir = sorted(glob.glob('train_data/*.jpg'))
classes = [int(i) for i in range(get_img_label(img_dir[-1]))]

transforms = transforms.Compose([
    transforms.ToTensor()
])

dataset = TripletFingerprintDataset(img_dir, classes, transform=transforms)
lengths = [int(len(dataset)*0.8), int(len(dataset)*0.2)]
train_set, test_set = torch.utils.data.random_split(dataset, [int(len(dataset)*0.8), len(dataset)-int(len(dataset)*0.8)])

train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=True)

# Model
embedding_net = EmbeddingNet()
model = TripletNet(embedding_net)

model.to(device)

# Loss and Optimizer
margin = 1.
loss = TripletLoss(margin)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)

# Train Network
fit(train_loader, test_loader, model, loss, optimizer, scheduler, num_epochs, cuda)


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]

Epoch: 1/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0294. Average validating loss: 0.0151
Saving best model...


Epoch: 2/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0039. Average validating loss: 0.0007
Saving best model...


Epoch: 3/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0133. Average validating loss: 0.0735


Epoch: 4/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0075. Average validating loss: 0.0002
Saving best model...


Epoch: 5/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0068. Average validating loss: 0.0066


Epoch: 6/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0706. Average validating loss: 0.0095


Epoch: 7/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0026. Average validating loss: 0.0017


Epoch: 8/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0016. Average validating loss: 0.0009


Epoch: 9/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0048. Average validating loss: 0.0467


Epoch: 10/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0773. Average validating loss: 0.0012


Epoch: 11/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0006. Average validating loss: 0.0003


Epoch: 12/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0001. Average validating loss: 0.0011


Epoch: 13/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0006. Average validating loss: 0.0000
Saving best model...


Epoch: 14/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0006. Average validating loss: 0.0000


Epoch: 15/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0053. Average validating loss: 0.0008


Epoch: 16/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0043. Average validating loss: 0.0054


Epoch: 17/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0015. Average validating loss: 0.0000
Saving best model...


Epoch: 18/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0004. Average validating loss: 0.0011


Epoch: 19/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.1963. Average validating loss: 0.0408


Epoch: 20/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0089. Average validating loss: 0.0021


Epoch: 21/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0013. Average validating loss: 0.0004


Epoch: 22/40


Training epoch:   0%|          | 0/743 [00:00<?, ?it/s]


Average training loss: 0.0006. Average validating loss: 0.0004


Epoch: 23/40


Training epoch:  18%|█▊        | 134/743 [00:42<03:13,  3.15it/s]