In [None]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [None]:
!nvidia-smi -L


GPU 0: Tesla T4 (UUID: GPU-578476da-1f50-a56f-cf8d-7524ffb9fdc1)


In [None]:
!unzip "/content/drive/MyDrive/DATN/data/roi_extraction.zip"

In [None]:
import os
from shutil import copy, rmtree
import random
def mk_file(file_path: str):
    if os.path.exists(file_path):
        rmtree(file_path)
    os.makedirs(file_path)

def main():
    random.seed(0)

    split_rate = 0.2

    cwd = os.getcwd()
    data_root = os.path.join(cwd, "/content/data")
    origin_palmpPrint_path = os.path.join("/content/roi_extraction")
    assert os.path.exists(origin_palmpPrint_path), "path '{}' does not exist.".format(origin_palmpPrint_path)

    palmPrint_class = [cla for cla in os.listdir(origin_palmpPrint_path)]

    train_root = os.path.join(data_root, "train")
    mk_file(train_root)
    for cla in palmPrint_class:
        mk_file(os.path.join(train_root, cla))

    val_root = os.path.join(data_root, "val")
    mk_file(val_root)
    for cla in palmPrint_class:
        mk_file(os.path.join(val_root, cla))

    for cla in palmPrint_class:
        cla_path = os.path.join(origin_palmpPrint_path, cla)
        images = os.listdir(cla_path)
        num = len(images)
        eval_index = random.sample(images, k=int(num*split_rate))
        for index, image in enumerate(images):
            if image in eval_index:
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(val_root, cla)
                copy(image_path, new_path)
            else:
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(train_root, cla)
                copy(image_path, new_path)
            print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
        print()

    print("processing done")

if __name__ == '__main__':
    main()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import label_binarize
from typing import Tuple

def binarize_and_smooth_labels(labels: torch.Tensor,
                               n_classes: int,
                               smoothing_factor: float,
                               device: torch.device
                               ) -> torch.FloatTensor:
    labels = labels.cpu().numpy()
    labels = label_binarize(labels, classes=range(0, n_classes))
    labels = labels * (1 - smoothing_factor)
    labels[labels == 0] = smoothing_factor / (n_classes - 1)
    labels = torch.from_numpy(labels).float().to(device)
    return labels


class ProxyNCALoss(nn.Module):
    def __init__(self,
                 n_classes: int,
                 embedding_size: int,
                 embedding_scale: float,
                 proxy_scale: float,
                 smoothing_factor: float,
                 device: torch.device,
                 ):
        super().__init__()

        self.device: torch.device = device
        self.n_classes: int = n_classes
        self.embedding_size: int = embedding_size
        self.embedding_scale: float = embedding_scale

        self.proxies = nn.Parameter(torch.randn(n_classes, embedding_size) / 8).to(device)
        self.proxy_scale: float = proxy_scale
        self.smoothing_factor: float = smoothing_factor

    def forward(self, embeddings: torch.Tensor, labels: torch.Tensor) -> Tuple[torch.Tensor, int]:
        # embeddings shape: [batch_size * embedding_size]
        embeddings = embeddings * self.embedding_scale
        # proxies shape: [n_classes * embedding_size]
        proxies: torch.Tensor = F.normalize(self.proxies, p=2, dim=1) * self.proxy_scale
        # distances shape: [batch_size * n_classes]
        distances: torch.Tensor = torch.cdist(embeddings, proxies).square()

        # labels shape: [batch_size * n_classes]
        labels = binarize_and_smooth_labels(labels, self.n_classes, self.smoothing_factor, self.device)
        proxy_nca_loss: torch.Tensor = (-labels * F.log_softmax(-distances, dim=1)).sum(dim=1)

        return proxy_nca_loss.mean(), 0


class ProxyAnchorLoss(nn.Module):
    def __init__(self, n_classes: int, embedding_size: int, margin: float, alpha: float, device: torch.device):
        super().__init__()

        self.device: torch.device = device
        self.n_classes: int = n_classes
        self.embedding_size: int = embedding_size

        # shape: [n_classes * embedding_size]
        self.proxies = nn.Parameter(torch.rand(n_classes, embedding_size)).to(device)
        nn.init.kaiming_normal_(self.proxies, mode="fan_out")

        self.margin: float = margin
        self.alpha: float = alpha

    def forward(self, embeddings: torch.Tensor, labels: torch.Tensor) -> Tuple[torch.Tensor, int]:
        # proxies shape: [n_classes * embedding_size]
        proxies: torch.Tensor = F.normalize(self.proxies, p=2, dim=1)

        # cosine_distances shape: [batch_size * n_classes]
        cosine_distances = F.linear(embeddings, proxies)
        # positive_exp shape: [batch_size * n_classes]
        positive_exp = torch.exp(-self.alpha * (cosine_distances - self.margin))
        # negative_exp shape: [batch_size * n_classes]
        negative_exp = torch.exp(self.alpha * (cosine_distances + self.margin))

        # shape: [batch_size * n_classes]
        labels_onehot: torch.Tensor = binarize_and_smooth_labels(
            labels, self.n_classes, smoothing_factor=0, device=self.device
        )
        # Indices of all positive proxies in batch, shape [n_classes]
        distinct_proxies_indices = torch.nonzero(labels_onehot.sum(dim=0) != 0).squeeze(dim=1)
        n_distinct_proxies: int = len(distinct_proxies_indices)

        # positive_distances: distance from a pariticular proxy to all positive samples in a batch
        # shape: [n_classes]
        sum_positive_distances = torch.where(
            labels_onehot == 1, positive_exp, torch.zeros_like(positive_exp)
        ).sum(dim=0)

        # negative distances: distance from a particular proxy to all negative samples in a batch
        # shape: [n_classes]
        sum_negative_distances = torch.where(
            labels_onehot == 0, negative_exp, torch.zeros_like(negative_exp)
        ).sum(dim=0)

        positive_term = torch.log(1 + sum_positive_distances).sum() / n_distinct_proxies
        negative_term = torch.log(1 + sum_negative_distances).sum() / self.n_classes
        proxy_anchor_loss: torch.Tensor = positive_term + negative_term
        return proxy_anchor_loss, 0

