<a href="https://colab.research.google.com/github/danielsaggau/IR_LDC/blob/main/Bregman_Sbert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import torch
from torch import nn, Tensor
from typing import Union, Tuple, List, Iterable, Dict
import torch.nn.functional as F
from enum import Enum
#from Sentence_Transformer import SentenceTransformer

In [26]:
class TripletDistanceMetric(Enum):
    """
    The metric for the triplet loss
    """
    COSINE = lambda x, y: 1 - F.cosine_similarity(x, y)
    EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
    MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)

In [17]:
class TripletLoss(nn.Module):
    """
    This class implements triplet loss. Given a triplet of (anchor, positive, negative),
    the loss minimizes the distance between anchor and positive while it maximizes the distance
    between anchor and negative. It compute the following loss function:
    loss = max(||anchor - positive|| - ||anchor - negative|| + margin, 0).
    Margin is an important hyperparameter and needs to be tuned respectively.
    For further details, see: https://en.wikipedia.org/wiki/Triplet_loss
    :param model: SentenceTransformerModel
    :param distance_metric: Function to compute distance between two embeddings. The class TripletDistanceMetric contains common distance metrices that can be used.
    :param triplet_margin: The negative should be at least this much further away from the anchor than the positive.
    Example::
        from sentence_transformers import SentenceTransformer,  SentencesDataset, LoggingHandler, losses
        from sentence_transformers.readers import InputExample
        model = SentenceTransformer('distilbert-base-nli-mean-tokens')
        train_examples = [InputExample(texts=['Anchor 1', 'Positive 1', 'Negative 1']),
            InputExample(texts=['Anchor 2', 'Positive 2', 'Negative 2'])]
        train_dataset = SentencesDataset(train_examples, model)
        train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
        train_loss = losses.TripletLoss(model=model)
    """
    def __init__(self, model: SentenceTransformer, distance_metric=TripletDistanceMetric.EUCLIDEAN, triplet_margin: float = 5):
        super(TripletLoss, self).__init__()
        self.model = model
        self.distance_metric = distance_metric
        self.triplet_margin = triplet_margin


    def get_config_dict(self):
        distance_metric_name = self.distance_metric.__name__
        for name, value in vars(TripletDistanceMetric).items():
            if value == self.distance_metric:
                distance_metric_name = "TripletDistanceMetric.{}".format(name)
                break

        return {'distance_metric': distance_metric_name, 'triplet_margin': self.triplet_margin}

    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]

        rep_anchor, rep_pos, rep_neg = reps
        distance_pos = self.distance_metric(rep_anchor, rep_pos)
        distance_neg = self.distance_metric(rep_anchor, rep_neg)

        losses = F.relu(distance_pos - distance_neg + self.triplet_margin)
        return losses.mean()

In [None]:
!pip install sentence_transformers
from sentence_transformers import SentenceTransformer,  SentencesDataset, LoggingHandler, losses
from sentence_transformers.readers import InputExample
from torch.utils.data import DataLoader

In [28]:
model = SentenceTransformer('distilbert-base-nli-mean-tokens')
train_examples = [InputExample(texts=['Anchor 1', 'Positive 1', 'Negative 1']),
InputExample(texts=['Anchor 2', 'Positive 2', 'Negative 2'])]
train_dataset = SentencesDataset(train_examples, model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=1)
train_loss = losses.TripletLoss(model=model, distance_metric = TripletDistanceMetric.MANHATTAN)

In [None]:
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=1, warmup_steps=100)

In [None]:
#Our sentences we like to encode
sentences = ['This framework generates embeddings for each input sentence',
    'Sentences are passed as a list of string.', 
    'The quick brown fox jumps over the lazy dog.']

#Sentences are encoded by calling model.encode()
embeddings = model.encode(sentences)

#Print the embeddings
for sentence, embedding in zip(sentences, embeddings):
    print("Sentence:", sentence)
    print("Embedding:", embedding)
    print("")

