### Fine-tune sentence transformers using a pair of similar sentences (one sentence is a summary of the other) without labels using Pytorch. The summaries are generated by extracting random sentences from input text.
Case 2 in https://huggingface.co/blog/how-to-train-sentence-transformers

In [None]:
!pip install sentence_transformers==2.2.2

In [None]:
import pandas as pd
import numpy as np
import string
import re
import spacy
from numpy.linalg import norm
from sentence_transformers import InputExample
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer
from sentence_transformers import losses

In [3]:
def preprocess(txt):
    '''
    Function to preprocess the article
    '''
    txt = re.sub(r'^By \. [\w\s]+ \. ', ' ', txt) # By . Ellie Zolfagharifard . 
    txt = re.sub(r'\d{1,2}\:\d\d [a-zA-Z]{3}', ' ', txt) # 10:30 EST
    txt = re.sub(r'\d{1,2} [a-zA-Z]+ \d{4}', ' ', txt) # 10 November 1990
    txt = txt.replace('PUBLISHED:', ' ')
    txt = txt.replace('UPDATED', ' ')
    txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
    txt = txt.replace(' : ', ' ')
    txt = txt.replace('(CNN)', ' ')
    txt = txt.replace('--', ' ')
    txt = re.sub(r'^\s*[\,\.\:\'\;\|]', ' ', txt) # remove puncts at beginning of sent
    txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
    txt = " ".join(txt.split())
    return txt

In [4]:
def get_summary(txt):
    '''
    Function that returns summary by concatenating 
    50% of sentences having length > median length of sentences
    '''
    txt_split = re.split(r'\.', txt)
    median_sent_len = np.median([len(sent.split()) for sent in txt_split])
    sents = [sent for sent in txt_split if len(sent.split()) >= median_sent_len]
    n_sents = len(sents)
    np.random.seed(42)
    rand_idx = np.random.choice(range(n_sents), size=int(0.5 * n_sents), replace=False)
    rand_idx.sort()
    rand_sents = [sents[i] for i in rand_idx]
    summary = '. '.join(rand_sents)
    return summary.strip()

In [5]:
def get_data(path: str, n_samples: int=50000):
    data = pd.read_csv(path, nrows=n_samples)
    data['article'] = data['article'].map(preprocess)
    data['highlights'] = data['article'].map(get_summary)
    return data

In [6]:
data = get_data("train.csv")

In [10]:
# text
data['article'][1]

'Ralph Mata was an internal affairs lieutenant for the Miami-Dade Police Department, working in the division that investigates allegations of wrongdoing by cops. Outside the office, authorities allege that the 45-year-old longtime officer worked with a drug trafficking organization to help plan a murder plot and get guns. A criminal complaint unsealed in U.S. District Court in New Jersey Tuesday accuses Mata, also known as "The Milk Man," of using his role as a police officer to help the drug trafficking organization in exchange for money and gifts, including a Rolex watch. In one instance, the complaint alleges, Mata arranged to pay two assassins to kill rival drug dealers. The killers would pose as cops, pulling over their targets before shooting them, according to the complaint. "Ultimately, the (organization) decided not to move forward with the murder plot, but Mata still received a payment for setting up the meetings," federal prosecutors said in a statement. The complaint also a

In [11]:
# summary
data['highlights'][1]

'Ralph Mata was an internal affairs lieutenant for the Miami-Dade Police Department, working in the division that investigates allegations of wrongdoing by cops.  District Court in New Jersey Tuesday accuses Mata, also known as "The Milk Man," of using his role as a police officer to help the drug trafficking organization in exchange for money and gifts, including a Rolex watch.  "Ultimately, the (organization) decided not to move forward with the murder plot, but Mata still received a payment for setting up the meetings," federal prosecutors said in a statement.  Mata has worked for the Miami-Dade Police Department since 1992, including directing investigations in Miami Gardens and working as a lieutenant in the K-9 unit at Miami International Airport, according to the complaint.  Mata faces charges of aiding and abetting a conspiracy to distribute cocaine, conspiring to distribute cocaine and engaging in monetary transactions in property derived from specified unlawful activity'