class SoftTripleLoss(nn.Module):
    def __init__(self,
                 n_classes: int,
                 embedding_size: int,
                 n_centers_per_class: int,
                 lambda_: float,
                 gamma: float,
                 tau: float,
                 margin: float,
                 device: torch.device
                 ):
        super().__init__()

        self.device: torch.device = device
        self.n_classes: int = n_classes
        self.embedding_size: int = embedding_size
        self.n_centers_per_class: int = n_centers_per_class

        self.lambda_: float = lambda_
        self.gamma: float = gamma
        self.tau: float = tau
        self.margin: float = margin

        # Each class has n centers
        self.centers: torch.Tensor = nn.Parameter(
            torch.rand(embedding_size, n_centers_per_class * n_classes)
        ).to(device)
        nn.init.kaiming_uniform_(self.centers, a=5**0.5)

        # weight for regularization term
        self.weight: torch.Tensor = torch.zeros(
            n_classes * n_centers_per_class,
            n_classes * n_centers_per_class,
            dtype=torch.long
        ).to(device)

        for i in range(n_classes):
            for j in range(n_centers_per_class):
                self.weight[
                    i * n_centers_per_class + j,
                    i * n_centers_per_class + j + 1:(i + 1) * n_centers_per_class
                ] = 1

    def forward(self, embeddings: torch.Tensor, labels: torch.Tensor) -> Tuple[torch.Tensor, int]:
        # labels shape: [n_classes]
        labels = labels.to(self.device)
        # centers shape: [embedding_size * (n_classes * n_centers_per_class)]
        centers = F.normalize(self.centers, p=2, dim=0)

        # Distance from each embedding to all centers
        # distances shape: [batch_size * n_classes * n_centers_per_class]
        distances = embeddings.matmul(centers).reshape(-1, self.n_classes, self.n_centers_per_class)

        # probabilities shape: [batch_size * n_classes * n_centers_per_class]
        probabilities = F.softmax(distances * self.gamma, dim=2)

        # Distance from each embedding to its TRUE center in a particular class
        # will depend on the distances betweeen it and all centers in that class.
        # relaxed_distances shape: [batch_size * n_classes]
        relaxed_distances = torch.sum(probabilities * distances, dim=2)

        margin = torch.zeros_like(relaxed_distances)
        margin[torch.arange(0, len(embeddings)), labels] = self.margin
        soft_triple_loss = F.cross_entropy(self.lambda_ * (relaxed_distances - margin), labels)

        if self.tau > 0 and self.n_centers_per_class > 1:
            # Distances between all pairs of centers
            distances_centers = centers.t().matmul(centers)
            dominator = torch.sum(torch.sqrt(2.0 + 1e-5 - 2. * distances_centers[self.weight]))
            denominator = (self.n_classes * self.n_centers_per_class * (self.n_centers_per_class - 1.))
            regularization = dominator / denominator
            return soft_triple_loss + self.tau * regularization, 0

        else:
            return soft_triple_loss, 0

class TripletMarginLoss(nn.Module):
    def __init__(self, margin=1.0, p=2.0, sampling_type="batch_hard_triplets"):
        super().__init__()
        self.margin: float = margin
        self.p: float = p
        self.sampling_type: str = sampling_type

    def forward(self, embeddings: torch.Tensor, labels: torch.Tensor) -> Tuple[torch.Tensor, float]:
        if self.sampling_type == "batch_hard_triplets":
            return _batch_hard_triplets_loss(labels, embeddings, self.margin, self.p)
        elif self.sampling_type == "batch_hardest_triplets":
            return _batch_hardest_triplets_loss(labels, embeddings, self.margin, self.p)
        else:
            raise NotImplementedError(self.sampling_type)


def _batch_hard_triplets_loss(labels: torch.Tensor,
                              embeddings: torch.Tensor,
                              margin: float,
                              p: float
                              ) -> Tuple[torch.Tensor, float]:

    pairwise_distance: torch.Tensor = torch.cdist(embeddings, embeddings, p=p)

    anchor_positive_distance: torch.Tensor = pairwise_distance.unsqueeze(2)
    anchor_negative_distance: torch.Tensor = pairwise_distance.unsqueeze(1)

    # Indexes of all triplets
    mask: torch.Tensor = _get_triplet_masks(labels)
    # Calucalate triplet loss
    triplet_loss: torch.Tensor = mask.float() * (anchor_positive_distance - anchor_negative_distance + margin)

    # Remove negative loss (easy triplets)
    triplet_loss[triplet_loss < 0] = 0

    # Count number of positive triplets (where triplet_loss > 0)
    hard_triplets: torch.Tensor = triplet_loss[triplet_loss > 1e-16]
    n_hard_triplets: int = hard_triplets.size(0)

    # Total triplets (including positive and negative triplet)
    n_triplets: int = mask.sum().item()
    # Fraction of postive triplets in total
    fraction_hard_triplets: float = n_hard_triplets / (n_triplets + 1e-16)

    # Get final mean triplet loss over the positive valid triplets
    triplet_loss = triplet_loss.sum() / (n_hard_triplets + 1e-16)
    return triplet_loss, fraction_hard_triplets


def _batch_hardest_triplets_loss(labels: torch.Tensor,
                                 embeddings: torch.Tensor,
                                 margin: float,
                                 p: float
                                 ) -> Tuple[torch.Tensor, int]:

    pairwise_distance: torch.Tensor = torch.cdist(embeddings, embeddings, p=p)
    # Indexes of all triplets
    mask_anchor_positive: torch.Tensor = _get_anchor_positive_mask(labels).float()
    # Distance between anchors and positives
    anchor_positive_distance: torch.Tensor = mask_anchor_positive * pairwise_distance

    # Hardest postive for every anchor
    hardest_positive_distance, _ = anchor_positive_distance.max(1, keepdim=True)

    mask_anchor_negative: torch.Tensor = _get_anchor_negative_mask(labels).float()
    # Add max value in each row to invalid negatives
    max_anchor_negative_distance, _ = pairwise_distance.max(dim=1, keepdim=True)
    anchor_negative_distance = pairwise_distance + max_anchor_negative_distance * (1.0 - mask_anchor_negative)

    # Hardest negative for every anchor
    hardest_negative_distance, _ = anchor_negative_distance.min(dim=1, keepdim=True)

    triplet_loss: torch.Tensor = hardest_positive_distance - hardest_negative_distance + margin
    triplet_loss[triplet_loss < 0] = 0
    triplet_loss = triplet_loss.mean()
    return triplet_loss, 0


