<a href="https://colab.research.google.com/github/danielsaggau/IR_LDC/blob/main/experiment_contrastive_bregman_simple_examples.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# contrastive loss
!pip install sentence_transformers

In [None]:
# Costum dataloader for torch 
class Dataloader(self): 
    def __init__(self):
        pass



In [11]:
from enum import Enum
from typing import Iterable, Dict
import torch.nn.functional as F
from torch import nn, Tensor
from sentence_transformers.SentenceTransformer import SentenceTransformer

class DistanceMetric(Enum):
    """
    The metric for the contrastive loss
    """
    EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
    MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
    COSINE_DISTANCE = lambda x, y: 1-F.cosine_similarity(x, y)

class ContrastiveLoss(nn.Module):
    """
    Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the
    two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased.
    Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    :param model: SentenceTransformer model
    :param distance_metric: Function that returns a distance between two emeddings. The class SiameseDistanceMetric contains pre-defined metrices that can be used
    :param margin: Negative samples (label == 0) should have a distance of at least the margin value.
    :param size_average: Average by the size of the mini-batch.
    Example::
        from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample
        from torch.utils.data import DataLoader
        model = SentenceTransformer('all-MiniLM-L6-v2')
        train_examples = [
            InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1),
            InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)]
        train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2)
        train_loss = losses.ContrastiveLoss(model=model)
        model.fit([(train_dataloader, train_loss)], show_progress_bar=True)
    """

    def __init__(self, model: SentenceTransformer, distance_metric=DistanceMetric.COSINE_DISTANCE, margin: float = 0.5, size_average:bool = True):
        super(ContrastiveLoss, self).__init__()
        self.distance_metric = distance_metric
        self.margin = margin
        self.model = model
        self.size_average = size_average

    def get_config_dict(self):
        distance_metric_name = self.distance_metric.__name__
        for name, value in vars(DistanceMetric).items():
            if value == self.distance_metric:
                distance_metric_name = "DistanceMetric.{}".format(name)
                break

        return {'distance_metric': distance_metric_name, 'margin': self.margin, 'size_average': self.size_average}

    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
        assert len(reps) == 2
        rep_anchor, rep_other = reps
        distances = self.distance_metric(rep_anchor, rep_other)
        losses = 0.5 * (labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2))
        return losses.mean() if self.size_average else losses.sum()

In [None]:
from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample
from torch.utils.data import DataLoader
model = SentenceTransformer('all-MiniLM-L6-v2')
train_examples = [
 InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1),
 InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2)
train_loss = ContrastiveLoss(model=model)
model.fit([(train_dataloader, train_loss)], show_progress_bar=True)

