# Note
This notebook can be run on google colab for improved performance. The code changes necessary for running on this system are commented over the code.

## Data preprocessing

In [None]:
# ! pip install sentence_transformers==0.4.0

In [None]:
import pandas as pd
import sys
from sklearn.model_selection import train_test_split
from sentence_transformers import SentencesDataset, SentenceTransformer, InputExample, losses
from sentence_transformers.evaluation import LabelAccuracyEvaluator
from torch import nn, Tensor
from typing import Iterable, Dict
from torch.utils.data import DataLoader
import math

# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
def country_labeled_sentences(excel_map):
    result = {}
    sent_num = 0
    
    for country, dataframe in excel_map.items():

        new_sents_col = dataframe["Sentence"].dropna()
        new_labels_col= dataframe["Primary Instrument"].dropna()
        
        sentences = list(new_sents_col.apply(lambda x: x.replace("\n", "").strip()))
        label_col = new_labels_col.apply(lambda x: x.replace("(PES)", "").replace("(Bond)", "").strip())
        labels = [[string.strip() for string in label.split(", ")][0] for label in label_col]
        result[country] = {}

        for sent, label in zip(sentences, labels):
            if sent_num not in result[country]:
                result[country][sent_num] = {"text": sent, "labels": [label]}
            else:
                result[country][sent_num]["text"] = sent
                result[country][sent_num]["labels"] = [label]
            
            sent_num += 1
            
    return result

def merge_labels(all_labels, labels_to_merge):
    return [f"{labels_to_merge[0]} & {labels_to_merge[1]}" if label in labels_to_merge else label for label in all_labels]

In [None]:
# Reading data from excel
data_excel = pd.read_excel("../../input/WRI_Policy_Tags.xlsx", engine="openpyxl", sheet_name=None)
# data_excel = pd.read_excel("/content/drive/MyDrive/WRI-LatinAmerica-Talent/Cristina_Policy_Files/WRI_Policy_Tags.xlsx", engine="openpyxl", sheet_name=None)

In [None]:
# Formatting the data
all_labeled_sentences = country_labeled_sentences(data_excel)
labeled_sents = dict()
for sents in all_labeled_sentences.values():
    labeled_sents.update(sents)

# Fitlering out General Incentive and Unknown sentences
filtered_sents_maps = [sent for sent in labeled_sents.values() if sent['labels'][0] not in ["General incentive", "Unknown"]]
all_sents = [sent['text'] for sent in filtered_sents_maps]
all_labels = [sent['labels'][0] for sent in filtered_sents_maps]
all_labels = merge_labels(all_labels, ["Credit", "Guarantee"]) 
label_names = list(set(all_labels))
label_names

## Fine-tuning the embedding model on the labeled data

### Something we can try out:
https://www.sbert.net/examples/training/data_augmentation/README.html#extend-to-your-own-datasets

### Links:
https://github.com/UKPLab/sentence-transformers/issues/350

https://omoindrot.github.io/triplet-loss

### Possible tasks for fine-tuning:
1) Given a pair of sentence embeddings, do they belong to the same category (binary)?

2) Given a sentence and a category embedding, does the sentence belong to the category (binary)?

3) Given a sentence embedding, use a classifier to predict its category (multiclass) [https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/nli/training_nli.py](https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/nli/training_nli.py)

4) Use a triplet loss approach such that sentences (texts) that have the same labels will become close in vector space, while sentences with a different label will be further away [https://github.com/UKPLab/sentencetransformers/blob/master/examples/training/other/training_batch_hard_trec_continue_training.py](https://github.com/UKPLab/sentencetransformers/blob/master/examples/training/other/training_batch_hard_trec_continue_training.py)
   
#### In this notebook **task number 3** is used to fine-tune the model.

In [None]:
# Train test split stratified
X_train, X_test, y_train, y_test = train_test_split(all_sents, all_labels, test_size=0.15, stratify=all_labels, random_state=42)

In [None]:
# Define model to fine-tune
model = SentenceTransformer('stsb-xlm-r-multilingual')
# model = SentenceTransformer('xlm-r-100langs-bert-base-nli-stsb-mean-tokens')

In [None]:
class SoftmaxClassifier(nn.Module):
    """
    This loss adds a softmax classifier on top of the output of the transformer network. 
    It takes a sentence embedding and learns a mapping between it and the corresponding category.
    :param model: SentenceTransformer model
    :param sentence_embedding_dimension: Dimension of your sentence embeddings
    :param num_labels: Number of different labels
    """
    def __init__(self,
                 model: SentenceTransformer,
                 sentence_embedding_dimension: int,
                 num_labels: int):
        super(SoftmaxClassifier, self).__init__()
        self.model = model
        self.num_labels = num_labels
        self.classifier = nn.Linear(sentence_embedding_dimension, num_labels)

    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        # Get batch sentence embeddings
        features = self.model(sentence_features[0])['sentence_embedding']
        
        # Get batch loss
        output = self.classifier(features)
        loss_fct = nn.CrossEntropyLoss()

        if labels is not None:
            loss = loss_fct(output, labels.view(-1))
            return loss
        else:
            return features, output

In [None]:
# Load data samples into batches
train_batch_size = 16
label2int = dict(zip(label_names, range(len(label_names))))
train_samples = []
for sent, label in zip(X_train, y_train):
    label_id = label2int[label]
    train_samples.append(InputExample(texts=[sent], label=label_id))
train_dataset = SentencesDataset(train_samples, model=model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)

# Define the way the loss is computed
classifier = SoftmaxClassifier(model=model, sentence_embedding_dimension=model.get_sentence_embedding_dimension(), num_labels=len(label2int))

# Configure the dev set evaluator - still need to test whether this works
dev_samples = []
for sent, label in zip(X_test, y_test):
    label_id = label2int[label]
    dev_samples.append(InputExample(texts=[sent], label=label_id))
dev_dataset = SentencesDataset(dev_samples, model=model)
dev_dataloader = DataLoader(dev_dataset, shuffle=True, batch_size=train_batch_size)
dev_evaluator = LabelAccuracyEvaluator(dataloader=dev_dataloader, softmax_model=classifier, name='lae-dev')

In [None]:
# Configure the training
num_epochs = 1
warmup_steps = math.ceil(len(train_dataset) * num_epochs / train_batch_size * 0.1)  # 10% of train data for warm-up
model_save_path = "../../output/FineTuning"
# model_save_path = "/content/drive/MyDrive/WRI-LatinAmerica-Talent/Modeling/FineTuning"

In [None]:
# Train the model
model.fit(train_objectives=[(train_dataloader, classifier)],
          evaluator=dev_evaluator,
          epochs=num_epochs,
          evaluation_steps=1000,
          warmup_steps=warmup_steps,
          output_path=model_save_path
          )

In [None]:
# Load the saved model and obtain random sentence embedding
load_model = SentenceTransformer(model_save_path)
load_model.encode(all_sents[0])