In [1]:
import pandas as pd
import numpy as np

import ir_datasets

import math
import logging
from datetime import datetime
import sys
import os
import gzip
import csv
import random

from pathlib import Path
from typing import List, Dict, Tuple, Iterable, Type, Union, Callable

import transformers
from sentence_transformers import models, losses, datasets
from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SentenceEvaluator 

import torch
from torch import nn, Tensor, device
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from sklearn.model_selection import train_test_split
from fastcore.basics import store_attr

import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
seed = 0
seed_everything(seed, workers=True)

from pytorch_lightning.loggers import WandbLogger

import config

Global seed set to 0


In [2]:
dataset_name = config.DATASET
save_path = Path(f"data/{dataset_name}")
df = pd.read_pickle(save_path/"data.pkl")
df.head()

Unnamed: 0,query_text,doc_text,relevance
0,How does Quora look to a moderator?,What does the Quora website look like to membe...,1
1,How do I refuse to chose between different thi...,Is it possible to pursue many different things...,1
2,Did Ben Affleck shine more than Christian Bale...,"According to you, whose Batman performance was...",1
3,Did Ben Affleck shine more than Christian Bale...,"No fanboys please, but who was the true batman...",1
4,Did Ben Affleck shine more than Christian Bale...,Who do you think portrayed Batman better: Chri...,1


In [3]:
df.relevance.value_counts()

0    7626
1    7626
Name: relevance, dtype: int64

In [4]:
train_size = 0.8

print(len(df))
df_train, df_val = train_test_split(df, train_size=train_size, stratify=df.relevance, random_state=seed)
print(len(df_train), len(df_val))
df_val.head()

15252
12201 3051


Unnamed: 0,query_text,doc_text,relevance
2289,What are the factors that can make your credit...,What are the best ways to learn Cloud Computing?,0
6499,What is a fun board game to play with only 2 p...,Why don't we have fusion reactors in power sta...,0
3620,Which business/startup should I start in Nagpu...,Which is the best business to start in nagpur?,1
2375,Why are most airliners painted white?,Why are Aeroplanes painted white?,1
1383,What are the top MBA colleges in the world?,Do you believe in free will or determinism?,0


In [5]:
train_samples = []
for row in df_train.itertuples():
    train_samples.append(InputExample(texts=[row.query_text, row.doc_text], label=row.relevance))

test_samples = []
for row in df_val.itertuples():
    test_samples.append(InputExample(texts=[row.query_text, row.doc_text], label=row.relevance))

print(test_samples[0])

<InputExample> label: 0, texts: What are the factors that can make your credit score drop?; What are the best ways to learn Cloud Computing?


## Training

In [6]:
model_name = config.MODEL
train_batch_size = 128  #The larger you select this, the better the results (usually). But it requires more GPU memory
val_batch_size = 128
max_seq_length = 128
num_epochs = 1

In [7]:
model_name = config.MODEL
model = SentenceTransformer(model_name)
model_name

'paraphrase-mpnet-base-v2'

In [8]:
class DataModule(pl.LightningDataModule):
    def __init__(self, train_batch_size=32, val_batch_size=32):
        super().__init__()

        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size

    def prepare_data(self):
        self.train_data = train_samples
        self.val_data = test_samples

    def setup(self, stage=None):
        pass

    def train_dataloader(self):
        train_dataloader = datasets.NoDuplicatesDataLoader(self.train_data, batch_size=self.train_batch_size)
        return train_dataloader

    def val_dataloader(self):
        val_dataloader = datasets.NoDuplicatesDataLoader(self.val_data, batch_size=self.val_batch_size)
        return val_dataloader