In [31]:
from sentence_transformers import SentenceTransformer, util
# Two lists of sentences
sentences1 = ['The cat sits outside',
             'A man is playing guitar',
             'The new movie is awesome']

sentences2 = ['The dog plays in the garden',
              'A woman watches TV',
              'The new movie is so great']

#Compute embedding for both lists
embeddings1 = model.encode(sentences1, convert_to_tensor=True)
embeddings2 = model.encode(sentences2, convert_to_tensor=True)

#Compute cosine-similarities
cosine_scores = util.cos_sim(embeddings1, embeddings2)

#Output the pairs with their score
for i in range(len(sentences1)):
    print("{} \t\t {} \t\t Score: {:.4f}".format(sentences1[i], sentences2[i], cosine_scores[i][i]))

The cat sits outside 		 The dog plays in the garden 		 Score: 0.3298
A man is playing guitar 		 A woman watches TV 		 Score: 0.2403
The new movie is awesome 		 The new movie is so great 		 Score: 0.9838


In [None]:
      # Augmentation for text so masking 
      '''
      params::tau augmentation function 
      Q: Masking in contrastive is just masking correlated samples
      Q: Check SBERT
      '''

      # projection 

      '''
      params::f projection function 
      params:: 
      '''

      #subnetworks 
      
      '''
      params:: d bregman subnetworks 
      params:: 
      '''

      #concatenation

      '''
      function::
      '''

      # computing triplet loss 


      # Computing bregman loss 
    def Divergence():
      '''params:: logits_b
         params:: labels_b
      '''

     # combine the two losses

     TripletBregmanloss = lambda * TripletLoss + BregmanLoss #(lambda =2 in other paper) 


     # SGD update
     TripletBregmanloss.backward()
     update(model.params)
  
  def D(o1, o2):
    p_star = argmax(o1)
    q_star = argmax(o2)

    # Bregman divergence (Eq.6)
    return o1[p_star] - o1[q_star]


# conversion D to phi 
def phi(D, sigma): 
  
phi =


# Pairwise Divergence for computation 