def _get_triplet_masks(labels: torch.Tensor) -> torch.Tensor:
    # indices_equal is a square matrix, 1 in the diagonal and 0 everywhere else
    indices_equal: torch.Tensor = torch.eye(labels.size(0), dtype=torch.bool, device=labels.device)
    # indices_not_equal is inversed of indices_equal, 0 in the diagonal and 1 everywhere else
    indices_not_equal: torch.Tensor = ~indices_equal

    # convention: i (anchor index), j (positive index), k (negative index)
    i_not_equal_j: torch.Tensor = indices_not_equal.unsqueeze(2)
    i_not_equal_k: torch.Tensor = indices_not_equal.unsqueeze(1)
    j_not_equal_k: torch.Tensor = indices_not_equal.unsqueeze(0)
    # Check that anchor, positive, negative are distince to each other
    distinct_indices: torch.Tensor = (i_not_equal_j & i_not_equal_k) & j_not_equal_k

    label_equal: torch.Tensor = labels.unsqueeze(0) == labels.unsqueeze(1)
    i_equal_j: torch.Tensor = label_equal.unsqueeze(2)
    i_equal_k: torch.Tensor = label_equal.unsqueeze(1)

    # i_equal_j: indices of anchor-positive pairs
    # ~i_equal_k: indices of anchor-negative pairs
    # indices of valid triplets
    indices_triplets: torch.Tensor = i_equal_j & (~i_equal_k)
    # Make sure that anchor, positive, negative are distince to each other
    indices_triplets = indices_triplets & distinct_indices
    return indices_triplets


def _get_anchor_positive_mask(labels: torch.Tensor) -> torch.Tensor:
    # Check that i and j are distince
    indices_equal: torch.Tensor = torch.eye(labels.size(0), dtype=torch.bool, device=labels.device)
    indices_not_equal: torch.Tensor = ~indices_equal

    # Check anchor and negative
    labels_equal: torch.Tensor = labels.unsqueeze(0) == labels.unsqueeze(1)
    return labels_equal & indices_not_equal


def _get_anchor_negative_mask(labels: torch.Tensor) -> torch.BoolTensor:
    return labels.unsqueeze(0) != labels.unsqueeze(1)



In [None]:
from typing import Callable, List, Optional

import torch
from torch import nn, Tensor
from torch.nn import functional as F
from functools import partial


def _make_divisible(ch, divisor=8, min_ch=None):
    if min_ch is None:
        min_ch = divisor
    new_ch = max(min_ch, int(ch + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_ch < 0.9 * ch:
        new_ch += divisor
    return new_ch


class ConvBNActivation(nn.Sequential):
    def __init__(self,
                 in_planes: int,
                 out_planes: int,
                 kernel_size: int = 3,
                 stride: int = 1,
                 groups: int = 1,
                 norm_layer: Optional[Callable[..., nn.Module]] = None,
                 activation_layer: Optional[Callable[..., nn.Module]] = None):
        padding = (kernel_size - 1) // 2
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if activation_layer is None:
            activation_layer = nn.ReLU6
        super(ConvBNActivation, self).__init__(nn.Conv2d(in_channels=in_planes,
                                                         out_channels=out_planes,
                                                         kernel_size=kernel_size,
                                                         stride=stride,
                                                         padding=padding,
                                                         groups=groups,
                                                         bias=False),
                                               norm_layer(out_planes),
                                               activation_layer(inplace=True))


class SqueezeExcitation(nn.Module):
    def __init__(self, input_c: int, squeeze_factor: int = 4):
        super(SqueezeExcitation, self).__init__()
        squeeze_c = _make_divisible(input_c // squeeze_factor, 8)
        self.fc1 = nn.Conv2d(input_c, squeeze_c, 1)
        self.fc2 = nn.Conv2d(squeeze_c, input_c, 1)

    def forward(self, x: Tensor) -> Tensor:
        scale = F.adaptive_avg_pool2d(x, output_size=(1, 1))
        scale = self.fc1(scale)
        scale = F.relu(scale, inplace=True)
        scale = self.fc2(scale)
        scale = F.hardsigmoid(scale, inplace=True)
        return scale * x


class InvertedResidualConfig:
    def __init__(self,
                 input_c: int,
                 kernel: int,
                 expanded_c: int,
                 out_c: int,
                 use_se: bool,
                 activation: str,
                 stride: int,
                 width_multi: float):
        self.input_c = self.adjust_channels(input_c, width_multi)
        self.kernel = kernel
        self.expanded_c = self.adjust_channels(expanded_c, width_multi)
        self.out_c = self.adjust_channels(out_c, width_multi)
        self.use_se = use_se
        self.use_hs = activation == "HS"  # whether using h-swish activation
        self.stride = stride

    @staticmethod
    def adjust_channels(channels: int, width_multi: float):
        return _make_divisible(channels * width_multi, 8)


class InvertedResidual(nn.Module):
    def __init__(self,
                 cnf: InvertedResidualConfig,
                 norm_layer: Callable[..., nn.Module]):
        super(InvertedResidual, self).__init__()

        if cnf.stride not in [1, 2]:
            raise ValueError("illegal stride value.")

        self.use_res_connect = (cnf.stride == 1 and cnf.input_c == cnf.out_c)

        layers: List[nn.Module] = []
        activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU

        # expand
        if cnf.expanded_c != cnf.input_c:
            layers.append(ConvBNActivation(cnf.input_c,
                                           cnf.expanded_c,
                                           kernel_size=1,
                                           norm_layer=norm_layer,
                                           activation_layer=activation_layer))

        # depthwise
        layers.append(ConvBNActivation(cnf.expanded_c,
                                       cnf.expanded_c,
                                       kernel_size=cnf.kernel,
                                       stride=cnf.stride,
                                       groups=cnf.expanded_c,
                                       norm_layer=norm_layer,
                                       activation_layer=activation_layer))

        if cnf.use_se:
            layers.append(SqueezeExcitation(cnf.expanded_c))

        # project
        layers.append(ConvBNActivation(cnf.expanded_c,
                                       cnf.out_c,
                                       kernel_size=1,
                                       norm_layer=norm_layer,
                                       activation_layer=nn.Identity))

        self.block = nn.Sequential(*layers)
        self.out_channels = cnf.out_c
        self.is_strided = cnf.stride > 1

    def forward(self, x: Tensor) -> Tensor:
        result = self.block(x)
        if self.use_res_connect:
            result += x

        return result


class MobileNetV3(nn.Module):
    def __init__(self,
                 inverted_residual_setting: List[InvertedResidualConfig],
                 last_channel: int,
                 num_classes: int = 1000,
                 block: Optional[Callable[..., nn.Module]] = None,
                 norm_layer: Optional[Callable[..., nn.Module]] = None):
        super(MobileNetV3, self).__init__()

        if not inverted_residual_setting:
            raise ValueError("The inverted_residual_setting should not be empty.")
        elif not (isinstance(inverted_residual_setting, List) and
                  all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])):
            raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")

        if block is None:
            block = InvertedResidual

        if norm_layer is None:
            norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)

        layers: List[nn.Module] = []

        # building first layer
        firstconv_output_c = inverted_residual_setting[0].input_c
        layers.append(ConvBNActivation(3,
                                       firstconv_output_c,
                                       kernel_size=3,
                                       stride=2,
                                       norm_layer=norm_layer,
                                       activation_layer=nn.Hardswish))
        # building inverted residual blocks
        for cnf in inverted_residual_setting:
            layers.append(block(cnf, norm_layer))

        # building last several layers
        lastconv_input_c = inverted_residual_setting[-1].out_c
        lastconv_output_c = 6 * lastconv_input_c
        layers.append(ConvBNActivation(lastconv_input_c,
                                       lastconv_output_c,
                                       kernel_size=1,
                                       norm_layer=norm_layer,
                                       activation_layer=nn.Hardswish))
        self.features = nn.Sequential(*layers)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(nn.Linear(lastconv_output_c, last_channel),
                                        nn.Hardswish(inplace=True),
                                        nn.Dropout(p=0.2, inplace=True),
                                        nn.Linear(last_channel, num_classes))

        # initial weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)

    def _forward_impl(self, x: Tensor) -> Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)

        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)