In [12]:
def finetune_model(data: pd.DataFrame, cols_to_use: list=['article', 'highlights'], 
                   model_id: str="distilbert-base-nli-mean-tokens", 
                   batch_size: int=32, epochs: int=2):
    model = SentenceTransformer(model_id)
    train_examples = []
    col1, col2 = cols_to_use
    for row in data.iterrows():
        train_examples.append(InputExample(texts=[row[1][col1], row[1][col2]]))
        
    train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
    # recommended loss for cases having 2 sentences with no labels
    train_loss = losses.MultipleNegativesRankingLoss(model=model)
    model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=epochs)
    model_save_path = model_id + '_finetuned'
    model.save(model_save_path)
    return model_save_path

In [13]:
non_finetuned_model_id = "distilbert-base-nli-mean-tokens"
non_finetuned_model = SentenceTransformer(non_finetuned_model_id)

finetuned_model_id = finetune_model(data=data)
finetuned_model = SentenceTransformer(finetuned_model_id)

Downloading .gitattributes:   0%|          | 0.00/690 [00:00<?, ?B/s]

Downloading 1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading README.md:   0%|          | 0.00/3.99k [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/550 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/265M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading tokenizer_config.json:   0%|          | 0.00/450 [00:00<?, ?B/s]

Downloading vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1563 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1563 [00:00<?, ?it/s]

In [14]:
class TextPreprocessor:
    def __init__(self, remove_punct: bool = True, remove_digits: bool = True,
                 remove_stop_words: bool = True):
        self.remove_punct = remove_punct
        self.remove_digits = remove_digits
        self.remove_stop_words = remove_stop_words
        self.stop_words = ['i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you',
                           'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself',
                           'she', 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them',
                           'their', 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that',
                           'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has',
                           'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 'the', 'and', 'if', 'or',
                           'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 'about',
                           'into', 'through', 'during', 'before', 'after', 'to', 'from',
                           'in', 'out', 'on', 'off', 'further', 'then', 'once',
                           'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each',
                           'other', 'such', 'only', 'own', 'same', 'so', 'than',
                           'too', 'can', 'will', 'just', 'should',
                           'now']


    @staticmethod
    def __remove_double_whitespaces(string: str):
        return " ".join(string.split())
    

    def __remove_punct(self, string_series: pd.Series):
        """
       Removes punctuations from the input string.
       :param string_series: pd.Series, input string series
       :return: pd.Series, cleaned string series
       """
        clean_string_series = string_series.copy()
        puncts = [r'\n', r'\r', r'\t']
        puncts.extend(list(string.punctuation))
        for i in puncts:
            clean_string_series = clean_string_series.str.replace(pat=i, repl=" ", regex=False).copy()
        return clean_string_series.map(self.__remove_double_whitespaces)

    def __remove_digits(self, string_series: pd.Series):
        """
       Removes digits from the input string.
       :param string_series: pd.Series, input string series
       :return: pd.Series, cleaned string series
       """
        clean_string_series = string_series.str.replace(pat=r'\d', repl=" ", regex=True).copy()
        return clean_string_series.map(self.__remove_double_whitespaces)
 

    def __remove_stop_words(self, string_series: pd.Series):
        """
       Removes stop words from the input string.
       :param string_series: pd.Series, input string series
       :return: pd.Series, cleaned string series
       """
        def str_remove_stop_words(string: str):
            stops = self.stop_words
            return " ".join([token for token in string.split() if token not in stops])

        return string_series.map(str_remove_stop_words)

    
    def preprocess(self, string_series: pd.Series, dataset: str = "train"):
        """
        Entry point.
        :param string_series: pd.Series, input string series
        :param dataset: str, "train" for training set, "tesrt" for val/dev/test set.
        :return: pd.Series, cleaned string series
        """
        string_series = string_series.str.lower()
        if self.remove_punct:
            string_series = self.__remove_punct(string_series=string_series)
        if self.remove_digits:
            string_series = self.__remove_digits(string_series=string_series)
        if self.remove_stop_words:
            string_series = self.__remove_stop_words(string_series=string_series)
        

        string_series = string_series.str.strip()
        string_series.replace(to_replace="", value="this is an empty message", inplace=True)

        return string_series

In [15]:
text_preprocessor = TextPreprocessor()

In [16]:
def get_sent_transformer_embeddings(sent_transformer, txt):
    '''
    Function to get sentence embeddings from SentenceTransformer using specified model
    '''
    embedding = sent_transformer.encode(txt, show_progress_bar=False)
    return embedding

In [17]:
def get_similarity_score(emb1, emb2):
    '''
    Function to compute cosine-similarity score.
    '''
    cos_sim = np.dot(emb1, emb2) / (norm(emb1) * norm(emb2))
    return cos_sim

In [18]:
def compare_models(txt1, txt2):
    '''
    Function to return cosine similarity scores of embeddings of fine-tuned and non fine-tuned models
    '''  
    txt1 = text_preprocessor.preprocess(pd.Series(txt1))[0]
    txt2 = text_preprocessor.preprocess(pd.Series(txt2))[0]
    sent_emb1 = get_sent_transformer_embeddings(non_finetuned_model, txt1)
    sent_emb2 = get_sent_transformer_embeddings(non_finetuned_model, txt2)
    
    finetuned_model_emb1 = get_sent_transformer_embeddings(finetuned_model, txt1)
    finetuned_model_emb2 = get_sent_transformer_embeddings(finetuned_model, txt2)
    
    print(f'Similarity score of fine-tuned model: {get_similarity_score(finetuned_model_emb1, finetuned_model_emb2)}')
    print(f'Similarity score of non-fine-tuned model: {get_similarity_score(sent_emb1, sent_emb2)}')

In [19]:
txt1 = '''
This test mission was a failure: SpaceX did not achieve all the goals it set out for,
and both the Starship spacecraft and Super Heavy booster exploded over the ocean.

But there were some big highlights for SpaceX.

The rocket made it much further into its flight profile than during the first flight 
attempt in April, when Starship began tumbling tail-over-head about four minutes after liftoff.
The Starship never even separated from the Super Heavy booster during that test.

This time, however, SpaceX did achieve that milestone: About two and a half minutes 
into flight, the Starship powered up its engines and successfully broke away using a 
brand new method called "hot staging."
'''

txt2 = '''
The Starship system made it much further into flight than the first attempt in April, 
but ultimately ended in another explosion.

The rocket and spacecraft safely lifted off the pad, with the Super Heavy booster 
igniting all 33 of its engines. During the last attempt, multiple engines shut down prematurely.

Then, the Super Heavy booster and Starship spacecraft successfully separated, as 
the Starship lit up its engines and pushed away.
'''

compare_models(txt1, txt2)

Similarity score of fine-tuned model: 0.8778677582740784
Similarity score of non-fine-tuned model: 0.8902333974838257


In [20]:
txt1 = '''
Israeli forces launched a raid Wednesday on Gaza’s largest hospital, Al-Shifa, 
after accusing Hamas of operating from tunnels beneath the vast complex – a claim denied 
by the militant group and hospital officials.

Thousands of Palestinians are believed to be sheltering in and around the hospital, 
which the UN said had become the “epicenter” of fighting in the area, trapping vulnerable 
patients, staff and displaced Palestinians as they run out of medical supplies and fuel.

The hospital’s main building has effectively ceased functioning, with doctors working by 
candlelight and wrapping premature babies in foil to keep them alive – with some warning 
the situation inside has become “catastrophic.”
'''

txt2 = '''
This test mission was a failure: SpaceX did not achieve all the goals it set out for, and 
both the Starship spacecraft and Super Heavy booster exploded over the ocean.

But there were some big highlights for SpaceX.

The rocket made it much further into its flight profile than during the first flight 
attempt in April, when Starship began tumbling tail-over-head about four minutes after 
liftoff. The Starship never even separated from the Super Heavy booster during that test.

This time, however, SpaceX did achieve that milestone: About two and a half minutes into 
flight, the Starship powered up its engines and successfully broke away using a brand new method called "hot staging."
'''

compare_models(txt1, txt2)

Similarity score of fine-tuned model: 0.1420629322528839
Similarity score of non-fine-tuned model: 0.6971122026443481


In [21]:
txt1 = '''
Israeli forces launched a raid Wednesday on Gaza’s largest hospital, Al-Shifa, 
after accusing Hamas of operating from tunnels beneath the vast complex – a claim denied 
by the militant group and hospital officials.

Thousands of Palestinians are believed to be sheltering in and around the hospital, 
which the UN said had become the “epicenter” of fighting in the area, trapping vulnerable 
patients, staff and displaced Palestinians as they run out of medical supplies and fuel.

The hospital’s main building has effectively ceased functioning, with doctors working by 
candlelight and wrapping premature babies in foil to keep them alive – with some warning 
the situation inside has become “catastrophic.”
'''

txt2 = '''
Gaza’s largest hospital, Al-Shifa, has become a flashpoint in Israel’s war in the enclave, 
which began when Hamas militants crossed the border into Israel on October 7, killing around 1,200 people.

Palestinians and humanitarian agencies say the current fighting in and around Al-Shifa is 
proof of Israel’s wanton disregard for civilian life in Gaza, while Israel accuses Hamas of using 
the medical center as a shield for its operations.

Since launching its operation at Al-Shifa this week, the Israel Defense Forces (IDF) claimed 
it found a tunnel shaft and military equipment, but has yet to show proof of the large-scale command 
and control center it alleges is there. Hamas denies the allegations. CNN has not verified the claims of either Israel or Hamas.
'''

compare_models(txt1, txt2)

Similarity score of fine-tuned model: 0.8547863364219666
Similarity score of non-fine-tuned model: 0.8923118114471436


In [22]:
txt1 = '''
Renewed fighting between the Myanmar Armed Forces (MAF) and the Arakan Army (AA) 
has displaced more than 26,000 people in the 
country’s western Rakhine state since Monday, according to the United Nations.

In a statement Friday, the United Nations Office for the Coordination of Humanitarian 
Affairs (UNOCHA) said the latest figures bring the total number of internally displaced 
people due to conflict between the two sides to approximately 90,000.

Eleven deaths and more than 30 injuries have been reported since an informal ceasefire 
agreed a year ago broke on November 13, the statement read.

More than 100 people have reportedly been detained by the MAF and five by the AA, it added.

Battles between the military and resistance groups have unfolded almost daily across 
Myanmar since army general Min Aung Hlaing seized power in February 2021, plunging the 
country into economic chaos and fresh civil war.
'''

txt2 = '''
Gaza’s largest hospital, Al-Shifa, has become a flashpoint in Israel’s war in the enclave, 
which began when Hamas militants crossed the border into Israel on October 7, killing around 1,200 people.

Palestinians and humanitarian agencies say the current fighting in and around Al-Shifa is 
proof of Israel’s wanton disregard for civilian life in Gaza, while Israel accuses Hamas of using 
the medical center as a shield for its operations.

Since launching its operation at Al-Shifa this week, the Israel Defense Forces (IDF) claimed 
it found a tunnel shaft and military equipment, but has yet to show proof of the large-scale command 
and control center it alleges is there. Hamas denies the allegations. CNN has not verified the claims of either Israel or Hamas.
'''

compare_models(txt1, txt2)

Similarity score of fine-tuned model: 0.3097432553768158
Similarity score of non-fine-tuned model: 0.7666799426078796


In [23]:
txt1 = '''
Renewed fighting between the Myanmar Armed Forces (MAF) and the Arakan Army (AA) 
has displaced more than 26,000 people in the 
country’s western Rakhine state since Monday, according to the United Nations.

In a statement Friday, the United Nations Office for the Coordination of Humanitarian 
Affairs (UNOCHA) said the latest figures bring the total number of internally displaced 
people due to conflict between the two sides to approximately 90,000.

Eleven deaths and more than 30 injuries have been reported since an informal ceasefire 
agreed a year ago broke on November 13, the statement read.

More than 100 people have reportedly been detained by the MAF and five by the AA, it added.

Battles between the military and resistance groups have unfolded almost daily across 
Myanmar since army general Min Aung Hlaing seized power in February 2021, plunging the 
country into economic chaos and fresh civil war.
'''

txt2 = '''
Ellen Burstyn, beloved actor who starred in the groundbreaking 1973 horror film 
“The Exorcist,” has returned to the nightmare-inducing franchise for the first 
time since its debut 50 years ago in the franchise’s next installment, “The Exorcist: Believer.”

The Oscar-winner made her onscreen return in the film’s first trailer on Wednesday, 
reprising her iconic role of Chris MacNeil, an actress who is forever haunted by her 
daughter Regan’s (Linda Blair) possession as seen in the original movie.

“Have you ever seen anything like this?” Leslie Odom Jr. – who stars in “Believer” as a 
widower trying to save his possessed daughter (Lidya Jewett) and her friend (Olivia Marcum) 
– asks “Handmaid’s Tale” actor Ann Dowd in the trailer.

“No, but there are people out there who have,” Dowd replies.

Burstyn then appears on the screen as that person who’s seen such a sight before, 
and their journey to save the two possessed girls begins.
'''

compare_models(txt1, txt2)

Similarity score of fine-tuned model: 0.10945425927639008
Similarity score of non-fine-tuned model: 0.6152119040489197


In [24]:
txt1 = '''
Renewed fighting between the Myanmar Armed Forces (MAF) and the Arakan Army (AA) 
has displaced more than 26,000 people in the country’s western Rakhine state since 
Monday, according to the United Nations.

In a statement Friday, the United Nations Office for the Coordination of 
Humanitarian Affairs (UNOCHA) said the latest figures bring the total number of 
internally displaced people due to conflict between the two sides to approximately 90,000.

Eleven deaths and more than 30 injuries have been reported since an informal 
ceasefire agreed a year ago broke on November 13, the statement read.

More than 100 people have reportedly been detained by the MAF and five by the AA, it added.

Battles between the military and resistance groups have unfolded almost daily 
across Myanmar since army general Min Aung Hlaing seized power in February 2021, 
plunging the country into economic chaos and fresh civil war.
'''

txt2 = '''
Luis Díaz scored two goals to help Colombia stun Brazil in a World Cup qualifier 
on Thursday, as his father watched on from the stands just a week after being released by kidnappers.

Luis Manuel Díaz looked overcome with emotion as he witnessed his son net an 
impressive second-half brace to fire Colombia to a 2-1 win inside packed Estadio Metropolitano.
On October 28, Díaz Sr. was abducted along with his wife, Cilenis Marulanda, by ELN gunmen in his 
hometown of Barrancas, northeastern Colombia. Marulanda was rescued later that day, but Díaz Sr. 
was handed over just under two weeks later to a mixed commission of UN personnel and Catholic 
priests on Thursday in nearby city Valledupar.
'''

compare_models(txt1, txt2)

Similarity score of fine-tuned model: 0.18763847649097443
Similarity score of non-fine-tuned model: 0.7308792471885681


The non-finetuned model is found to be returning a higher similarity score for every sentence pair compared to the fine-tuned model