In [None]:
class BregmanLoss(nn.Module):
    """
    The Bregman loss. Expect as Input 

    """ 

    def __init__(self, batch_size, model: SentenceTransformer,sigma, temperature, margin: float = 0.5, size_average:bool = True):
        super(BregmanLoss, self).__init__()
        self.margin = margin
        self.model = model
        self.temperature = temperature
        self.sigma = sigma
        self.batch_size = batch_size
        self.size_average = size_average
        self.mask = self.mask_correlated_samples(batch_size)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")

    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask
    
    def b_sim(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):): # double check the issue here 
        
        reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
        assert len(reps) == 2
        rep_anchor, rep_other = reps
        
        mm = torch.max(reps, dim=1) # was features instead of reps 
        indx_max_features = mm[1]
        max_features = mm[0].reshape(-1, 1)
        
        # Compute the number of active subnets in one batch
        eye = torch.eye(features.shape[1])
        one = eye[indx_max_features]
        num_max = torch.sum(one, dim=0)
        
        dist_matrix = max_features - features[:, indx_max_features]
        
        case = 2
        if case == 0:
            m2 = torch.divide(dist_matrix, torch.max(dist_matrix))
            sim_matrix = torch.divide(torch.tensor([1]).to(features.device), m2 + 1)
            
        if case == 1:
            gamma = torch.tensor([1]).to(features.device)
            sim_matrix = torch.exp(torch.mul(-dist_matrix, gamma))
            
        if case == 2:
            sigma = torch.tensor([self.sigma]).to(features.device)
            sig2 = 2 * torch.pow(sigma, 2)
            sim_matrix = torch.exp(torch.div(-dist_matrix, sig2))
        
        if case == 3:
            sim_matrix = 1 - dist_matrix
            
        return sim_matrix, num_max

    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        
        N = 2 * self.batch_size
        ###################################################
        ### Computing Similarity Matrix ###################
        sim_matrix, num_max = self.b_sim(features) # todo features vs reps
        sim_matrix = sim_matrix / self.temperature
        ##################################################

        pos_ab = torch.diag(sim_matrix, self.batch_size)
        pos_ba = torch.diag(sim_matrix, -self.batch_size)

        positives = torch.cat((pos_ab, pos_ba), dim=0).reshape(N, 1)
        negatives = sim_matrix[self.mask].reshape(N, -1)

        labels = torch.zeros(N, dtype=torch.long).to(features.device)
        logits = torch.cat((positives, negatives), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N
        return loss, num_max

In [None]:
!pip install torch
import torch
import torch.nn as nn
train_loss = BregmanLoss(model = model, sigma = 1, temperature =0.2, batch_size = 1)

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
model.fit([(train_dataloader, train_loss)], show_progress_bar=True)

# Triplet implementation 


In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import os
import math
import csv

import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import tensorflow_datasets as tfds
import numpy as np

from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
#from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score

from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
import tqdm
from ipywidgets import IntProgress

Instructions for updating:
non-resource variables are not supported in the long term


In [2]:
def _pairwise_divergences(embed):

    max_out = tf.math.argmax(embed, 1, output_type=tf.dtypes.int32)
    one_to_n = tf.range(tf.shape(embed)[0], dtype=tf.dtypes.int32)
    max_indices = tf.transpose(tf.stack([one_to_n, max_out]))
    max_values = tf.gather_nd(embed, max_indices)
    max_values_repeated = tf.transpose(tf.reshape(tf.tile(max_values, [tf.shape(embed)[0]]), [tf.shape(embed)[0], tf.shape(embed)[0]]))
    repeated_max_out = tf.tile(max_out, [tf.shape(embed)[0]])
    repeated_one_to_n = tf.tile(one_to_n, [tf.shape(embed)[0]])
    mat_rotn = tf.reshape(tf.transpose(tf.reshape(repeated_one_to_n, [tf.shape(embed)[0], tf.shape(embed)[0]])), [-1])
    new_max_indices = tf.transpose(tf.stack([mat_rotn, repeated_max_out]))
    new_max_values = tf.gather_nd(embed, new_max_indices)
    reshaped_new_max_values = tf.reshape(new_max_values, [tf.shape(embed)[0], tf.shape(embed)[0]])
    div_matrix = tf.maximum(tf.subtract(max_values_repeated, reshaped_new_max_values), 0.0)  
    
#    #for differentiability, this version uses softmax instead of argmax
#    sftmx = tf.nn.softmax(tf.multiply(1.0, embed))
#    ES = tf.linalg.matmul(embed, sftmx, transpose_b=True)
#    one_vec = tf.reshape(tf.ones([tf.shape(embed)[0]]), [1, tf.shape(embed)[0]])
#    diag_ES = tf.reshape(tf.linalg.diag_part(ES), [1, tf.shape(embed)[0]])
#    max_outputs = tf.linalg.matmul(diag_ES, one_vec, transpose_a=True)
#    div_matrix = tf.maximum(tf.subtract(max_outputs, ES), 0.0)
    
    return div_matrix

In [3]:
def _get_anchor_positive_triplet_mask(labels):

    indices_equal = tf.cast(tf.eye(tf.shape(labels)[0]), tf.bool)
    indices_not_equal = tf.logical_not(indices_equal)
    labels_equal = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
    mask = tf.logical_and(indices_not_equal, labels_equal)
    return mask

In [4]:
def _get_anchor_negative_triplet_mask(labels):

    labels_equal = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))

    mask = tf.logical_not(labels_equal)

    return mask

def _get_triplet_mask(labels):

    indices_equal = tf.cast(tf.eye(tf.shape(labels)[0]), tf.bool)
    indices_not_equal = tf.logical_not(indices_equal)
    i_not_equal_j = tf.expand_dims(indices_not_equal, 2)
    i_not_equal_k = tf.expand_dims(indices_not_equal, 1)
    j_not_equal_k = tf.expand_dims(indices_not_equal, 0)
    distinct_indices = tf.logical_and(tf.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)
    label_equal = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
    i_equal_j = tf.expand_dims(label_equal, 2)
    i_equal_k = tf.expand_dims(label_equal, 1)
    valid_labels = tf.logical_and(i_equal_j, tf.logical_not(i_equal_k))
    mask = tf.logical_and(distinct_indices, valid_labels)
    return mask

In [5]:
def batch_all_triplet_loss(labels, embeddings, margin, squared=False, breg=False):

    if breg:
        pairwise_dist = _pairwise_divergences(embeddings)
    else:
        pairwise_dist = _pairwise_distances(embeddings, squared=True)  
    anchor_positive_dist = tf.expand_dims(pairwise_dist, 2)
    assert anchor_positive_dist.shape[2] == 1, "{}".format(anchor_positive_dist.shape)
    anchor_negative_dist = tf.expand_dims(pairwise_dist, 1)
    assert anchor_negative_dist.shape[1] == 1, "{}".format(anchor_negative_dist.shape)
    
    triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
    mask = _get_triplet_mask(labels)
    mask = tf.to_float(mask)
    triplet_loss = tf.multiply(mask, triplet_loss)

    triplet_loss = tf.maximum(triplet_loss, 0.0)
    valid_triplets = tf.to_float(tf.greater(triplet_loss, 1e-16))
    num_positive_triplets = tf.reduce_sum(valid_triplets)
    num_valid_triplets = tf.reduce_sum(mask)
    fraction_positive_triplets = num_positive_triplets / (num_valid_triplets + 1e-16)
    triplet_loss = tf.reduce_sum(triplet_loss) / (num_positive_triplets + 1e-16)

    return triplet_loss, fraction_positive_triplets

In [13]:
train_examples 

[<sentence_transformers.readers.InputExample.InputExample at 0x7fb6df669350>,
 <sentence_transformers.readers.InputExample.InputExample at 0x7fb6df6693d0>]