<a href="https://colab.research.google.com/github/danielsaggau/IR_LDC/blob/main/model/Untitled61.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Ref Code

In [None]:
import torch
import torch.nn as nn

class NT_Xent(nn.Module):
    def __init__(self, batch_size, temperature):
        super(NT_Xent, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature

        self.mask = self.mask_correlated_samples(batch_size)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)

    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask
    
    def forward(self, out_a, out_b):
        
        N = 2 * self.batch_size

        out = torch.cat((out_a, out_b), dim=0)
        
        ###################################################
        ### Computing Similarity Matrix ###################
        sim_matrix = self.similarity_f(out.unsqueeze(1), out.unsqueeze(0)) / self.temperature
        ###################################################
        

        pos_ab = torch.diag(sim_matrix, self.batch_size)
        pos_ba = torch.diag(sim_matrix, -self.batch_size)

        positives = torch.cat((pos_ab, pos_ba), dim=0).reshape(N, 1)
        negatives = sim_matrix[self.mask].reshape(N, -1)
        
        #######################################################
        ### New loss
        #negatives = negatives.reshape(-1, 1)
        #negatives, negatives_indices = negatives.topk(k=(N-10)*N, largest=False, dim=0)
        #negatives = negatives.reshape(N, -1)
        #######################################################

        labels = torch.zeros(N, dtype=torch.long).to(out.device)
        logits = torch.cat((positives, negatives), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N
        return loss

In [None]:
class BregmanLoss(nn.Module):
    def __init__(self, batch_size, temperature, sigma):
        super(BregmanLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.sigma = sigma

        self.mask = self.mask_correlated_samples(batch_size)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        
        #self.similarity_f = nn.CosineSimilarity(dim=2)

    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask
    
    def b_sim(self, features):
        mm = torch.max(features, dim=1)
        indx_max_features = mm[1]
        max_features = mm[0].reshape(-1, 1)
        
        # Compute the number of active subnets in one batch
        eye = torch.eye(features.shape[1])
        one = eye[indx_max_features]
        num_max = torch.sum(one, dim=0)
        dist_matrix = max_features - features[:, indx_max_features]
        sigma = torch.tensor([self.sigma]).to(features.device)
        sig2 = 2 * torch.pow(sigma, 2)
        sim_matrix = torch.exp(torch.div(-dist_matrix, sig2))
            
        return sim_matrix, num_max

    def forward(self, out_a, out_b):
        
        N = 2 * self.batch_size

        features = torch.cat((out_a, out_b), dim=0)
        
        ###################################################
        ### Computing Similarity Matrix ###################
        sim_matrix, num_max = self.b_sim(features)
        sim_matrix = sim_matrix / self.temperature
        ###################################################
        #sim_matrix = self.similarity_f(out.unsqueeze(1), out.unsqueeze(0)) / self.temperature

        pos_ab = torch.diag(sim_matrix, self.batch_size)
        pos_ba = torch.diag(sim_matrix, -self.batch_size)

        positives = torch.cat((pos_ab, pos_ba), dim=0).reshape(N, 1)
        negatives = sim_matrix[self.mask].reshape(N, -1)

        labels = torch.zeros(N, dtype=torch.long).to(features.device)
        logits = torch.cat((positives, negatives), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N
        return loss, num_max


# Costum Sentence transformer loss and trainer 

In [None]:
from enum import Enum
from typing import Iterable, Dict
import torch.nn.functional as F
from torch import nn, Tensor
!pip install sentence_transformers
from sentence_transformers.SentenceTransformer import SentenceTransformer

In [None]:
import torch
class CustomBregmanLoss(nn.Module):
  '''
  The loss expects input pairs where the positive pair is the specified and all else are assumed to be negative pairs 
  Example 
  '''
  def __init__(self, model: SentenceTransformer, batch_size, temperature, sigma):
    '''
    param model
    param scale 
    '''  
    super(CustomBregmanLoss, self).__init__()
    self.model=model
    self.batch_size = batch_size
    self.temperature = temperature
    self.sigma = sigma
    self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="sum")
    self.mask = self.mask_correlated_samples(batch_size)

  def mask_correlated_samples(self, batch_size): # just for positive and negatives and similar to simclr
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0 # positive pairs
            mask[batch_size + i, i] = 0 # by masking get negative data
        return mask

  def b_sim(self, features):
        mm = torch.max(features, dim=1)
        indx_max_features = mm[1]
        max_features = mm[0].reshape(-1, 1)
        # Compute the number of active subnets in one batch
        eye = torch.eye(features.shape[1])
        one = eye[indx_max_features]
        num_max = torch.sum(one, dim=0)
        dist_matrix = max_features - features[:, indx_max_features]
        sigma = torch.tensor([self.sigma]).to(features.device)
        sig2 = 2 * torch.pow(sigma, 2)
        sim_matrix = torch.exp(torch.div(-dist_matrix, sig2))

        return sim_matrix, num_max

  def forward(self, sentence_features:Iterable[Dict[str,Tensor]], labels:Tensor):
        reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
        embeddings_a = reps[0]
        embeddings_b = torch.cat(reps[1:])
        N = 2 * self.batch_size
        features = torch.cat((embeddings_a, embeddings_b), dim=0)
        
        ###################################################
        ### Computing Similarity Matrix ###################
        sim_matrix, num_max = self.b_sim(features)
        sim_matrix = sim_matrix / self.temperature
        ###################################################
        #sim_matrix = self.similarity_f(out.unsqueeze(1), out.unsqueeze(0)) / self.temperature

        pos_ab = torch.diag(sim_matrix, self.batch_size)
        pos_ba = torch.diag(sim_matrix, -self.batch_size)

        positives = torch.cat((pos_ab, pos_ba), dim=0).reshape(N, 1)
        negatives = sim_matrix[self.mask].reshape(N, -1)

        labels = torch.zeros(N, dtype=torch.long).to(device=features.device)
        scores = torch.cat((positives, negatives), dim=1)
        loss = self.cross_entropy_loss(scores, labels)
        loss /= N
        return loss#, num_max


In [None]:
from sentence_transformers import SentenceTransformer, losses, InputExample
from torch.utils.data import DataLoader
model = SentenceTransformer('distilbert-base-uncased')
train_examples = [InputExample(texts=['Anchor 1', 'Positive 1']), InputExample(texts=['Anchor 2', 'Positive 2'])]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)

In [None]:
train_loss = CustomBregmanLoss(model=model, batch_size=2,temperature=0.1, sigma=2) 
train_loss

CustomBregmanLoss(
  (model): SentenceTransformer(
    (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: DistilBertModel 
    (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  )
  (cross_entropy_loss): CrossEntropyLoss()
)

In [None]:
model.fit([(train_dataloader, train_loss)], show_progress_bar=True)

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
import torch
from torch import nn, Tensor
from typing import Iterable, Dict
from sentence_transformers.SentenceTransformer import SentenceTransformer

class MultipleNegativesRankingLoss(nn.Module):
    """
        This loss expects as input a batch consisting of sentence pairs (a_1, p_1), (a_2, p_2)..., (a_n, p_n)
        where we assume that (a_i, p_i) are a positive pair and (a_i, p_j) for i!=j a negative pair.
        For each a_i, it uses all other p_j as negative samples, i.e., for a_i, we have 1 positive example (p_i) and
        n-1 negative examples (p_j). It then minimizes the negative log-likehood for softmax normalized scores.
        This loss function works great to train embeddings for retrieval setups where you have positive pairs (e.g. (query, relevant_doc))
        as it will sample in each batch n-1 negative docs randomly.
        The performance usually increases with increasing batch sizes.
        For more information, see: https://arxiv.org/pdf/1705.00652.pdf
        (Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4)
        You can also provide one or multiple hard negatives per anchor-positive pair by structering the data like this:
        (a_1, p_1, n_1), (a_2, p_2, n_2)
        Here, n_1 is a hard negative for (a_1, p_1). The loss will use for the pair (a_i, p_i) all p_j (j!=i) and all n_j as negatives.
        Example::
            from sentence_transformers import SentenceTransformer, losses, InputExample
            from torch.utils.data import DataLoader
            model = SentenceTransformer('distilbert-base-uncased')
            train_examples = [InputExample(texts=['Anchor 1', 'Positive 1']),
                InputExample(texts=['Anchor 2', 'Positive 2'])]
            train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
            train_loss = losses.MultipleNegativesRankingLoss(model=model)
    """
    def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_fct = util.cos_sim):
        """
        :param model: SentenceTransformer model
        :param scale: Output of similarity function is multiplied by scale value
        :param similarity_fct: similarity function between sentence embeddings. By default, cos_sim. Can also be set to dot product (and then set scale to 1)
        """
        super(MultipleNegativesRankingLoss, self).__init__()
        self.model = model
        self.scale = scale
        self.similarity_fct = similarity_fct
        self.cross_entropy_loss = nn.CrossEntropyLoss()


    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
        embeddings_a = reps[0]
        embeddings_b = torch.cat(reps[1:])

        scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale
        labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device)  # Example a[i] should match with b[i]
        return self.cross_entropy_loss(scores, labels)

    def get_config_dict(self):
        return {'scale': self.scale, 'similarity_fct': self.similarity_fct.__name__}



In [22]:
def cos_sim(a: Tensor, b: Tensor):
    """
    Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
    :return: Matrix with res[i][j]  = cos_sim(a[i], b[j])
    """
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a)

    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b)

    if len(a.shape) == 1:
        a = a.unsqueeze(0)

    if len(b.shape) == 1:
        b = b.unsqueeze(0)

    a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
    b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
    return torch.mm(a_norm, b_norm.transpose(0, 1))



