# DML loss function
## 準備


In [None]:
!pip install -q pytorch-metric-learning[with-hooks]

[K     |████████████████████████████████| 102kB 7.1MB/s 
[K     |████████████████████████████████| 67.7MB 44kB/s 
[?25h

In [None]:
%matplotlib inline
from pytorch_metric_learning import losses, miners, samplers, trainers, testers
from pytorch_metric_learning.utils import common_functions
import pytorch_metric_learning.utils.logging_presets as logging_presets
import numpy as np
import torchvision
from torchvision import datasets, transforms
import torch
import torch.nn as nn
from PIL import Image
import logging
import matplotlib.pyplot as plt
import umap
from cycler import cycler
import record_keeper
import pytorch_metric_learning
logging.getLogger().setLevel(logging.INFO)
logging.info("VERSION %s"%pytorch_metric_learning.__version__)
    

INFO:root:VERSION 0.9.96


In [None]:
class MLP(nn.Module):
    # layer_sizes[0] is the dimension of the input
    # layer_sizes[-1] is the dimension of the output
    def __init__(self, layer_sizes, final_relu=False):
        super().__init__()
        layer_list = []
        layer_sizes = [int(x) for x in layer_sizes]
        num_layers = len(layer_sizes) - 1
        final_relu_layer = num_layers if final_relu else num_layers - 1
        for i in range(len(layer_sizes) - 1):
            input_size = layer_sizes[i]
            curr_size = layer_sizes[i + 1]
            if i < final_relu_layer:
                layer_list.append(nn.ReLU(inplace=False))
            layer_list.append(nn.Linear(input_size, curr_size))
        self.net = nn.Sequential(*layer_list)
        self.last_linear = self.net[-1]

    def forward(self, x):
        return self.net(x)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set trunk model and replace the softmax layer with an identity function
trunk = torchvision.models.resnet18(pretrained=True)
trunk_output_size = trunk.fc.in_features
trunk.fc = common_functions.Identity()
trunk = torch.nn.DataParallel(trunk.to(device))

# Set embedder model. This takes in the output of the trunk and outputs 64 dimensional embeddings
embedder = torch.nn.DataParallel(MLP([trunk_output_size, 64]).to(device))

# Set optimizers
trunk_optimizer = torch.optim.Adam(trunk.parameters(), lr=0.00001, weight_decay=0.0001)
embedder_optimizer = torch.optim.Adam(embedder.parameters(), lr=0.0001, weight_decay=0.0001)

# Set the image transforms
train_transform = transforms.Compose([transforms.Resize(64),
                                    transforms.RandomResizedCrop(scale=(0.16, 1), ratio=(0.75, 1.33), size=64),
                                    transforms.RandomHorizontalFlip(0.5),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

val_transform = transforms.Compose([transforms.Resize(64),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])





Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))




In [None]:
# Download the original datasets
original_train = datasets.CIFAR100(root="CIFAR100_Dataset", train=True, transform=None, download=True)
original_val = datasets.CIFAR100(root="CIFAR100_Dataset", train=False, transform=None, download=True)

# This will be used to create train and val sets that are class-disjoint
class ClassDisjointCIFAR100(torch.utils.data.Dataset):
    def __init__(self, original_train, original_val, train, transform):
        rule = (lambda x: x < 50) if train else (lambda x: x >=50)
        train_filtered_idx = [i for i,x in enumerate(original_train.targets) if rule(x)]
        val_filtered_idx = [i for i,x in enumerate(original_val.targets) if rule(x)]
        self.data = np.concatenate([original_train.data[train_filtered_idx], original_val.data[val_filtered_idx]], axis=0)
        self.targets = np.concatenate([np.array(original_train.targets)[train_filtered_idx], np.array(original_val.targets)[val_filtered_idx]], axis=0)
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, index):            
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)
        if self.transform is not None:
            img = self.transform(img)
        return img, target

# Class disjoint training and validation set
train_dataset = ClassDisjointCIFAR100(original_train, original_val, True, train_transform)
val_dataset = ClassDisjointCIFAR100(original_train, original_val, False, val_transform)
assert set(train_dataset.targets).isdisjoint(set(val_dataset.targets))

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to CIFAR100_Dataset/cifar-100-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting CIFAR100_Dataset/cifar-100-python.tar.gz to CIFAR100_Dataset
Files already downloaded and verified


## Define the simplest loss

In [None]:
from pytorch_metric_learning.losses import BaseMetricLossFunction
import torch

class BarebonesLoss(BaseMetricLossFunction):
    def compute_loss(self, embeddings, labels, indices_tuple):
        # perform some calculation #
        print('indices tuple:', indices_tuple)
        some_loss = torch.mean(embeddings)

        # put into dictionary #
        return {
            "loss": {
                "losses": some_loss,
                "indices": None,
                "reduction_type": "already_reduced",
            }
        }


# TopK-Pre
## libraries


In [None]:
#####################
# Global Attributes
#####################

""" Seed """
l2r_seed = 137

""" A Small Value """
epsilon  = 1e-8


""" GPU Setting If Expected """

#global_gpu, global_device, gpu_id = False, 'cpu', None
global_gpu, global_device, gpu_id = True, 'cuda:0', 0
#global_gpu, global_device, gpu_id = True, 'cuda:1', 1

#
if global_gpu: torch.cuda.set_device(gpu_id)

# a uniform tensor type
tensor      = torch.cuda.FloatTensor if global_gpu else torch.FloatTensor
byte_tensor = torch.cuda.ByteTensor  if global_gpu else torch.ByteTensor
long_tensor = torch.cuda.LongTensor  if global_gpu else torch.LongTensor

# uniform constants
torch_one, torch_half, torch_zero = tensor([1.0]), tensor([0.5]), tensor([0.0])
torch_two = tensor([2.0])

torch_minus_one = tensor([-1.0])

def get_pairwise_stds(batch_labels):
	"""
	:param batch_labels: [batch_size], for each element of the batch assigns a class [0,...,C-1]
	:return: [batch_size, batch_size], where S_ij represents whether item-i and item-j belong to the same class
	"""
	assert 1 == len(batch_labels.size())
	#print(batch_labels.size())
	batch_labels = batch_labels.type(tensor)
	cmp_mat = torch.unsqueeze(batch_labels, dim=1) - torch.unsqueeze(batch_labels, dim=0)
	sim_mat_std = torch.where(cmp_mat==0, torch_one, torch_zero)

	return sim_mat_std

def get_pairwise_similarity(batch_reprs):
	'''
	todo-as-note Currently, it the dot-product of a pair of representation vectors, on the assumption that the input vectors are already normalized
	Efficient function to compute the pairwise similarity matrix given the input vector representations.
	:param batch_reprs: [batch_size, length of vector repr] a batch of vector representations
	:return: [batch_size, batch_size]
	'''

	sim_mat = torch.matmul(batch_reprs, batch_reprs.t())
	return sim_mat

def dist(batch_reprs, eps = 1e-16, squared=False):
	"""
	Efficient function to compute the distance matrix for a matrix A.

	Args:
		batch_reprs:  vector representations
		eps: float, minimal distance/clampling value to ensure no zero values.
	Returns:
		distance_matrix, clamped to ensure no zero values are passed.
	"""
	prod = torch.mm(batch_reprs, batch_reprs.t())
	norm = prod.diag().unsqueeze(1).expand_as(prod)
	res = (norm + norm.t() - 2 * prod).clamp(min = 0)

	if squared:
		return res.clamp(min=eps)
	else:
		return res.clamp(min = eps).sqrt()

## TopKPreMiner


In [None]:
from pytorch_metric_learning.miners.base_miner import BaseTupleMiner
from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.utils import common_functions as c_f
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu

class TopKPreMiner(BaseTupleMiner):
  def __init__(self, k, **kwargs):
        super().__init__(**kwargs)
        self.k = k
        self.add_to_recordable_attributes(name="k", is_stat=False)
  
  def mine(self, embeddings, labels, ref_emb, ref_labels):
        print('embeddings', embeddings.size())
        print('labels', labels.size())
        print('ref_emb', ref_emb.size())
        print('ref_labels', ref_labels.size())
        mat = self.distance(embeddings, ref_emb)
        a1, p, a2, n = lmu.get_all_pairs_indices(labels, ref_labels)

        if len(a1) == 0 or len(a2) == 0:
            empty = torch.LongTensor([]).to(labels.device)
            return empty.clone(), empty.clone(), empty.clone(), empty.clone()

        mat_neg_sorting = mat
        mat_pos_sorting = mat.clone()

        dtype = mat.dtype
        pos_ignore = (
            c_f.pos_inf(dtype) if self.distance.is_inverted else c_f.neg_inf(dtype)
        )
        neg_ignore = (
            c_f.neg_inf(dtype) if self.distance.is_inverted else c_f.pos_inf(dtype)
        )

        mat_pos_sorting[a2, n] = pos_ignore
        mat_neg_sorting[a1, p] = neg_ignore
        if embeddings is ref_emb:
            mat_pos_sorting.fill_diagonal_(pos_ignore)
            mat_neg_sorting.fill_diagonal_(neg_ignore)

        pos_sorted, pos_sorted_idx = torch.sort(mat_pos_sorting, dim=1)
        neg_sorted, neg_sorted_idx = torch.sort(mat_neg_sorting, dim=1)

        if self.distance.is_inverted:
            hard_pos_idx = torch.where(
                pos_sorted - self.epsilon < neg_sorted[:, -1].unsqueeze(1)
            )
            hard_neg_idx = torch.where(
                neg_sorted + self.epsilon > pos_sorted[:, 0].unsqueeze(1)
            )
        else:
            hard_pos_idx = torch.where(
                pos_sorted + self.epsilon > neg_sorted[:, 0].unsqueeze(1)
            )
            hard_neg_idx = torch.where(
                neg_sorted - self.epsilon < pos_sorted[:, -1].unsqueeze(1)
            )

        a1 = hard_pos_idx[0]
        p = pos_sorted_idx[a1, hard_pos_idx[1]]
        a2 = hard_neg_idx[0]
        n = neg_sorted_idx[a2, hard_neg_idx[1]]

        return a1, p, a2, n

  def get_default_distance(self):
      return CosineSimilarity()

## Define TopK-Pre Loss

In [None]:
from pytorch_metric_learning.losses import BaseMetricLossFunction
import torch

class TopKPreLoss(BaseMetricLossFunction):
    """
    Sampling Wisely: Deep Image Embedding by Top-K Precision Optimization
    Jing Lu, Chaofan Xu, Wei Zhang, Ling-Yu Duan, Tao Mei; The IEEE International Conference on Computer Vision (ICCV), 2019, pp. 7961-7970
    """

    def __init__(self, k=4):#, anchor_id='Anchor', use_similarity=False, opt=None):
        super().__init__()

        self.name = 'TopKPreLoss'
        # assert anchor_id in ANCHOR_ID

        # self.opt = opt
        # self.anchor_id = anchor_id
        # self.use_similarity = use_similarity

        self.k = 4
        self.margin = 0.1 # self.opt.margin

        # if 'Class' == anchor_id:
        #     assert 0 == self.opt.bs % self.opt.samples_per_class
        #     self.num_distinct_cls = int(self.opt.bs / self.opt.samples_per_class)

    def compute_loss(self, embeddings, labels, indices_tuple): # (simi_mat, cls_match_mat, k=5, margin=None):
        '''
        assuming no-existence of classes with a single instance == samples_per_class > 1
        :param sim_mat: [batch_size, batch_size] pairwise similarity matrix, without removing self-similarity
        :param cls_match_mat: [batch_size, batch_size] v_ij is one if d_i and d_j are of the same class, zero otherwise
        :param k: cutoff value
        :param margin:
        :return:
        '''
        # print('conpute loss')
        # print('embeddings size', embeddings.size())
        # print('labels size', labels.size())

        simi_mat = get_pairwise_similarity(batch_reprs=embeddings)
        cls_match_mat = get_pairwise_stds(
            batch_labels=labels)  # [batch_size, batch_size] S_ij is one if d_i and d_j are of the same class, zero otherwise
        # print('simi mat', simi_mat.size())
        # print('class match mat', cls_match_mat.size())

        simi_mat_hat = simi_mat + (1.0 - cls_match_mat) * self.margin  # impose margin

        ''' get rank positions '''
        _, orgp_indice = torch.sort(simi_mat_hat, dim=1, descending=True)
        _, desc_indice = torch.sort(orgp_indice, dim=1, descending=False)
        rank_mat = desc_indice + 1.  # todo using desc_indice directly without (+1) to improve efficiency
        # print('rank_mat', rank_mat)

        # number of true neighbours within the batch
        batch_pos_nums = torch.sum(cls_match_mat, dim=1)

        ''' get proper K rather than a rigid predefined K
        torch.clamp(tensor, min=value) is cmax and torch.clamp(tensor, max=value) is cmin.
        It works but is a little confusing at first.
        '''
        # batch_ks = torch.clamp(batch_pos_nums, max=k)
        '''
        due to no explicit self-similarity filtering.
        implicit assumption: a common L2-normalization leading to self-similarity of the maximum one!
        '''
        batch_ks = torch.clamp(batch_pos_nums, max=self.k + 1)
        k_mat = batch_ks.view(-1, 1).repeat(1, rank_mat.size(1))
        # print('k_mat', k_mat.size())

        '''
        Only deal with a single case: n_{+}>=k
        step-1: determine set of false positive neighbors, i.e., N, i.e., cls_match_std is zero && rank<=k

        step-2: determine the size of N, i.e., |N| which determines the size of P

        step-3: determine set of false negative neighbors, i.e., P, i.e., cls_match_std is one && rank>k && rank<= (k+|N|)
        '''
        # N
        batch_false_pos = (cls_match_mat < 1) & (rank_mat <= k_mat)  # torch.uint8 -> used as indice
        # print('batch_false_pos', batch_false_pos) bool
        batch_fp_nums = torch.sum(batch_false_pos.float(), dim=1)  # used as one/zero
        # print('batch_fp_nums', batch_fp_nums)

        # P
        batch_false_negs = cls_match_mat.bool() & (rank_mat > k_mat)  # all false negative

        ''' just for check '''
        # batch_fn_nums = torch.sum(batch_false_negs.float(), dim=1)
        # print('batch_fn_nums', batch_fn_nums)

        # batch_loss = 0
        batch_loss = torch.tensor(0., requires_grad=True).cuda()
        for i in range(cls_match_mat.size(0)):
            fp_num = int(batch_fp_nums.data[i].item())
            if fp_num > 0:  # error exists, in other words, skip correct case
                # print('fp_num', fp_num)
                all_false_neg = simi_mat_hat[i, :][batch_false_negs[i, :]]
                # print('all_false_neg', all_false_neg)
                top_false_neg, _ = torch.topk(all_false_neg, k=fp_num, sorted=False, largest=True)
                # print('top_false_neg', top_false_neg)

                false_pos = simi_mat_hat[i, :][batch_false_pos[i, :]]

                loss = torch.sum(false_pos - top_false_neg)
                batch_loss += loss
        return {
            "loss": {
                "losses": batch_loss,
                "indices": None,
                "reduction_type": "already_reduced",
            }
        }

## Define RS-TopK-Pre

In [None]:
class RSTopKPreLoss(losses.BaseMetricLossFunction):
    """
    Sampling Wisely: Deep Image Embedding by Top-K Precision Optimization
    Jing Lu, Chaofan Xu, Wei Zhang, Ling-Yu Duan, Tao Mei; The IEEE International Conference on Computer Vision (ICCV), 2019, pp. 7961-7970
    """

    def __init__(self, k=4):#, anchor_id='Anchor'): # , use_similarity=False, opt=None):
        super().__init__()

        self.name = 'RSTopKPreLoss'

        # self.opt = opt
        # self.use_similarity = use_similarity

        self.k = k
        self.margin = 0.1 # self.opt.margin

        # if 'Class' == anchor_id:置いておく
        #     assert 0 == self.opt.bs % self.opt.samples_per_class
        #     self.num_distinct_cls = int(self.opt.bs / self.opt.samples_per_class)

    def compute_loss(self, embeddings, labels, indices_tuple): # (simi_mat, cls_match_mat, k=5, margin=None):
        '''
        assuming no-existence of classes with a single instance == samples_per_class > 1
        :param sim_mat: [batch_size, batch_size] pairwise similarity matrix, without removing self-similarity
        :param cls_match_mat: [batch_size, batch_size] v_ij is one if d_i and d_j are of the same class, zero otherwise
        :param k: cutoff value
        :param margin:
        :return:
        '''
        # print('conpute loss')
        # print('embeddings size', embeddings.size())
        # print('labels size', labels.size())

        simi_mat = get_pairwise_similarity(batch_reprs=embeddings)
        cls_match_mat = get_pairwise_stds(
            batch_labels=labels)  # [batch_size, batch_size] S_ij is one if d_i and d_j are of the same class, zero otherwise
        # print('simi mat', simi_mat.size())
        # print('class match mat', cls_match_mat.size())

        simi_mat_hat = simi_mat + (1.0 - cls_match_mat) * self.margin  # impose margin

        ''' get rank positions '''
        _, orgp_indice = torch.sort(simi_mat_hat, dim=1, descending=True)
        _, desc_indice = torch.sort(orgp_indice, dim=1, descending=False)
        rank_mat = desc_indice + 1.  # todo using desc_indice directly without (+1) to improve efficiency
        # print('rank_mat', rank_mat)

        # number of true neighbours within the batch
        batch_pos_nums = torch.sum(cls_match_mat, dim=1)

        ''' get proper K rather than a rigid predefined K
        torch.clamp(tensor, min=value) is cmax and torch.clamp(tensor, max=value) is cmin.
        It works but is a little confusing at first.
        '''
        # batch_ks = torch.clamp(batch_pos_nums, max=k)
        '''
        due to no explicit self-similarity filtering.
        implicit assumption: a common L2-normalization leading to self-similarity of the maximum one!
        '''
        batch_ks = torch.clamp(batch_pos_nums, max=self.k + 1)
        k_mat = batch_ks.view(-1, 1).repeat(1, rank_mat.size(1))
        # print('k_mat', k_mat.size())

        '''
        Only deal with a single case: n_{+}>=k
        step-1: determine set of false positive neighbors, i.e., N, i.e., cls_match_std is zero && rank<=k

        step-2: determine the size of N, i.e., |N| which determines the size of P

        step-3: determine set of false negative neighbors, i.e., P, i.e., cls_match_std is one && rank>k && rank<= (k+|N|)
        '''
        # N
        batch_false_pos = (cls_match_mat < 1) & (rank_mat <= k_mat)  # torch.uint8 -> used as indice
        # print('batch_false_pos', batch_false_pos) bool
        batch_fp_nums = torch.sum(batch_false_pos.float(), dim=1)  # used as one/zero
        # print('batch_fp_nums', batch_fp_nums)

        # P
        batch_false_negs = cls_match_mat.bool() & (rank_mat > k_mat)  # all false negative

        ''' just for check '''
        # batch_fn_nums = torch.sum(batch_false_negs.float(), dim=1)
        # print('batch_fn_nums', batch_fn_nums)

        # batch_loss = 0
        batch_loss = torch.tensor(0., requires_grad=True).cuda()
        for i in range(cls_match_mat.size(0)):
            fp_num = int(batch_fp_nums.data[i].item())
            if fp_num > 0:  # error exists, in other words, skip correct case
                all_false_neg = simi_mat_hat[i, :][batch_false_negs[i, :]]
                rank_neg = rank_mat[i, :][batch_false_negs[i, :]]
                top_false_neg, neg_idx = torch.topk(all_false_neg, k=fp_num, sorted=False, largest=True)
                rank_top_neg = torch.gather(rank_neg, -1, neg_idx)
                ks = torch.zeros(fp_num).cuda()
                batch_ones = torch.ones_like(ks)
                ks0 = ks.add(self.k)
                # ks1 = ks.add(k+1)
                ks3 = ks.add(self.k + 3)
                beta1 = torch.add(batch_ones, -1 / (rank_top_neg - ks0))
                # print('rank top neg', rank_top_neg)
                # print('ks0', ks0)
                # print('beta1', beta1)
                # print('top_false_neg', top_false_neg)
                loss_neg = 3 * torch.dot(beta1, top_false_neg) / fp_num
                # print("loss_neg", loss_neg)

                rank_pos = rank_mat[i, :][batch_false_pos[i, :]]
                false_pos = simi_mat_hat[i, :][batch_false_pos[i, :]]
                # print('fp_num', fp_num)
                # print('ks3', ks3)
                # print('rank_pos', rank_pos)
                # print('batch_ones', batch_ones)
                beta2 = torch.add(batch_ones, -1 / (ks3 - rank_pos))
                # print('beta2', beta2)
                # print('false pos', false_pos)
                # print('false_pos', false_pos)
                loss_pos = 3 * torch.dot(beta2, false_pos) / fp_num
                # print("loss_pos", loss_pos)

                loss = torch.sum(loss_pos - loss_neg)  # /fp_num
                batch_loss += loss

        return {
            "loss": {
                "losses": batch_loss,
                "indices": None,
                "reduction_type": "already_reduced",
            }
        }

# L2R method for DML

##  Lambdarank


In [None]:
def torch_ideal_dcg(batch_sorted_labels, gpu=False):
    '''
    :param sorted_labels: [batch, ranking_size]
    :return: [batch, 1]
    '''
    batch_gains = torch.pow(2.0, batch_sorted_labels) - 1.0
    batch_ranks = torch.arange(batch_sorted_labels.size(1))

    batch_discounts = torch.log2(2.0 + batch_ranks.type(torch.cuda.FloatTensor)) if gpu else torch.log2(2.0 + batch_ranks.type(torch.FloatTensor))
    batch_ideal_dcg = torch.sum(batch_gains / batch_discounts, dim=1, keepdim=True)

    return batch_ideal_dcg


def get_delta_ndcg(batch_stds, batch_stds_sorted_via_preds):
    '''
    Delta-nDCG w.r.t. pairwise swapping of the currently predicted ltr_adhoc
    :param batch_stds: the standard labels sorted in a descending order
    :param batch_stds_sorted_via_preds: the standard labels sorted based on the corresponding predictions
    :return:
    '''
    batch_idcgs = torch_ideal_dcg(batch_sorted_labels=batch_stds, gpu=global_gpu)                      # ideal discount cumulative gains

    batch_gains = torch.pow(2.0, batch_stds_sorted_via_preds) - 1.0
    batch_n_gains = batch_gains / batch_idcgs               # normalised gains
    batch_ng_diffs = torch.unsqueeze(batch_n_gains, dim=2) - torch.unsqueeze(batch_n_gains, dim=1)

    batch_std_ranks = torch.arange(batch_stds_sorted_via_preds.size(1)).type(tensor)
    batch_dists = 1.0 / torch.log2(batch_std_ranks + 2.0)   # discount co-efficients
    batch_dists = torch.unsqueeze(batch_dists, dim=0)
    batch_dists_diffs = torch.unsqueeze(batch_dists, dim=2) - torch.unsqueeze(batch_dists, dim=1)
    batch_delta_ndcg = torch.abs(batch_ng_diffs) * torch.abs(batch_dists_diffs)  # absolute changes w.r.t. pairwise swapping

    return batch_delta_ndcg


def lambdarank_loss(batch_preds=None, batch_stds=None, sigma=1.0):
    '''
    This method will impose explicit bias to highly ranked documents that are essentially ties
    :param batch_preds:
    :param batch_stds:
    :return:
    '''

    batch_preds_sorted, batch_preds_sorted_inds = torch.sort(batch_preds, dim=1, descending=True)   # sort documents according to the predicted relevance
    batch_stds_sorted_via_preds = torch.gather(batch_stds, dim=1, index=batch_preds_sorted_inds)    # reorder batch_stds correspondingly so as to make it consistent. BTW, batch_stds[batch_preds_sorted_inds] only works with 1-D tensor

    batch_std_diffs = torch.unsqueeze(batch_stds_sorted_via_preds, dim=2) - torch.unsqueeze(batch_stds_sorted_via_preds, dim=1)  # standard pairwise differences, i.e., S_{ij}
    batch_std_Sij = torch.clamp(batch_std_diffs, min=-1.0, max=1.0) # ensuring S_{ij} \in {-1, 0, 1}

    batch_pred_s_ij = torch.unsqueeze(batch_preds_sorted, dim=2) - torch.unsqueeze(batch_preds_sorted, dim=1)  # computing pairwise differences, i.e., s_i - s_j

    batch_delta_ndcg = get_delta_ndcg(batch_stds, batch_stds_sorted_via_preds)
    # print('batch_delta_ndcg', batch_delta_ndcg)

    batch_loss_1st = 0.5 * sigma * batch_pred_s_ij * (1.0 - batch_std_Sij) # cf. the 1st equation in page-3
    batch_loss_2nd = torch.log(torch.exp(-sigma * batch_pred_s_ij) + 1.0)  # cf. the 1st equation in page-3
    # print('batch_loss_1st', batch_loss_1st)
    # print('batch_loss_2nd', batch_loss_2nd)

    # the coefficient of 0.5 is added due to all pairs are used
    batch_loss = torch.sum(0.5 * (batch_loss_1st + batch_loss_2nd) * batch_delta_ndcg)    # weighting with delta-nDCG
    # print('batch loss', batch_loss)

    return batch_loss

In [None]:
class Lambdarank(torch.nn.Module):

    def __init__(self):
        super(Lambdarank, self).__init__()

        self.name = 'lambdarank'
        # assert anchor_id in ANCHOR_ID
        # self.use_similarity = use_similarity

    def forward(self, embeddings, labels, indices_tuple): #**kwargs
        '''
        :param batch_reprs:  torch.Tensor() [(BS x embed_dim)], batch of embeddings
        :param batch_labels: [(BS x 1)], for each element of the batch assigns a class [0,...,C-1]
        :return:
        '''

        cls_match_mat = get_pairwise_stds(batch_labels=labels)  # [batch_size, batch_size] S_ij is one if d_i and d_j are of the same class, zero otherwise

        # if self.use_similarity:
        #     sim_mat = get_pairwise_similarity(batch_reprs=embeddings)
        # else:
        dist_mat = dist(batch_reprs=embeddings, squared=False)  # [batch_size, batch_size], pairwise distances
        sim_mat = -dist_mat

        # if 'Class' == self.anchor_id:  # vs. anchor wise sorting
        #     cls_match_mat = cls_match_mat.view(self.num_distinct_cls, -1)
        #     sim_mat = sim_mat.view(self.num_distinct_cls, -1)

        # print('sim mat', sim_mat)
        # print('cls_match mat', cls_match_mat)
        batch_loss = lambdarank_loss(batch_preds=sim_mat, batch_stds=cls_match_mat)

        return batch_loss

## ListNet


In [None]:
import torch.nn.functional as F

class ListNet(torch.nn.Module):

    def __init__(self): # , anchor_id='Anchor', use_similarity=False, opt=None
        super(ListNet, self).__init__()

        self.name = 'listnet'

    def forward(self, embeddings, labels, indices_tuple):
        '''
        :param batch_reprs:  torch.Tensor() [(BS x embed_dim)], batch of embeddings
        :param batch_labels: [(BS x 1)], for each element of the batch assigns a class [0,...,C-1]
        :return:
        '''

        cls_match_mat = get_pairwise_stds(batch_labels=labels)  # [batch_size, batch_size] S_ij is one if d_i and d_j are of the same class, zero otherwise

        # if self.use_similarity:
        #     sim_mat = get_pairwise_similarity(batch_reprs=batch)
        # else:
        dist_mat = dist(batch_reprs=embeddings, squared=False)  # [batch_size, batch_size], pairwise distances
        sim_mat = -dist_mat

        # convert to one-dimension vector
        batch_size = embeddings.size(0)
        index_mat = torch.triu(torch.ones(batch_size, batch_size), diagonal=1) == 1
        sim_vec = sim_mat[index_mat]
        cls_vec = cls_match_mat[index_mat]

        # cross-entropy between two softmaxed vectors
        # batch_loss = -torch.sum(F.softmax(sim_vec) * F.log_softmax(cls_vec))
        batch_loss = -torch.sum(F.softmax(cls_vec) * F.log_softmax(sim_vec))

        return batch_loss

## Create the loss, miner, sampler, and package them into dictionaries

In [None]:
# Set the loss function
loss = TopKPreLoss(k=4)
# loss = Lambdarank()
# loss = ListNet()

# Set the mining function
miner = miners.MultiSimilarityMiner(epsilon=0.1)
miner = TopKPreMiner(k=4)

# Set the dataloader sampler
sampler = samplers.MPerClassSampler(train_dataset.targets, m=4, length_before_new_iter=len(train_dataset))

# Set other training parameters
batch_size = 32
num_epochs = 2

# Package the above stuff into dictionaries.
models = {"trunk": trunk, "embedder": embedder}
optimizers = {"trunk_optimizer": trunk_optimizer, "embedder_optimizer": embedder_optimizer}
loss_funcs = {"metric_loss": loss}
mining_funcs = {"tuple_miner": miner}

In [None]:
record_keeper, _, _ = logging_presets.get_record_keeper("example_logs", "example_tensorboard")
hooks = logging_presets.get_hook_container(record_keeper)
dataset_dict = {"val": val_dataset}
model_folder = "example_saved_models"

def visualizer_hook(umapper, umap_embeddings, labels, split_name, keyname, *args):
    logging.info("UMAP plot for the {} split and label set {}".format(split_name, keyname))
    label_set = np.unique(labels)
    num_classes = len(label_set)
    fig = plt.figure(figsize=(20,15))
    plt.gca().set_prop_cycle(cycler("color", [plt.cm.nipy_spectral(i) for i in np.linspace(0, 0.9, num_classes)]))
    for i in range(num_classes):
        idx = labels == label_set[i]
        plt.plot(umap_embeddings[idx, 0], umap_embeddings[idx, 1], ".", markersize=1)   
    plt.show()

# Create the tester
tester = testers.GlobalEmbeddingSpaceTester(end_of_testing_hook = hooks.end_of_testing_hook, 
                                            visualizer = umap.UMAP(), 
                                            visualizer_hook = visualizer_hook,
                                            dataloader_num_workers = 32)

end_of_epoch_hook = hooks.end_of_epoch_hook(tester, 
                                            dataset_dict, 
                                            model_folder, 
                                            test_interval = 1,
                                            patience = 1)

In [None]:
trainer = trainers.MetricLossOnly(models,
                                optimizers,
                                batch_size,
                                loss_funcs,
                                mining_funcs,
                                train_dataset,
                                sampler=sampler,
                                dataloader_num_workers = 32,
                                end_of_iteration_hook = hooks.end_of_iteration_hook,
                                end_of_epoch_hook = end_of_epoch_hook)

In [None]:
trainer.train(num_epochs=num_epochs)