In [None]:
class BregmannLoss(nn.Module):
      """ The Bregman loss should take a triplet (anchor, negative, positive) computing the loss for all valid triplets
      Arguments: 
      :param model: 
      :param embedding:
      negative
      Returns:
      Example:: 
      """
      def __init__(self, batch_size, temperature, sigma):
        super(BregmanLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.sigma = sigma
        self.mask = self.mask_correlated_samples(batch_size)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        
      def pairwise_divergence(embeddings, squared =False):
          max_output = torch.argmax(embeddings,1, output_type=torch.Tensor.int32)
          one_to_n = torch.range(torch.Tensor.int(embedding)[0], output_type=torch.Tensor.int32)
          max_indices = 
          max_val =
          max_val_repeated =
          repeated_max_out =
          repeated_one_to_n = 

      def forward(self, out_a, out_b): 
        similarity_matrix = ()
        features = torch.cat((out_a, out_b), dim=0)
        
#computing triplet 
        anchor_positive = torch.diag()
        anchor_negative = torch.diag()

      return div_matrix

# Masking triplets 


In [None]:
      def get_anchor_positive_mask(labels): 
        # indices equal =
        indicies_not_equal = torch.logical_not(indicies_equal)
        labels_equal = torch.equal(torch.unsqueeze(labels,0), torch.expand_dims(labels,1))
        mask = torch.logical_not(indicies_not_equal, labels_equal)
        return mask

      def get_anchor_negative_mask(labels):
        labels_equal = torch.equal(torch.unsquezee(labels, 0), torch.unsqueeze(labels,1))
        mask = tf.logical_not(labels_equal)
        return mask

      def get_triplet_mask(labels):
        # indices_equal = tf.cast(tf.eye(tf.shape(labels)[0]), tf.bool)
        indicies_equal = torch.double(torch.eye(torch.size(labels)[0]), torch.gt) # what is t and what does this do ?
        indicies_not_equal = torch.logical_not(indicies_equal)
        i_not_j = torch.unsqueeze(indicies_not_equal, 2)
        i_not_k = torch.unsqueeze(indicies_not_equal, 1)
        j_not_k = torch.unsqueeze(indicies_not_equal, 0)
        distinct_indicies = torch.logical_and(torch_logical_and(i_not_j, i_not_k, j_not_k))
        label_equal = torch.equal(torch.unsqueeze(labels, 0),torch.unsqueeze(labels,1))
        i_equal_j = torch.unsqueeze(label_equal, 2)
        i_equal_k = torch.unsqueeze(label_equal, 1) #what does this mean? 
        valid_labels = torch.logical_and(i_equal_j, torch.logical_not(i_equal_k))
        mask = torch.logical_and(distinct_indicies, valid_labels)  
        return mask 



# Compute Batch All Triplet Loss 

Computing Bregman Loss 

# References


In [None]:
# code snippet kubrac
def _pairwise_divergences(embed):

    max_out = tf.math.argmax(embed, 1, output_type=tf.dtypes.int32)
    one_to_n = tf.range(tf.shape(embed)[0], dtype=tf.dtypes.int32)
    max_indices = tf.transpose(tf.stack([one_to_n, max_out]))
    max_values = tf.gather_nd(embed, max_indices)
    max_values_repeated = tf.transpose(tf.reshape(tf.tile(max_values, [tf.shape(embed)[0]]), [tf.shape(embed)[0], tf.shape(embed)[0]]))
    repeated_max_out = tf.tile(max_out, [tf.shape(embed)[0]])
    repeated_one_to_n = tf.tile(one_to_n, [tf.shape(embed)[0]])
    mat_rotn = tf.reshape(tf.transpose(tf.reshape(repeated_one_to_n, [tf.shape(embed)[0], tf.shape(embed)[0]])), [-1])
    new_max_indices = tf.transpose(tf.stack([mat_rotn, repeated_max_out]))
    new_max_values = tf.gather_nd(embed, new_max_indices)
    reshaped_new_max_values = tf.reshape(new_max_values, [tf.shape(embed)[0], tf.shape(embed)[0]])
    div_matrix = tf.maximum(tf.subtract(max_values_repeated, reshaped_new_max_values), 0.0)  
    
#    #for differentiability, this version uses softmax instead of argmax
#    sftmx = tf.nn.softmax(tf.multiply(1.0, embed))
#    ES = tf.linalg.matmul(embed, sftmx, transpose_b=True)
#    one_vec = tf.reshape(tf.ones([tf.shape(embed)[0]]), [1, tf.shape(embed)[0]])
#    diag_ES = tf.reshape(tf.linalg.diag_part(ES), [1, tf.shape(embed)[0]])
#    max_outputs = tf.linalg.matmul(diag_ES, one_vec, transpose_a=True)
#    div_matrix = tf.maximum(tf.subtract(max_outputs, ES), 0.0)
    
    return div_matrix

In [None]:
  # code snippet rezaei
class BregmanLoss(nn.Module):
    def __init__(self, batch_size, temperature, sigma):
        super(BregmanLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.sigma = sigma

        self.mask = self.mask_correlated_samples(batch_size)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        
        #self.similarity_f = nn.CosineSimilarity(dim=2)

    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask
    
    def b_sim(self, features):
        mm = torch.max(features, dim=1)
        indx_max_features = mm[1]
        max_features = mm[0].reshape(-1, 1)
        
        # Compute the number of active subnets in one batch
        eye = torch.eye(features.shape[1])
        one = eye[indx_max_features]
        num_max = torch.sum(one, dim=0)
        
        dist_matrix = max_features - features[:, indx_max_features]
        
        case = 2
        if case == 0:
            m2 = torch.divide(dist_matrix, torch.max(dist_matrix))
            sim_matrix = torch.divide(torch.tensor([1]).to(features.device), m2 + 1)
            
        if case == 1:
            gamma = torch.tensor([1]).to(features.device)
            sim_matrix = torch.exp(torch.mul(-dist_matrix, gamma))
            
        if case == 2:
            sigma = torch.tensor([self.sigma]).to(features.device)
            sig2 = 2 * torch.pow(sigma, 2)
            sim_matrix = torch.exp(torch.div(-dist_matrix, sig2))
        
        if case == 3:
            sim_matrix = 1 - dist_matrix
            
        return sim_matrix, num_max

    def forward(self, out_a, out_b):
        
        N = 2 * self.batch_size

        features = torch.cat((out_a, out_b), dim=0)
        
        ###################################################
        ### Computing Similarity Matrix ###################
        sim_matrix, num_max = self.b_sim(features)
        sim_matrix = sim_matrix / self.temperature
        ###################################################
        #sim_matrix = self.similarity_f(out.unsqueeze(1), out.unsqueeze(0)) / self.temperature

        pos_ab = torch.diag(sim_matrix, self.batch_size)
        pos_ba = torch.diag(sim_matrix, -self.batch_size)

        positives = torch.cat((pos_ab, pos_ba), dim=0).reshape(N, 1)
        negatives = sim_matrix[self.mask].reshape(N, -1)

        labels = torch.zeros(N, dtype=torch.long).to(features.device)
        logits = torch.cat((positives, negatives), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N
        return loss, num_max

In [None]:
  # code snippet rezaei
  # train for one epoch to learn unique features
    def train(self, data_loader, epoch):
        self.model.train()
        batch_size = data_loader.batch_size
        #bloss = BregMarginLoss(batch_size)
        bloss = BregmanLoss(batch_size, self.temperature, self.sigma)
        nt_xent = NT_Xent(batch_size, self.temperature)
        
        total_loss, total_num, tot_max, train_bar = 0.0, 0, 0, tqdm(data_loader)
        tot_bloss, tot_nt_xent = 0.0, 0.0
        num_max = torch.tensor([0])
        for [aug_1, aug_2], target in train_bar:
            aug_1, aug_2 = aug_1.to(self.device), aug_2.to(self.device)
            feature_1, out_1 = self.model(aug_1)
            feature_2, out_2 = self.model(aug_2)

            # compute loss
            loss, num_max = bloss(out_1, out_2)
            tot_bloss += loss.item() * batch_size
            if self.mixed_loss:
                loss1 = nt_xent(feature_1, feature_2)
                tot_nt_xent += loss1.item() * batch_size
                loss = loss + self.lmbda * loss1
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            tot_max += num_max
            total_num += batch_size
            total_loss += loss.item() * batch_size
            train_bar.set_description(
                '{}Train{} {}Epoch:{} [{}/{}] {}Loss:{}  {:.4f} {}Active Subs:{} [{}/{}]'
                .format(
                    bcolors.OKCYAN, bcolors.ENDC,
                    bcolors.WARNING, bcolors.ENDC,
                    epoch,
                    self.epochs,
                    bcolors.WARNING, bcolors.ENDC,
                    total_loss / total_num,
                    bcolors.WARNING, bcolors.ENDC,
                    len(torch.where(tot_max>10)[0]),
                    tot_max.shape[0]))
            
        # warmup with nt_xent loss for the first 50 epochs
        #if epoch >= 100:
        self.scheduler.step()

        return (total_loss/total_num,
                tot_bloss/total_num,
                tot_nt_xent/total_num,
                self.scheduler.get_last_lr()[0])

In [None]:
class BregmanTriplet_loss(nn.Module):
  '''
  Arguments 
  param:: lambda 
  param::

  Example 
  
  '''

In [None]:
# subnetworks

In [None]:
# Computation 