def mobilenet_v3_large(num_classes: int = 1000,
                       reduced_tail: bool = False) -> MobileNetV3:
    width_multi = 1.0
    bneck_conf = partial(InvertedResidualConfig, width_multi=width_multi)
    adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_multi=width_multi)

    reduce_divider = 2 if reduced_tail else 1

    inverted_residual_setting = [
        # input_c, kernel, expanded_c, out_c, use_se, activation, stride
        bneck_conf(16, 3, 16, 16, False, "RE", 1),
        bneck_conf(16, 3, 64, 24, False, "RE", 2),  # C1
        bneck_conf(24, 3, 72, 24, False, "RE", 1),
        bneck_conf(24, 5, 72, 40, True, "RE", 2),  # C2
        bneck_conf(40, 5, 120, 40, True, "RE", 1),
        bneck_conf(40, 5, 120, 40, True, "RE", 1),
        bneck_conf(40, 3, 240, 80, False, "HS", 2),  # C3
        bneck_conf(80, 3, 200, 80, False, "HS", 1),
        bneck_conf(80, 3, 184, 80, False, "HS", 1),
        bneck_conf(80, 3, 184, 80, False, "HS", 1),
        bneck_conf(80, 3, 480, 112, True, "HS", 1),
        bneck_conf(112, 3, 672, 112, True, "HS", 1),
        bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2),  # C4
        bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1),
        bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1),
    ]
    last_channel = adjust_channels(1280 // reduce_divider)  # C5

    return MobileNetV3(inverted_residual_setting=inverted_residual_setting,
                       last_channel=last_channel,
                       num_classes=num_classes)
    

class Mobilenet_v3_large(nn.Module):

    def __init__(self, embedding_size: int, weight_path: str=None):
        super().__init__()

        model = mobilenet_v3_large()
        assert os.path.exists(weight_path), "Weight_path {} does not exists!".format(weight_path)
        model.load_state_dict(torch.load(weight_path, map_location="cpu"))
        # Features extraction layers without the last fully-connected
        self.features = nn.Sequential(*list(model.children())[:-1])
        # Embeddding layer
        self.embedding = nn.Sequential(
            nn.Linear(in_features=960, out_features=embedding_size)
        )

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        
        embedding: torch.Tensor = self.features(image)
        embedding = embedding.flatten(start_dim=1)

        embedding: torch.Tensor = self.embedding(embedding)
        embedding = F.normalize(embedding, p=2, dim=1)
        return embedding

In [None]:
import numpy as np
from torch.utils.data import Subset
from torchvision.datasets import ImageFolder
from PIL import ImageFile

from itertools import groupby
from typing import List, Tuple, Dict


ImageFile.LOAD_TRUNCATED_IMAGES = True


class Dataset(ImageFolder):
    def __init__(self, images_dir: str, transform=None):
        super().__init__(images_dir, transform=transform)

        self.idx_to_class: Dict[int, str] = {
            idx: class_name
            for class_name, idx in self.class_to_idx.items()
        }

    def __len__(self) -> int:
        return len(self.imgs)

    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)

        return sample, target


def get_subset_from_dataset(dataset: Dataset, n_samples_per_class: int) -> Dataset:
    # samples + corresponding indices in database_set
    samples: List[Tuple[int, str, int]] = [(i, *sample) for i, sample in enumerate(dataset.samples)]
    group_by_class_idx = groupby(samples, key=lambda sample: sample[2])  # group by class_index

    indices: List[int] = []
    for _, group in group_by_class_idx:
        group: List = list(group)
        indices_in_same_class, _, _ = zip(*group)
        indices_in_same_class: List[int] = np.random.choice(
            indices_in_same_class, size=n_samples_per_class, replace=False
        ).tolist()
        indices.extend(indices_in_same_class)

    subset: Dataset = Subset(dataset, indices=indices)
    return subset

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import os
import datetime
import random
from typing import List, Dict, Tuple, Any, Union