class BregmanRankingLoss(nn.Module) :
  '''

  '''

  def __init__(self, model: SentenceTransformer, sigma, temperature, batch_size, lambda1, lambda2,scale: float = 20.0, similarity_fct = cos_sim):
        """
        :param model: SentenceTransformer model
        :param scale: Output of similarity function is multiplied by scale value
        :param similarity_fct: similarity function between sentence embeddings. By default, cos_sim. Can also be set to dot product (and then set scale to 1)
        """
        super(BregmanRankingLoss, self).__init__()
        self.model = model
        self.sigma = sigma
        self.temperature = temperature
        self.batch_size = batch_size
        self.lambda1 = lambda1
        self.lambda2=lambda2
        self.scale = scale
        self.similarity_fct = similarity_fct
        self.cross_entropy_loss = nn.CrossEntropyLoss()

  def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=torch.long)#, dtype=bool)
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask

  def b_sim(self, features):
        mm = torch.max(features, dim=1)
        indx_max_features = mm[1]
        max_features = mm[0].reshape(-1, 1)
        # Compute the number of active subnets in one batch
        eye = torch.eye(features.shape[1])
        one = eye[indx_max_features]
        num_max = torch.sum(one, dim=0)
        dist_matrix = max_features - features[:, indx_max_features]
        sigma = torch.tensor([self.sigma]).to(features.device)
        sig2 = 2 * torch.pow(sigma, 2)
        sim_matrix = torch.exp(torch.div(-dist_matrix, sig2))

        return sim_matrix, num_max

  def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features] # get output main model
        embeddings_a = reps[0] 
        embeddings_b = torch.cat(reps[1:])

        scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale
        labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device)  # Example a[i] should match with b[i]
        rloss = self.cross_entropy_loss(scores, labels)
        
        # bregman part 

        N = 2 * self.batch_size
        features = torch.cat((embeddings_a, embeddings_b), dim=0)
        
        ###################################################
        ### Computing Similarity Matrix ###################
        sim_matrix, num_max = self.b_sim(features)
        sim_matrix = sim_matrix / self.temperature
        ###################################################
        #sim_matrix = self.similarity_f(out.unsqueeze(1), out.unsqueeze(0)) / self.temperature

        pos_ab = torch.diag(sim_matrix, self.batch_size)
        pos_ba = torch.diag(sim_matrix, -self.batch_size)

        positives = torch.cat((pos_ab, pos_ba), dim=0).reshape(N, 1)
        negatives = sim_matrix[self.mask].reshape(N, -1)

        blabel = torch.zeros(N, dtype=torch.long).to(device=features.device)
        bscores = torch.cat((positives, negatives), dim=1)
        bloss = self.cross_entropy_loss(bscores, blabel)
        bloss /= N
        loss = lambda1 * bloss + lambda2 * rloss 
        return loss