In [9]:
class SentenceTransformerModel(pl.LightningModule):
    def __init__(self, 
                loss_model,
                max_seq_length: int = 128,
                evaluator: SentenceEvaluator = None,
                epochs: int = 1,
                steps_per_epoch = None,
                scheduler: str = 'WarmupLinear',
                warmup_steps: int = 10000,
                optimizer_class: Type[Optimizer] = transformers.AdamW,
                optimizer_params : Dict[str, object]= {'lr': 2e-5},
                weight_decay: float = 0.01,
                ):
        
        super(SentenceTransformerModel, self).__init__()
        self.save_hyperparameters()
        store_attr("loss_model, epochs, weight_decay, optimizer_class, optimizer_params, steps_per_epoch, scheduler, warmup_steps")
        self.loss_model.max_seq_length = max_seq_length
    
#     def on_epoch_start(self):
#         print('\n')

    def forward(self, features, labels):
        loss = self.loss_model(features, labels)
        return loss

    def training_step(self, data, batch_idx):
        features, labels = self.loss_model.model.smart_batching_collate(data)
        loss = self.forward(features, labels)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, data, batch_idx):
        #TODO: dev_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, batch_size=train_batch_size, name='sts-dev')
        features, labels = self.loss_model.model.smart_batching_collate(data)
        loss = self.forward(features, labels)
        # _, preds = torch.max(logits, dim=1)
        # val_acc = accuracy_score(preds.cpu(), batch["label"].cpu())
        # val_acc = torch.tensor(val_acc)
        self.log("val_loss", loss, prog_bar=True)
        # self.log("val_acc", val_acc, prog_bar=True)
#         return loss

    def configure_optimizers(self):
        # return torch.optim.Adam(self.parameters(), lr=self.hparams["lr"])
        param_optimizer = list(self.loss_model.named_parameters())

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': self.weight_decay},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

        optimizer = self.optimizer_class(optimizer_grouped_parameters, **self.optimizer_params)
        
#         if self.steps_per_epoch is None or self.steps_per_epoch == 0:
#             self.steps_per_epoch = min([len(dataloader) for dataloader in dataloaders])

#         num_train_steps = int(self.steps_per_epoch * self.epochs)
#         scheduler_obj = self.loss_model.model._get_scheduler(optimizer, scheduler=self.scheduler, warmup_steps=self.warmup_steps, t_total=num_train_steps)

#         return [[optimizer], [scheduler_obj]]
        return optimizer

In [10]:
from pytorch_lightning.callbacks import ProgressBar, ModelCheckpoint

class LitProgressBar(ProgressBar):
    def on_train_epoch_end(self, *args, **kwargs):
        super().on_train_epoch_end(*args, **kwargs)
        print()
        
checkpoint_callback = ModelCheckpoint(dirpath="./models", monitor="val_loss", mode="min")
early_stop_callback = EarlyStopping(monitor='val_loss', min_delta=0.00, patience=5, verbose=True, mode='auto')

In [11]:
wandb_logger = WandbLogger(project="lightning-sentence-transformers", name="test", reinit=True)

In [13]:
pl_data = DataModule()
loss_model = losses.MultipleNegativesRankingLoss(model)
steps_per_epoch = 476
stl_model = SentenceTransformerModel(loss_model, steps_per_epoch=steps_per_epoch,)

#TODO: Add learning rate scheduler
trainer = pl.Trainer(
    default_root_dir="logs",
    gpus=(1 if torch.cuda.is_available() else 0),
    max_epochs=10,
    fast_dev_run=False,
    gradient_clip_val=1.0,
    amp_backend='native',
    amp_level='O2',
    precision=16,
    auto_lr_find=True,
    auto_scale_batch_size=False,
    auto_select_gpus=True,
#     callbacks=[LitProgressBar()],
#     logger=pl.loggers.TensorBoardLogger("logs/", name=model_name, version=1),
    logger=wandb_logger,
    deterministic=True,
)

trainer.fit(stl_model, pl_data)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Currently logged in as: [33mpratik[0m (use `wandb login --relogin` to force relogin)



  | Name       | Type                         | Params
------------------------------------------------------------
0 | loss_model | MultipleNegativesRankingLoss | 109 M 
------------------------------------------------------------
109 M     Trainable params
0         Non-trainable params
109 M     Total params
437.946   Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Global seed set to 0


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…