def set_random_seed(seed: int) -> None:
    """
    Set random seed for package random, numpy and pytorch
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_current_time() -> str:
    return datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')


def save_checkpoint(model: nn.Module,
                    config: Dict[str, Any],
                    current_epoch: int,
                    output_dir: str,
                    mean_average_precision: float = None,
                    ) -> str:

    checkpoint_name: str = f"epoch{current_epoch}"
    if mean_average_precision is not None:
        checkpoint_name += f"-map{mean_average_precision:.2f}"
    checkpoint_name += ".pth"

    checkpoint_path: str = os.path.join(output_dir, checkpoint_name)
    torch.save(
        {
            "config": config,
            "model_state_dict": model.module.state_dict(),
        },
        checkpoint_path
    )
    return checkpoint_path


def log_embeddings_to_tensorboard(loader: DataLoader,
                                  model: nn.Module,
                                  device: torch.device,
                                  writer: SummaryWriter,
                                  tag: str
                                  ) -> None:
    if tag == "train":
        if hasattr(loader.sampler, "sequential_sampling"):
            loader.sampler.sequential_sampling = True
    # Calculating embedding of training set for visualization
    embeddings, labels = get_embeddings_from_dataloader(loader, model, device)
    writer.add_embedding(embeddings, metadata=labels.tolist(), tag=tag)


@torch.no_grad()
def get_embeddings_from_dataloader(loader: DataLoader,
                                   model: nn.Module,
                                   device: torch.device,
                                   return_numpy_array=False,
                                   return_image_paths=False,
                                   ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[np.ndarray, np.ndarray]]:
    model.eval()

    embeddings_ls: List[torch.Tensor] = []
    labels_ls: List[torch.Tensor] = []
    for images_, labels_ in loader:
        images: torch.Tensor = images_.to(device, non_blocking=True)
        labels: torch.Tensor = labels_.to(device, non_blocking=True)
        embeddings: torch.Tensor = model(images)
        embeddings_ls.append(embeddings)
        labels_ls.append(labels)

    embeddings: torch.Tensor = torch.cat(embeddings_ls, dim=0)  # shape: [N x embedding_size]
    labels: torch.Tensor = torch.cat(labels_ls, dim=0)  # shape: [N]

    if return_numpy_array:
        embeddings = embeddings.cpu().numpy()
        labels = labels.cpu().numpy()

    if return_image_paths:
        images_paths: List[str] = []
        for path, _ in loader.dataset.samples:
            images_paths.append(path)
        return (embeddings, labels, images_paths)

    return (embeddings, labels)


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics.cluster import normalized_mutual_info_score
from typing import Tuple, Dict

def calculate_all_metrics(model: nn.Module,
                          test_loader: DataLoader,
                          ref_loader: DataLoader,
                          device: torch.device,
                          k: Tuple[int, int, int] = (1, 5, 10)
                          ) -> Dict[str, float]:

    # Calculate all embeddings of training set and test set
    embeddings_test, labels_test = get_embeddings_from_dataloader(test_loader, model, device)
    embeddings_ref, labels_ref = get_embeddings_from_dataloader(ref_loader, model, device)

    # Expand dimension for batch calculating
    embeddings_test = embeddings_test.unsqueeze(dim=0)  # [M x K] -> [1 x M x embedding_size]
    embeddings_ref = embeddings_ref.unsqueeze(dim=0)  # [N x K] -> [1 x N x embedding_size]
    labels_test = labels_test.unsqueeze(dim=1)  # [M] -> [M x 1]

    # Pairwise distance of all embeddings between test set and reference set
    distances: torch.Tensor = torch.cdist(embeddings_test, embeddings_ref, p=2).squeeze()  # [M x N]

    # Calculate precision_at_k on test set with k=1, k=5 and k=10
    metrics: Dict[str, float] = {}
    for i in k:
        metrics[f"average_precision_at_{i}"] = calculate_precision_at_k(distances,
                                                                        labels_test,
                                                                        labels_ref,
                                                                        k=i
                                                                        )
    # Calculate mean average precision (MAP)
    mean_average_precision: float = sum(precision_at_k for precision_at_k in metrics.values()) / len(metrics)
    metrics["mean_average_precision"] = mean_average_precision

    # Calculate top-1 and top-5 and top-10 accuracy
    for i in k:
        metrics[f"top_{i}_accuracy"] = calculate_topk_accuracy(distances,
                                                               labels_test,
                                                               labels_ref,
                                                               top_k=i
                                                               )
    # Calculate NMI score
    n_classes: int = len(test_loader.dataset.classes)
    metrics["normalized_mutual_information"] = calculate_normalized_mutual_information(
        embeddings_test.squeeze(), labels_test.squeeze(), n_classes
    )

    return metrics


def calculate_precision_at_k(distances: torch.Tensor,
                             labels_test: torch.Tensor,
                             labels_ref: torch.Tensor,
                             k: int
                             ) -> float:

    _, indices = distances.topk(k=k, dim=1, largest=False)  # indices shape: [M x k]

    y_pred = []
    for i in range(k):
        indices_at_k: torch.Tensor = indices[:, i]  # [M]
        y_pred_at_k: torch.Tensor = labels_ref[indices_at_k].unsqueeze(dim=1)  # [M x 1]
        y_pred.append(y_pred_at_k)

    y_pred: torch.Tensor = torch.hstack(y_pred)  # [M x k]
    labels_test = torch.hstack((labels_test,) * k)  # [M x k]

    precision_at_k: float = ((y_pred == labels_test).sum(dim=1) / k).mean().item() * 100
    return precision_at_k


def calculate_topk_accuracy(distances: torch.Tensor,
                            labels_test: torch.Tensor,
                            labels_ref: torch.Tensor,
                            top_k: int
                            ) -> float:

    _, indices = distances.topk(k=top_k, dim=1, largest=False)  # indices shape: [M x k]

    y_pred = []
    for i in range(top_k):
        indices_at_k: torch.Tensor = indices[:, i]  # [M]
        y_pred_at_k: torch.Tensor = labels_ref[indices_at_k].unsqueeze(dim=1)  # [M x 1]
        y_pred.append(y_pred_at_k)

    y_pred: torch.Tensor = torch.hstack(y_pred)  # [M x k]
    labels_test = torch.hstack((labels_test,) * top_k)  # [M x k]

    n_predictions: int = y_pred.shape[0]
    n_true_predictions: int = ((y_pred == labels_test).sum(dim=1) > 0).sum().item()
    topk_accuracy: float = n_true_predictions / n_predictions * 100
    return topk_accuracy


def calculate_normalized_mutual_information(embeddings: torch.Tensor,
                                            labels_test: torch.Tensor,
                                            n_classes: int
                                            ) -> float:
    embeddings = embeddings.cpu().numpy()
    y_test: np.ndarray = labels_test.cpu().numpy().astype(np.int)

    y_pred: np.ndarray = KMeans(n_clusters=n_classes).fit(embeddings).labels_
    NMI_score: float = normalized_mutual_info_score(y_test, y_pred)

    return NMI_score

In [None]:
import numpy as np
import torch
from torch.utils.data.sampler import Sampler
from collections import defaultdict

import random
from typing import List, DefaultDict, Dict, Iterator, Any


def _create_groups(targets: List[int], samples_per_class: int) -> DefaultDict[int, List[int]]:
    
    class_to_idxs: DefaultDict[int, List[int]] = defaultdict(list)
    for idx, class_idx in enumerate(targets):
        class_to_idxs[class_idx].append(idx)

    # Get classes that have number of samples less than k
    classes_to_extend: List[int] = []
    for class_ in class_to_idxs:
        if len(class_to_idxs[class_]) < samples_per_class:
            classes_to_extend.append(class_)
            continue

    # For class that have less than k sample, we will extend that class
    # by duplicating random samples in class until the class has k samples
    for class_ in classes_to_extend:
        n_samples: int = len(class_to_idxs[class_])
        n_samples_to_extend = samples_per_class - n_samples

        random_samples: List[int] = []
        for _ in range(n_samples_to_extend):
            # Choose a random sample in class to duplicate
            random_sample: int = np.random.choice(class_to_idxs[class_])
            random_samples.append(random_sample)

        # Extend current class to have k samples
        class_to_idxs[class_].extend(random_samples)

    return class_to_idxs


class PKSampler(Sampler):

    def __init__(self,
                 labels: List[int],
                 classes_per_batch: int,
                 samples_per_class: int,
                 sequential_sampling: bool = False
                 ):

        self.labels: List[int] = labels
        self.classes_per_batch: int = classes_per_batch
        self.samples_per_class: int = samples_per_class
        self.class_to_idxs: DefaultDict[int, List[int]] = _create_groups(labels, self.samples_per_class)

        self.sequential_sampling: bool = sequential_sampling

        # Ensures there are enough classes to sample from
        if len(self.class_to_idxs) < classes_per_batch:
            raise Exception(f"There are not enough classes to sample."
                            f"Got: class_to_idxs={self.class_to_idxs}, "
                            f"classes_per_batch={classes_per_batch}"
                            )

    def __iter__(self) -> Iterator[Any]:
        """
        Return index of images and targets to be sampled
        """
        if not self.sequential_sampling:
            # Shuffle sample within classes
            for key in self.class_to_idxs:
                random.shuffle(self.class_to_idxs[key])

            # Keep trach of the number of samples left for each classes
            class_to_n_samples: Dict[int, int] = {}
            for class_, sample_idxs in self.class_to_idxs.items():
                class_to_n_samples[class_] = len(sample_idxs)

            while len(class_to_n_samples) >= self.classes_per_batch:
                # Select p classes at random from valid remaining classes
                classes: List[int] = list(class_to_n_samples.keys())
                selected_class_idxs: List[int] = torch.multinomial(
                    torch.ones(len(classes)),
                    self.classes_per_batch
                ).tolist()

                for i in selected_class_idxs:
                    class_: int = classes[i]
                    # List of indexes of samples in a particular class
                    sample_idxs: List[int] = self.class_to_idxs[class_]
                    for _ in range(self.samples_per_class):
                        # Sequentially return an index of a sample in a class
                        sample_idx: int = len(sample_idxs) - class_to_n_samples[class_]
                        yield sample_idxs[sample_idx]
                        class_to_n_samples[class_] -= 1

                    # Don't sample from class if it has less than k samples remaning
                    if class_to_n_samples[class_] < self.samples_per_class:
                        class_to_n_samples.pop(class_)

        else:
            for i in range(len(self.labels)):
                yield i

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm
import sys
import logging
from typing import Tuple, Dict, Any, Union

logger = logging.getLogger(__name__)


def train_one_epoch(model: nn.Module,
                    optimizer: Optimizer,
                    loss_function: Union[TripletMarginLoss, ProxyNCALoss, ProxyAnchorLoss, SoftTripleLoss],
                    train_loader: DataLoader,
                    test_loader: DataLoader,
                    reference_loader: DataLoader,
                    writer: SummaryWriter,
                    device: torch.device,
                    config: Dict[str, Any],
                    checkpoint_dir: str,
                    log_frequency: int,
                    validate_frequency: int,
                    output_dict: Dict[str, Any]
                    ) -> Dict[str, Any]:

    # Increase number of epochs so far
    output_dict["current_epoch"] += 1
    current_epoch: int = output_dict["current_epoch"]

    running_loss: float = 0.0
    running_fraction_hard_triplets: float = 0.0

    for images, labels in train_loader:
        # Increase number of iterations so far
        output_dict["current_iter"] += 1
        current_iter: int = output_dict["current_iter"]

        # Run validation
        if current_iter == 1 or (current_iter % validate_frequency == 0):
            metrics: Dict[str, Any] = calculate_all_metrics(model.module, test_loader, reference_loader, device)
            log_info_metrics(logger, metrics, current_epoch)

            # Log all metrics to tensorboard
            for metric_name, value in metrics.items():
                writer.add_scalar(f"test/{metric_name}", value, current_iter)

            # Save checkpoint that has highest MAP
            if metrics['mean_average_precision'] > output_dict["metrics"]["mean_average_precision"]:
                output_dict["metrics"] = metrics
                save_checkpoint(
                    model,
                    config,
                    current_epoch,
                    current_iter,
                    checkpoint_dir,
                    metrics['mean_average_precision'],
                )

        output_batch: Dict[str, Any] = train_one_batch(model, optimizer, loss_function, images, labels, device)
        running_loss += output_batch["loss"]
        running_fraction_hard_triplets += output_batch["fraction_hard_triplets"]

        # Logging to tensorboard
        for metric_name, value in output_batch.items():
            writer.add_scalar(f"train/{metric_name}", value, current_iter)

        # Logging to standard output stream
        if current_iter % log_frequency == 0:
            average_loss: float = running_loss / log_frequency
            average_hard_triplets: float = running_fraction_hard_triplets / log_frequency * 100
            logger.info(
                f"TRAINING\t[{current_epoch}|{current_iter}]\t"
                f"train_loss: {average_loss:.6f}\t"
                f"hard triplets: {average_hard_triplets:.2f}%"
            )
            running_loss = 0.0
            running_fraction_hard_triplets = 0.0

    # Run validation at the final iteration
    if output_dict["current_epoch"] == output_dict["total_epoch"]:
        metrics: Dict[str, Any] = calculate_all_metrics(model.module, test_loader, reference_loader, device)
        log_info_metrics(logger, metrics, current_epoch)

        for metric_name, value in metrics.items():
            writer.add_scalar(f"test/{metric_name}", value, current_iter)

        if metrics['mean_average_precision'] > output_dict["metrics"]["mean_average_precision"]:
            output_dict["metrics"] = metrics
            save_checkpoint(
                model,
                config,
                current_epoch,
                current_iter,
                checkpoint_dir,
                metrics['mean_average_precision'],
            )
    return output_dict


def train_one_batch(model: nn.Module,
                    optimizer: Optimizer,
                    loss_function: Union[TripletMarginLoss, ProxyNCALoss, ProxyAnchorLoss, SoftTripleLoss],
                    images: torch.Tensor,
                    labels: torch.Tensor,
                    device: torch.device,
                    ) -> Tuple[float, float]:
    model.train()
    optimizer.zero_grad()

    images: torch.Tensor = images.to(device, non_blocking=True)
    labels: torch.Tensor = labels.to(device, non_blocking=True)

    embeddings: torch.Tensor = model(images)
    loss, fraction_hard_triplets = loss_function(embeddings, labels)

    loss.backward()
    optimizer.step()

    return {
        "loss": loss.item(),
        "fraction_hard_triplets": float(fraction_hard_triplets)
    }


def log_info_metrics(logger, metrics: Dict[str, float], current_epoch: int) -> None:
    """
    Print all metrics to stdout
    """
    logger.info("*" * 130)
    logger.info(
        f"VALIDATING\t[{current_epoch}]\t"
        f"MAP: {metrics['mean_average_precision']:.2f}%\t"
        f"AP@1: {metrics['average_precision_at_1']:.2f}%\t"
        f"AP@5: {metrics['average_precision_at_5']:.2f}%\t"
        f"Top-1: {metrics['top_1_accuracy']:.2f}%\t"
        f"Top-5: {metrics['top_5_accuracy']:.2f}%\t"
        f"NMI: {metrics['normalized_mutual_information']:.2f}\t"
    )
    logger.info("*" * 130)

In [None]:
from torchvision.transforms.transforms import CenterCrop
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim import RAdam

import argparse
import logging
import json
import time
import sys
import os
import yaml
from pprint import pformat
from typing import Dict, Any

CURRENT_TIME: str = get_current_time()

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s  %(name)s  %(levelname)s: %(message)s',
    datefmt='%y-%b-%d %H:%M:%S',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler(f"./{CURRENT_TIME}.txt", mode="w", encoding="utf-8")
    ]
)
logger = logging.getLogger(__name__)


def main():
    start = time.time()

    # Intialize config
    config = {
        "lr": 0.0001,
        "image_size": 256,
        "embedding_size": 128,
        "batch_size": 48,
        "smoothing_factor": 0.1,
        "alpha": 32,
        "embedding_scale": 1.0,
        "proxy_scale": 3.0,
        "n_epochs": 100,
        "loss": "soft_triple",
        "n_samples_per_reference_class": -1,
        "checkpoint_root_dir": "/content/drive/MyDrive/DATN/save_check_points",
        "log_frequency": 100,
        "validate_frequency": 1000,
        "sampling_type": "batch_hard_triplets",
        "samples_per_class": 4,
        "classes_per_batch": 12,
        "k_queries": 3,
        "n_centers_per_class": 5,
        # lambda is for calculating distance probability. See section 3.2 in the paper for more detail.
        "lambda": 20,
        # gamma is for calculating cross entropy loss. See section 3.2 in the paper for more detail.
        "gamma": 0.1,
        # tau is for regularization. See section 3.2 in the paper for more detail.
        "tau": 0.,
        # Margin factor
        "margin": 0.01,
        "n_workers": 2
    }

    # Weight_path
    weight_path = "/content/drive/MyDrive/DATN/mobilenet_v3_large-8738ca79.pth"
    # Train_data_path
    train_dir = "/content/data/train"
    # Val_data_dir
    val_dir = "/content/data/val"
    # Initialize device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # Initialize number worker
    num_worker = min([os.cpu_count(), config["batch_size"] if config["batch_size"] > 1 else 0, 8])
    # Intialize model
    model = nn.DataParallel(Mobilenet_v3_large(
        embedding_size=config["embedding_size"],
        weight_path=weight_path
    ))
    model = model.to(device)


    # Initialize optimizer
    optimizer = RAdam(model.parameters(), lr=config["lr"])

    # Initialize train transforms
    transform_train = transforms.Compose([
        transforms.Resize((config["image_size"], config["image_size"])),
        transforms.RandomResizedCrop((224, 224)),
        transforms.RandomRotation(degrees=30),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.RandomAffine(degrees=5, scale=(0.8, 1.2), translate=(0.2, 0.2)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])


    # Initialize training set
    train_set = Dataset(train_dir, transform=transform_train)

    if config["loss"] == "triplet_loss":
        # Initialize train loader for triplet loss
        batch_size: int = config["classes_per_batch"] * config["samples_per_class"]
        train_loader = DataLoader(
            train_set,
            batch_size,
            sampler=PKSampler(
                train_set.targets,
                config["classes_per_batch"],
                config["samples_per_class"]
            ),
            shuffle=False,
            num_workers=num_worker,
            pin_memory=True,
        )


        # Intialize loss function
        loss_function = TripletMarginLoss(
            margin=config["margin"],
            sampling_type=config["sampling_type"]
        )


    elif config["loss"] == "proxy_nca":
        # Initialize train loader for proxy-nca loss
        batch_size: int = config["batch_size"]
        train_loader = DataLoader(
            train_set,
            config["batch_size"],
            shuffle=True,
            num_workers=num_worker,
            pin_memory=True,
        )


        loss_function = ProxyNCALoss(
            n_classes=len(train_set.classes),
            embedding_size=config["embedding_size"],
            embedding_scale=config["embedding_scale"],
            proxy_scale=config["proxy_scale"],
            smoothing_factor=config["smoothing_factor"],
            device=device
        )

    elif config["loss"] == "proxy_anchor":
        # Intialize train loader for proxy-anchor loss
        batch_size: int = config["batch_size"]
        train_loader = DataLoader(
            train_set,
            config["batch_size"],
            shuffle=True,
            num_workers=config["n_workers"],
            pin_memory=True,
        )


        loss_function = ProxyAnchorLoss(
            n_classes=len(train_set.classes),
            embedding_size=config["embedding_size"],
            margin=config["margin"],
            alpha=config["alpha"],
            device=device
        )

    elif config["loss"] == "soft_triple":
        # Intialize train loader for proxy-anchor loss
        batch_size: int = config["batch_size"]
        train_loader = DataLoader(
            train_set,
            config["batch_size"],
            shuffle=True,
            num_workers=config["n_workers"],
            pin_memory=True,
        )

        loss_function = SoftTripleLoss(
            n_classes=len(train_set.classes),
            embedding_size=config["embedding_size"],
            n_centers_per_class=config["n_centers_per_class"],
            lambda_=config["lambda"],
            gamma=config["gamma"],
            tau=config["tau"],
            margin=config["margin"],
            device=device
        )
    else:
        raise Exception("Only the following losses is supported: "
                        "['tripletloss', 'proxy_nca', 'proxy_anchor', 'soft_triple']. "
                        f"Got {config['loss']}")


    # Initialize test transforms
    transform_test = transforms.Compose([
        transforms.Resize((config["image_size"], config["image_size"])),
        transforms.CenterCrop((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])



    # Initialize test set and test loader
    test_dataset = Dataset(val_dir, transform=transform_test)
    test_loader = DataLoader(
        test_dataset, batch_size,
        shuffle=False,
        num_workers=num_worker,
    )


    # Initialize reference set and reference loader
    reference_set = Dataset(train_dir, transform=transform_test)

    n_samples_per_reference_class: int = config["n_samples_per_reference_class"]
    if n_samples_per_reference_class > 0:
        reference_set = get_subset_from_dataset(reference_set, n_samples_per_reference_class)

    reference_loader = DataLoader(reference_set, batch_size, shuffle=False, num_workers=num_worker)



    # Initialize checkpointing directory
    checkpoint_dir: str = os.path.join(config["checkpoint_root_dir"], CURRENT_TIME)
    writer = SummaryWriter(log_dir=checkpoint_dir)
    logger.info(f"Created checkpoint directory at: {checkpoint_dir}")


    # Dictionary contains all metrics
    output_dict: Dict[str, Any] = {
        "total_epoch": config["n_epochs"],
        "current_epoch": 0,
        "current_iter": 0,
        "metrics": {
            "mean_average_precision": 0.0,
            "average_precision_at_1": 0.0,
            "average_precision_at_5": 0.0,
            "average_precision_at_10": 0.0,
            "top_1_accuracy": 0.0,
            "top_5_accuracy": 0.0,
            "normalized_mutual_information": 0.0,
        }
    }
    # Start training and testing
    logger.info("Start training...")
    for _ in range(1, config["n_epochs"] + 1):
        output_dict = train_one_epoch(
            model, optimizer, loss_function,
            train_loader, test_loader, reference_loader,
            writer, device, config,
            checkpoint_dir,
            config['log_frequency'],
            config['validate_frequency'],
            output_dict
        )
    logger.info(f"DONE TRAINING {config['n_epochs']} epochs")


    # Visualize embeddings
    logger.info("Calculating train embeddings for visualization...")
    log_embeddings_to_tensorboard(train_loader, model, device, writer, tag="train")
    logger.info("Calculating reference embeddings for visualization...")
    log_embeddings_to_tensorboard(reference_loader, model, device, writer, tag="reference")
    logger.info("Calculating test embeddings for visualization...")
    log_embeddings_to_tensorboard(test_loader, model, device, writer, tag="test")


    # Visualize model's graph
    logger.info("Adding graph for visualization")
    with torch.no_grad():
        dummy_input = torch.zeros(1, 3, config["image_size"], config["image_size"]).to(device)
        writer.add_graph(model.module.features, dummy_input)


    # Save all hyper-parameters and corresponding metrics
    logger.info("Saving all hyper-parameters")
    writer.add_hparams(
        config,
        metric_dict={f"hyperparams/{key}": value for key, value in output_dict["metrics"].items()}
    )
    with open(os.path.join(checkpoint_dir, "output_dict.json"), "w") as f:
        json.dump(output_dict, f, indent=4)
    logger.info(f"Dumped output_dict.json at {checkpoint_dir}")


    end = time.time()
    logger.info(f"EVERYTHING IS DONE. Training time: {round(end - start, 2)} seconds")


if __name__ == "__main__":
    main()

22-Jul-01 03:34:00  __main__  INFO: Created checkpoint directory at: /content/drive/MyDrive/DATN/save_check_points/2022-07-01_03-34-00
22-Jul-01 03:34:00  __main__  INFO: Start training...


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations


22-Jul-01 03:34:47  __main__  INFO: **********************************************************************************************************************************
22-Jul-01 03:34:47  __main__  INFO: VALIDATING	[1]	MAP: 81.25%	AP@1: 95.78%	AP@5: 83.63%	Top-1: 95.78%	Top-5: 98.72%	NMI: 0.89	
22-Jul-01 03:34:47  __main__  INFO: **********************************************************************************************************************************


TypeError: ignored