In [None]:
from sentence_transformers import SentenceTransformer, losses, InputExample
from torch.utils.data import DataLoader
model = SentenceTransformer('distilbert-base-uncased')
train_examples = [InputExample(texts=['Anchor 1', 'Positive 1']), InputExample(texts=['Anchor 2', 'Positive 2'])]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
train_loss = BregmanRankingLoss(model=model, batch_size=2,temperature=0.1, sigma=2, lambda1=1, lambda2=1) 

In [24]:
train_loss

BregmanRankingLoss(
  (model): SentenceTransformer(
    (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: DistilBertModel 
    (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  )
  (cross_entropy_loss): CrossEntropyLoss()
)

# trainer in original specification

In [None]:
class Trainer():
  def _init_(self,
             model,
             optimizer,
             scheduler, 
             temperature,
             num_cls,
             epochs,
             sigma,
             lambda,
             device):
    super(Trainer, self)._init_()
    self.model=model
    self.optimizer=optimizer
    self.scheduler=scheduler
    self.temperature=temperature
    self.num_cls=num_cls
    self.epochs=epochs
    self.sigma=sigma
    self.lambda
    self.device=device
    self.mixed_loss=True

)
 # train for one epoch to learn unique features
  def train(self, data_loader, epoch):
        self.model.train()
        batch_size = data_loader.batch_size
        #bloss = BregMarginLoss(batch_size)
        bloss = BregmanLoss(batch_size, self.temperature, self.sigma)
        nt_xent = NT_Xent(batch_size, self.temperature)
        
        total_loss, total_num, tot_max, train_bar = 0.0, 0, 0, tqdm(data_loader)
        tot_bloss, tot_nt_xent = 0.0, 0.0
        num_max = torch.tensor([0])
        for [aug_1, aug_2], target in train_bar:
            aug_1, aug_2 = aug_1.to(self.device), aug_2.to(self.device)
            feature_1, out_1 = self.model(aug_1)
            feature_2, out_2 = self.model(aug_2)

            # compute loss
            loss, num_max = bloss(out_1, out_2)
            tot_bloss += loss.item() * batch_size
            if self.mixed_loss:
                loss1 = nt_xent(feature_1, feature_2)
                tot_nt_xent += loss1.item() * batch_size
                loss = loss + self.lmbda * loss1
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            tot_max += num_max
            total_num += batch_size
            total_loss += loss.item() * batch_size
            train_bar.set_description(
                '{}Train{} {}Epoch:{} [{}/{}] {}Loss:{}  {:.4f} {}Active Subs:{} [{}/{}]'
                .format(
                    bcolors.OKCYAN, bcolors.ENDC,
                    bcolors.WARNING, bcolors.ENDC,
                    epoch,
                    self.epochs,
                    bcolors.WARNING, bcolors.ENDC,
                    total_loss / total_num,
                    bcolors.WARNING, bcolors.ENDC,
                    len(torch.where(tot_max>10)[0]),
                    tot_max.shape[0]))
            
        # warmup with nt_xent loss for the first 50 epochs
        #if epoch >= 100:
        self.scheduler.step()

        return (total_loss/total_num,
                tot_bloss/total_num,
                tot_nt_xent/total_num,
                self.scheduler.get_last_lr()[0])
    
    
  def bregman_sim(self, feature, feature_bank):
        # [B, 1]
        mf = torch.max(feature, dim=1)
        # [N, 1]
        mfb = torch.max(feature_bank, dim=1)
        indx_max_feature_bank = mfb[1]
        max_feature = mf[0].reshape(-1, 1)
        # [B, N]
        dist_matrix = max_feature - feature[:, indx_max_feature_bank]
        # Computing Similarity from Bregman distance
        sigma = torch.tensor([1.]).to(self.device)
        sigma = 2 * torch.pow(sigma, 2)
        sim_matrix = torch.exp(torch.div(-dist_matrix, sigma))
        
        return sim_matrix
        
    # test for one epoch, use weighted knn to find the most similar images' label to assign the test image
  def test(self, memory_data_loader, test_data_loader, k_nn, epoch):
        self.model.eval()
        total_top1, total_top5, total_num, feature_bank, feature_labels = 0.0, 0.0, 0, [], []
        
        with torch.no_grad():
            # generate feature bank
            
            for [data, _], target in tqdm(memory_data_loader,
                                        desc=f'{bcolors.OKBLUE}Feature extracting{bcolors.ENDC}'):
                feature, out = self.model(data.to(self.device))
                feature_bank.append(out)
                feature_labels.append(target)
            # [N, D]
            feature_bank = torch.cat(feature_bank, dim=0)
            feature_labels = torch.cat(feature_labels, dim=0).long().to(self.device)
            # [N]
            
            #feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=self.device)
            # loop test data to predict the label by weighted knn search
            test_bar = tqdm(test_data_loader)
            for [data, _], target in test_bar:
                data, target = data.to(self.device), target.to(self.device)
                feature, out = self.model(data)

                total_num += data.size(0)
                # compute bregman similarity between each feature vector and feature bank ---> [B, N]
                sim_matrix = self.bregman_sim(out, feature_bank)
                # [B, K]
                sim_weight, sim_indices = sim_matrix.topk(k=k_nn, dim=-1)
                # [B, K]
                sim_labels = torch.gather(feature_labels.expand(data.size(0), -1),
                                          dim=-1,
                                          index=sim_indices)
                sim_weight = (sim_weight / self.temperature).exp()

                # counts for each class
                one_hot_label = torch.zeros(data.size(0) * k_nn, self.num_cls, device=self.device)
                # [B*K, C]
                one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
                # weighted score ---> [B, C]
                pred_scores = torch.sum(one_hot_label.view(
                    data.size(0), -1, self.num_cls) * sim_weight.unsqueeze(dim=-1), dim=1)

                pred_labels = pred_scores.argsort(dim=-1, descending=True)
                total_top1 += torch.sum(
                    (pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
                total_top5 += torch.sum(
                    (pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
                
                test_bar.set_description(
                    '{}Test{}  {}Epoch:{} [{}/{}] {}Acc@1: {}{:.2f}% {}Acc@5: {}{:.2f}%'.format(
                    bcolors.OKCYAN, bcolors.ENDC,
                    bcolors.WARNING, bcolors.ENDC,
                    epoch,
                    self.epochs,
                    bcolors.WARNING, bcolors.ENDC,
                    (total_top1 / total_num) * 100,
                    bcolors.WARNING, bcolors.ENDC,
                    (total_top5 / total_num) * 100))

        return (total_top1 / total_num) * 100, (total_top5 / total_num) * 100