## Introduction
Semantic textual similarity deals with determining how similar a pair of text documents are. The goal of the first task is to implement a new architecture by combining the ideas from papers
- Siamese Recurrent Architectures for Learning Sentence Similarity, Jonas Mueller et. al (will be referred as the AAAI paper)
- A Structured Self-Attentive Sentence Embedding, Zhouhan Lin et. al (will be referred as the ICLR paper) <br/><br/>
Furthermore, you'd be evaluating whether the new architecture improves the results of **Siamese Recurrent Architectures for Learning Sentence Similarity, Jonas Mueller et. al**. Your overall network architecture should look similar to the following figure. 
![Untitled%20Diagram.drawio%20%281%29.png](https://raw.githubusercontent.com/shahrukhx01/ocr-test/main/download.png)
<br/><br/>


Moreover, you'd be required to implement further helper functions that these papers propose i.e., attention penalty term for loss, etc.

### SICK dataset
We will use SICK dataset throughout the project (at least in the first two tasks). To get more information about the dataset you can refer to the original [paper](http://www.lrec-conf.org/proceedings/lrec2014/pdf/363_Paper.pdf) on the dataset. You can download the dataset using one of the following links:
- [dataset page 1](https://marcobaroni.org/composes/sick.html)
- [dataset page 2](https://huggingface.co/datasets/sick)    

The relevant columns for the project are `sentence_A`, `sentence_B`, `relatedness_score`, where `relatedness_score` is the label. <br><br>
**Hint: For each task make sure to decide whether the label should be normalized or not.**<br><br>

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import warnings
import sts_data
from importlib import reload
from test import evaluate_test_set
warnings.filterwarnings('ignore')

## Part 1. Data pipeline (3 points)
Before starting working on the model, we must configure the data pipeline to load the data in the correct format. Please, implement the functions for processing the data.

### Part 1.1 Loading and preprocessing the data (1 point)
Download the SICK dataset and store it in [pandas](https://pandas.pydata.org/docs/index.html) `Dataframe`'s. You should use the official data split.  

Implement `load_data` method of `STSData` class in `sts_data.py`. The method must download the dataset and perform basic preprocessing. Minimal preprocessing required:  
1. normalize text to lower case
2. remove punctuations  
3. remove [stopwords](https://en.wikipedia.org/wiki/Stop_word) - we provided you with the list of English stopwords.
4. Optionally, any other preprocessing that you deem necessary.

All the preprocessing code must be contained in the `preprocessing.py` file.  
You can use Hugginface's [datasets library](https://huggingface.co/docs/datasets/) for easy dataset download.

### Part 1.2 Building vocabulary (1 point)
Before we can feed our text to the model it must be vectorized. We use 300 dimensional pretrained [FastText embeddings](https://fasttext.cc/docs/en/english-vectors.html) for mapping words to vectors. To know more general information about embeddings you can refer to [this video](https://www.youtube.com/watch?v=ERibwqs9p38) (even though, we use different types of embeddings - FastText vs Word2Vec described in the video - the general purpose of them is the same).  
In order to apply the embedding, we must first construct the vocabulary for data. Complete the `create_vocab` method of `STSData` class in `sts_data.py` where you concatenate each sentence pair, tokenize it and construct the vocabulary for the whole training data. You should use [torchtext](https://torchtext.readthedocs.io/en/latest/data.html
) for processing the data. For tokenization, you can use any library (or write your own tokenizer), but we recommend you to use tokenizer by [spacy](https://spacy.io/). Use the `fasttext.simple.300d` as pretrained vectors.  
In the end, you must have a vocabulary object capable of mapping your input to corresponding vectors. Remember that the vocabulary is created using only training data (not touching validation or test data).

### Part 1.3 Creating DataLoader (1 point)
Implement `get_data_loader` method of `STSData` class in `sts_data.py`. It must perform the following operations on each of the data splits:
1. vectorize each pair of the sentences by replacing all tokens with their index in vocabulary
2. normalize labels
3. convert everything to PyTorch tensors
4. pad every sentence so that all of them have the same length
5. create `STSDataset` from `dataset.py`
6. create PyTorch DataLoader out of the created dataset. 


We have provided you with the interfaces of possible helper functions, but you can change them as you need.   
In the end, you must have 3 data loaders for each of the splits.

In [3]:
reload(sts_data)
from sts_data import STSData
import warnings
warnings.filterwarnings('ignore')
columns_mapping = {
        "sent1": "sentence_A",
        "sent2": "sentence_B",
        "label": "relatedness_score",
    }
dataset_name = "sick"
sick_data = STSData(
    dataset_name=dataset_name,
    columns_mapping=columns_mapping,
    normalize_labels=True,
    normalization_const=5.0,
)
batch_size = 64
sick_dataloaders = sick_data.get_data_loader(batch_size=batch_size)

INFO:root:loading and preprocessing data...


Downloading and preparing dataset sick/default (download: 212.48 KiB, generated: 2.50 MiB, post-processed: Unknown size, total: 2.71 MiB) to /home/ibrahimssd/.cache/huggingface/datasets/sick/default/0.0.0/c6b3b0b44eb84b134851396d6d464e5cb8f026960519d640e087fe33472626db...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset sick downloaded and prepared to /home/ibrahimssd/.cache/huggingface/datasets/sick/default/0.0.0/c6b3b0b44eb84b134851396d6d464e5cb8f026960519d640e087fe33472626db. Subsequent calls will reuse this data.


INFO:root:reading and preprocessing data completed...
INFO:root:creating vocabulary...
INFO:torchtext.vocab:Loading vectors from .vector_cache/wiki.simple.vec.pt
INFO:root:creating vocabulary completed...


## Part 2. Model Configuration & Hyperparameter Tuning (3 points)
In this part, you are required to define a model capable of learning self-attentive sentence embeddings described in [this ICLR paper](https://arxiv.org/pdf/1703.03130.pdf). The sentence embedding learned by this model will be used for computing the similarity score instead of the simpler embeddings described in the original AAAI paper.  
Please familiarize yourself with the model described in the ICLR paper and implement `SiameseBiLSTMAttention` and `SelfAttention` classes in `siamese_lstm_attention.py`. Remember that you must run the model on each sentence in the sentence pair to calculate the similarity between them. You can use `similarity_score` from `utils.py` to compute the similarity score between two sentences. 
  
To get more theoretical information about attention mechanisms you can refer to [this chapter](https://web.stanford.edu/~jurafsky/slp3/10.pdf) of ["Speech and Language Processing" book](https://web.stanford.edu/~jurafsky/slp3/) by Dan Jurafsky and James H. Martin, where the attention mechanism is described in the context of the machine translation task. 

Finally, once your implementation works on the default parameters stated below, make sure to perform **hyperparameter tuning** to find the best combination of hyperparameters.

In [4]:
output_size = 1
hidden_size = 64
vocab_size = len(sick_data.vocab)
embedding_size = 300
embedding_weights = sick_data.vocab.vectors
lstm_layers = 4
learning_rate = 0.01
fc_hidden_size = 64
max_epochs = 100
bidirectional = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## self attention config
self_attention_config = {
    "hidden_size": 150,  ## refers to variable 'da' in the ICLR paper
    "output_size": 20,  ## refers to variable 'r' in the ICLR paper
    "penalty": 2e-1,  ## refers to penalty coefficient term in the ICLR paper
}

In [5]:
## init siamese lstm

from siamese_lstm_attention import SiameseBiLSTMAttention

siamese_lstm_attention = SiameseBiLSTMAttention(
    batch_size=batch_size,
    output_size=output_size,
    hidden_size=hidden_size,
    vocab_size=vocab_size,
    embedding_size=embedding_size,
    embedding_weights=embedding_weights,
    lstm_layers=lstm_layers,
    self_attention_config=self_attention_config,
    fc_hidden_size=fc_hidden_size,
    device=device,
    bidirectional=bidirectional,
)

# optimizer = torch.optim.SGD(siamese_lstm_attention.parameters(), lr=0.001, momentum=0.9)
optimizer = torch.optim.Adam(siamese_lstm_attention.parameters(),lr=learning_rate, betas=(0.9, 0.999), eps=1e-08)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
## move model to device
siamese_lstm_attention.to(device)

SiameseBiLSTMAttention(
  (embeddings): Embedding(2029, 300)
  (lookup_table): Embedding(2029, 300)
  (bi_lstm): LSTM(300, 64, num_layers=4, bias=False, batch_first=True, bidirectional=True)
  (SelfAtt): SelfAttention(
    (ws1): Linear(in_features=128, out_features=150, bias=False)
    (ws2): Linear(in_features=150, out_features=20, bias=False)
    (tanh): Tanh()
    (softmax): Softmax(dim=None)
  )
  (normalize): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=2560, out_features=64, bias=True)
  (tanh): Tanh()
)

## Part 3. Training (2 points)  
Perform the final training of the model by implementing functions in `train.py` after setting values of your best-chosen hyperparameters. Note you can use the same training function when performing hyperparameter tuning.
- **What is a good choice of performance metric here for evaluating your model?** [Max 2-3 lines]

- The main metric used here is the Pearson correlation to evaluate the linear relationship between the two continuous variables (predictions and scores). So pearson coefficient correlation is a good choice for this task based on the state-of-art papers.

- **What other performance evaluation metric can we use here for this task? Motivate your answer.**[Max 2-3 lines]

- Spearman correlation could also be used as a metric. It is based on the ranked values and evaluates the monotonic relationship as well. Explained_variance score could also be a choice to explain the variance between correlated continuous values. All of the mentioned metrics are a good tool for evaluating the relation between two continous group.


In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
from train import train_model
import warnings
import pickle

warnings.filterwarnings('ignore')
siamese_lstm_attention = train_model(
    model=siamese_lstm_attention,
    optimizer=optimizer,
    scheduler=scheduler,
    dataloader=sick_dataloaders,
    data=sick_data,
    max_epochs=max_epochs,
    clip=True,
    config_dict={
        "device": device,
        "model_name": "siamese_lstm_attention",
        "self_attention_config": self_attention_config,
    },
)


#load and save model
model_filename = "lstm_model.sav"
saved_model = pickle.dump(siamese_lstm_attention, open(model_filename,'wb'))
print('Model is saved into to disk successfully')

  0%|          | 0/50 [00:00<?, ?it/s]INFO:root:Evaluating accuracy on dev set
INFO:root:new model saved
INFO:root:Train loss: 6.4187774658203125 - Train pearson score: 0.13878308269081924 -- Validation loss: 5.561148643493652 - Validation pearson score: 0.11024243164888813- Validation                  p_value: 0.00012070761559507383
  2%|▏         | 1/50 [00:42<34:39, 42.45s/it]

[0/49] train_loss: 6.419, train_score: 0.139 


INFO:root:Evaluating accuracy on dev set
INFO:root:new model saved
INFO:root:Train loss: 5.099574089050293 - Train pearson score: 0.10736160276461369 -- Validation loss: 4.620316982269287 - Validation pearson score: 0.2096014655316269- Validation                  p_value: 4.16795600158878e-06
  4%|▍         | 2/50 [01:23<33:30, 41.88s/it]

[1/49] train_loss: 5.100, train_score: 0.107 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 4.503004550933838 - Train pearson score: 0.10050299520823366 -- Validation loss: 4.283022880554199 - Validation pearson score: 0.1159030148036601- Validation                  p_value: 0.0033209428244082326
  6%|▌         | 3/50 [02:01<32:06, 40.99s/it]

[2/49] train_loss: 4.503, train_score: 0.101 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 4.255911827087402 - Train pearson score: 0.1778635520487137 -- Validation loss: 4.136762619018555 - Validation pearson score: 0.15936731799218992- Validation                  p_value: 2.1776481279871214e-05
  8%|▊         | 4/50 [02:34<29:27, 38.43s/it]

[3/49] train_loss: 4.256, train_score: 0.178 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.8501875400543213 - Train pearson score: 0.11717939243482907 -- Validation loss: 3.728193521499634 - Validation pearson score: 0.036257144921688154- Validation                  p_value: 0.1394507926259779
 10%|█         | 5/50 [03:04<26:54, 35.89s/it]

[4/49] train_loss: 3.850, train_score: 0.117 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.5747339725494385 - Train pearson score: 0.06741640679015867 -- Validation loss: 3.747328519821167 - Validation pearson score: 0.1840757753835015- Validation                  p_value: 7.048027439443743e-05
 12%|█▏        | 6/50 [03:34<25:05, 34.23s/it]

[5/49] train_loss: 3.575, train_score: 0.067 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.533460855484009 - Train pearson score: 0.08952831919473614 -- Validation loss: 3.455205202102661 - Validation pearson score: 0.07178036278036569- Validation                  p_value: 0.000989820110524137
 14%|█▍        | 7/50 [04:05<23:44, 33.12s/it]

[6/49] train_loss: 3.533, train_score: 0.090 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.5009970664978027 - Train pearson score: 0.1375464806917153 -- Validation loss: 3.469712495803833 - Validation pearson score: 0.145512414070901- Validation                  p_value: 0.0038753203576590094
 16%|█▌        | 8/50 [04:35<22:34, 32.25s/it]

[7/49] train_loss: 3.501, train_score: 0.138 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.4042751789093018 - Train pearson score: 0.15330611065207383 -- Validation loss: 3.417280912399292 - Validation pearson score: 0.07120200043445063- Validation                  p_value: 0.43445109890827893
 18%|█▊        | 9/50 [05:05<21:36, 31.62s/it]

[8/49] train_loss: 3.404, train_score: 0.153 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.460498809814453 - Train pearson score: 0.07805694150800967 -- Validation loss: 3.4055111408233643 - Validation pearson score: 0.1280383926871517- Validation                  p_value: 0.003197255704026946
 20%|██        | 10/50 [05:35<20:49, 31.24s/it]

[9/49] train_loss: 3.460, train_score: 0.078 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.379757881164551 - Train pearson score: 0.1943851581228037 -- Validation loss: 3.4780843257904053 - Validation pearson score: 0.17457940353631018- Validation                  p_value: 4.73796208585001e-05
 22%|██▏       | 11/50 [06:06<20:06, 30.94s/it]

[10/49] train_loss: 3.380, train_score: 0.194 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.7189998626708984 - Train pearson score: 0.03407597053822603 -- Validation loss: 3.4572525024414062 - Validation pearson score: 0.17902801072324867- Validation                  p_value: 2.586147719159484e-06
 24%|██▍       | 12/50 [06:36<19:31, 30.82s/it]

[11/49] train_loss: 3.719, train_score: 0.034 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.4655375480651855 - Train pearson score: 0.11973385930995539 -- Validation loss: 3.517150402069092 - Validation pearson score: 0.18145710915006508- Validation                  p_value: 4.1978181379551105e-05
 26%|██▌       | 13/50 [07:07<19:04, 30.94s/it]

[12/49] train_loss: 3.466, train_score: 0.120 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.4008071422576904 - Train pearson score: 0.20506928089347626 -- Validation loss: 3.393239736557007 - Validation pearson score: 0.18566924815009261- Validation                  p_value: 1.1093664871617125e-05
 28%|██▊       | 14/50 [07:39<18:38, 31.06s/it]

[13/49] train_loss: 3.401, train_score: 0.205 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.372438669204712 - Train pearson score: 0.2496713991197393 -- Validation loss: 3.377286911010742 - Validation pearson score: 0.1927730385682219- Validation                  p_value: 0.00019896235800455003
 30%|███       | 15/50 [08:09<18:01, 30.89s/it]

[14/49] train_loss: 3.372, train_score: 0.250 


INFO:root:Evaluating accuracy on dev set
INFO:root:new model saved
INFO:root:Train loss: 3.346240758895874 - Train pearson score: 0.2919224358895444 -- Validation loss: 3.3710713386535645 - Validation pearson score: 0.233160690188824- Validation                  p_value: 1.6086438165349005e-06
 32%|███▏      | 16/50 [08:39<17:17, 30.52s/it]

[15/49] train_loss: 3.346, train_score: 0.292 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.37420916557312 - Train pearson score: 0.2496914607286705 -- Validation loss: 3.382174253463745 - Validation pearson score: 0.20376365693029277- Validation                  p_value: 2.165437162243457e-05
 34%|███▍      | 17/50 [09:09<16:47, 30.52s/it]

[16/49] train_loss: 3.374, train_score: 0.250 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.404787540435791 - Train pearson score: 0.21658230920179133 -- Validation loss: 3.559771776199341 - Validation pearson score: 0.16266093567585482- Validation                  p_value: 0.0014755474184707166
 36%|███▌      | 18/50 [09:40<16:16, 30.53s/it]

[17/49] train_loss: 3.405, train_score: 0.217 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.410992383956909 - Train pearson score: 0.1973769418915563 -- Validation loss: 3.4299182891845703 - Validation pearson score: 0.1571563002181424- Validation                  p_value: 0.0015503200843522576
 38%|███▊      | 19/50 [10:10<15:44, 30.48s/it]

[18/49] train_loss: 3.411, train_score: 0.197 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.3955883979797363 - Train pearson score: 0.23668599658601064 -- Validation loss: 3.3901658058166504 - Validation pearson score: 0.20904246499753104- Validation                  p_value: 4.903881569749884e-06
 40%|████      | 20/50 [10:40<15:07, 30.24s/it]

[19/49] train_loss: 3.396, train_score: 0.237 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.3984575271606445 - Train pearson score: 0.20840727772360626 -- Validation loss: 3.4298956394195557 - Validation pearson score: 0.22246336193518462- Validation                  p_value: 2.2038500851772716e-06
 42%|████▏     | 21/50 [11:10<14:34, 30.17s/it]

[20/49] train_loss: 3.398, train_score: 0.208 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.3675460815429688 - Train pearson score: 0.2590426660330252 -- Validation loss: 3.407527208328247 - Validation pearson score: 0.22045627988278557- Validation                  p_value: 1.1696176094006004e-05
 44%|████▍     | 22/50 [11:41<14:12, 30.43s/it]

[21/49] train_loss: 3.368, train_score: 0.259 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.37738299369812 - Train pearson score: 0.24261887798755627 -- Validation loss: 3.381460428237915 - Validation pearson score: 0.20571894882876185- Validation                  p_value: 4.585418016101991e-05
 46%|████▌     | 23/50 [12:12<13:48, 30.70s/it]

[22/49] train_loss: 3.377, train_score: 0.243 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.361255645751953 - Train pearson score: 0.2674297336651077 -- Validation loss: 3.4175117015838623 - Validation pearson score: 0.19829075207870933- Validation                  p_value: 1.6916657729195807e-05
 48%|████▊     | 24/50 [12:44<13:26, 31.02s/it]

[23/49] train_loss: 3.361, train_score: 0.267 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.3642919063568115 - Train pearson score: 0.263983892265567 -- Validation loss: 3.3664798736572266 - Validation pearson score: 0.22912946948768145- Validation                  p_value: 3.477752689938851e-07
 50%|█████     | 25/50 [13:13<12:42, 30.50s/it]

[24/49] train_loss: 3.364, train_score: 0.264 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.355557441711426 - Train pearson score: 0.28858175314234413 -- Validation loss: 3.3741471767425537 - Validation pearson score: 0.22561073266004822- Validation                  p_value: 4.594541082962626e-05
 52%|█████▏    | 26/50 [13:43<12:01, 30.08s/it]

[25/49] train_loss: 3.356, train_score: 0.289 


INFO:root:Evaluating accuracy on dev set
INFO:root:new model saved
INFO:root:Train loss: 3.35164475440979 - Train pearson score: 0.28853673834359245 -- Validation loss: 3.365579843521118 - Validation pearson score: 0.24473343733365108- Validation                  p_value: 2.2060337788703412e-06
 54%|█████▍    | 27/50 [14:12<11:26, 29.85s/it]

[26/49] train_loss: 3.352, train_score: 0.289 


INFO:root:Evaluating accuracy on dev set
INFO:root:new model saved
INFO:root:Train loss: 3.345743179321289 - Train pearson score: 0.30152289901903173 -- Validation loss: 3.369123935699463 - Validation pearson score: 0.2567465933202753- Validation                  p_value: 9.256597220843657e-08
 56%|█████▌    | 28/50 [14:41<10:52, 29.67s/it]

[27/49] train_loss: 3.346, train_score: 0.302 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.343252658843994 - Train pearson score: 0.3095054158831476 -- Validation loss: 3.3648571968078613 - Validation pearson score: 0.23076383029096303- Validation                  p_value: 5.562858596442716e-07
 58%|█████▊    | 29/50 [15:12<10:29, 29.99s/it]

[28/49] train_loss: 3.343, train_score: 0.310 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.3436594009399414 - Train pearson score: 0.30795817270887355 -- Validation loss: 3.366668462753296 - Validation pearson score: 0.2561351666940812- Validation                  p_value: 2.3063559303256545e-08
 60%|██████    | 30/50 [15:43<10:05, 30.25s/it]

[29/49] train_loss: 3.344, train_score: 0.308 


INFO:root:Evaluating accuracy on dev set
INFO:root:new model saved
INFO:root:Train loss: 3.3437390327453613 - Train pearson score: 0.31668330636146413 -- Validation loss: 3.3593389987945557 - Validation pearson score: 0.2636649958740177- Validation                  p_value: 1.1804419399771624e-07
 62%|██████▏   | 31/50 [16:13<09:35, 30.29s/it]

[30/49] train_loss: 3.344, train_score: 0.317 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.3446474075317383 - Train pearson score: 0.3196190371192033 -- Validation loss: 3.372572422027588 - Validation pearson score: 0.2635344641195485- Validation                  p_value: 8.07318680585718e-09
 64%|██████▍   | 32/50 [16:44<09:07, 30.42s/it]

[31/49] train_loss: 3.345, train_score: 0.320 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.340358018875122 - Train pearson score: 0.3371979954093569 -- Validation loss: 3.3787777423858643 - Validation pearson score: 0.21637497226936933- Validation                  p_value: 1.7701337709566279e-06
 66%|██████▌   | 33/50 [17:16<08:46, 30.96s/it]

[32/49] train_loss: 3.340, train_score: 0.337 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.3411388397216797 - Train pearson score: 0.3338993610026214 -- Validation loss: 3.3613038063049316 - Validation pearson score: 0.2529977565178022- Validation                  p_value: 3.4938439556004056e-08
 68%|██████▊   | 34/50 [17:47<08:13, 30.84s/it]

[33/49] train_loss: 3.341, train_score: 0.334 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.337146043777466 - Train pearson score: 0.3427367717803101 -- Validation loss: 3.3648526668548584 - Validation pearson score: 0.23130088109262278- Validation                  p_value: 1.5871042704274565e-06
 70%|███████   | 35/50 [18:19<07:47, 31.17s/it]

[34/49] train_loss: 3.337, train_score: 0.343 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.347240686416626 - Train pearson score: 0.3217476731918222 -- Validation loss: 3.3651716709136963 - Validation pearson score: 0.2594097052136943- Validation                  p_value: 2.3724653324060673e-07
 72%|███████▏  | 36/50 [18:51<07:20, 31.43s/it]

[35/49] train_loss: 3.347, train_score: 0.322 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.3477628231048584 - Train pearson score: 0.3193030660052184 -- Validation loss: 3.3809432983398438 - Validation pearson score: 0.18837311207161334- Validation                  p_value: 0.0002706424185752357
 74%|███████▍  | 37/50 [19:23<06:51, 31.69s/it]

[36/49] train_loss: 3.348, train_score: 0.319 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.340794324874878 - Train pearson score: 0.3420151820125905 -- Validation loss: 3.3819565773010254 - Validation pearson score: 0.2594127490425474- Validation                  p_value: 3.906463597762517e-08
 76%|███████▌  | 38/50 [19:54<06:18, 31.56s/it]

[37/49] train_loss: 3.341, train_score: 0.342 


INFO:root:Evaluating accuracy on dev set
INFO:root:new model saved
INFO:root:Train loss: 3.3364920616149902 - Train pearson score: 0.33897810768004805 -- Validation loss: 3.356975555419922 - Validation pearson score: 0.26425285108524543- Validation                  p_value: 2.1145096548457657e-08
 78%|███████▊  | 39/50 [20:25<05:44, 31.31s/it]

[38/49] train_loss: 3.336, train_score: 0.339 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.340237617492676 - Train pearson score: 0.3327029989141298 -- Validation loss: 3.369279623031616 - Validation pearson score: 0.22708517770866346- Validation                  p_value: 2.66292535988498e-06
 80%|████████  | 40/50 [20:55<05:10, 31.05s/it]

[39/49] train_loss: 3.340, train_score: 0.333 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.3421378135681152 - Train pearson score: 0.33328162811878864 -- Validation loss: 3.3765690326690674 - Validation pearson score: 0.2410841779001512- Validation                  p_value: 1.1927232494720867e-07
 82%|████████▏ | 41/50 [21:34<05:00, 33.38s/it]

[40/49] train_loss: 3.342, train_score: 0.333 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.3400790691375732 - Train pearson score: 0.33514678104185835 -- Validation loss: 3.3649699687957764 - Validation pearson score: 0.2321188538522526- Validation                  p_value: 1.507801212532628e-06
 84%|████████▍ | 42/50 [22:10<04:32, 34.01s/it]

[41/49] train_loss: 3.340, train_score: 0.335 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.3387491703033447 - Train pearson score: 0.3352381780438473 -- Validation loss: 3.3631036281585693 - Validation pearson score: 0.23300069973455048- Validation                  p_value: 2.8991210566538345e-06
 86%|████████▌ | 43/50 [22:46<04:02, 34.63s/it]

[42/49] train_loss: 3.339, train_score: 0.335 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.3402011394500732 - Train pearson score: 0.34529952692045374 -- Validation loss: 3.3881938457489014 - Validation pearson score: 0.20983221256740767- Validation                  p_value: 0.00011148483062843712
 88%|████████▊ | 44/50 [23:21<03:28, 34.77s/it]

[43/49] train_loss: 3.340, train_score: 0.345 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.335355758666992 - Train pearson score: 0.3481970211683671 -- Validation loss: 3.3871936798095703 - Validation pearson score: 0.22494188129528106- Validation                  p_value: 1.1835125001407515e-06
 90%|█████████ | 45/50 [23:56<02:54, 34.85s/it]

[44/49] train_loss: 3.335, train_score: 0.348 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.3375797271728516 - Train pearson score: 0.35014351711354097 -- Validation loss: 3.3677310943603516 - Validation pearson score: 0.2607057821520484- Validation                  p_value: 7.146599005721175e-09
 92%|█████████▏| 46/50 [24:30<02:19, 34.78s/it]

[45/49] train_loss: 3.338, train_score: 0.350 


INFO:root:Evaluating accuracy on dev set
INFO:root:new model saved
INFO:root:Train loss: 3.34891939163208 - Train pearson score: 0.3325580007540644 -- Validation loss: 3.3572938442230225 - Validation pearson score: 0.2642601410719377- Validation                  p_value: 7.58171678819945e-09
 94%|█████████▍| 47/50 [25:05<01:44, 34.80s/it]

[46/49] train_loss: 3.349, train_score: 0.333 


INFO:root:Evaluating accuracy on dev set
INFO:root:Train loss: 3.3391799926757812 - Train pearson score: 0.336652247386211 -- Validation loss: 3.3584818840026855 - Validation pearson score: 0.2520584108704959- Validation                  p_value: 1.5717649448633558e-07
 96%|█████████▌| 48/50 [25:41<01:09, 34.92s/it]

[47/49] train_loss: 3.339, train_score: 0.337 


[99/99] train_loss: 3.220, train_score: 0.877 
Final score: 0.877, expected 1.000

Train loss: 3.220442295074463 - Train pearson score: 0.8765596209715861 -- Validation loss: 3.2523105144500732 - Validation pearson score: 0.6362855344577428- Validation                  p_value: 2.162431682382807e-35

## Part 4. Evaluation and Analysis (2 points)  
Implement function evaluate_test_set to calculate the final accuracy of the performance evaluation metric on the test data.  
Compare the result with the original AAAI paper. Сomment on effect of penalty loss on model capacity. Did the inclusion of the self-attention block improve the results? If yes, then how? Can you think of additional techniques to improve the results? Briefly answer these questions in the markdown cells.

- The results obtained with this model are worse compared to the original paper , which means adding self-attention layer doesn't improve the model. The AAAI paper outperforms the baseline model.
- To improve the base results several approches could be implemented: we can remove self-attention layer and add a transformer instead or use both at the same time.
- We can apply transfer learning to train the model to classifiy sentences using labels column in sick dataset , and then fine-tune the output model on our main target task.
- The affect of penalty loss is to penalize specific weights in order to diffrentiate between different weight vectors accross different hops of attention(reduce redundancy in the embedding matrix and add diversity to weight vectors) , which reduces model flexibility.
- In other words penalty term reduces the model capacity by reducing its variance and introduce some bias in the model (because it penalizes some parameters) to prevent overfitting. 

In [None]:
import pickle
import warnings
model_filename = "lstm_model.sav"
from test import evaluate_test_set
lstm_model = pickle.load(open(model_filename, 'rb'))
warnings.filterwarnings('ignore')

evaluate_test_set(
    model=lstm_model,
    data_loader=sick_dataloaders,
    config_dict={
        "device": device,
        "model_name": "siamese_lstm_attention",
        "self_attention_config": self_attention_config,
    },
)

train_loss: 3.220, train_score: 0.877 Final score: 0.877, expected 1.000

test_loss: 0.257, spearman_score: 0.328 , pearson_score: 0.336