# Model 01

Evidence retrieval using a Siamese BERT classification model.

Ref:
- [STS continue training guide](https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/sts/training_stsbenchmark_continue_training.py)

## Setup

### Working Directory

In [1]:
# Change the working directory to project root
import pathlib
import os
ROOT_DIR = pathlib.Path.cwd()
while not ROOT_DIR.joinpath("src").exists():
    ROOT_DIR = ROOT_DIR.parent
os.chdir(ROOT_DIR)

### File paths

In [2]:
MODEL_PATH = ROOT_DIR.joinpath("./result/models/*")

### Dependencies

In [3]:
# Imports and dependencies
import spacy
import torch
from torch import nn
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, LoggingHandler
from sentence_transformers.losses import SoftmaxLoss
from sentence_transformers.evaluation import BinaryClassificationEvaluator
from src.torch_utils import get_torch_device
from src.spacy_utils import process_sentence
from src.model_01 import ClaimEvidenceDataset
from datetime import datetime
import logging
import math

torch_device = get_torch_device()

Torch device is 'mps'


  from .autonotebook import tqdm as notebook_tqdm


### Names

In [4]:
run_time = datetime.now().strftime('%Y_%m_%d_%H_%M')
model_save_path = MODEL_PATH.with_name(f"model_01_{run_time}")
eval_name = "model_01_dev"

### Logging

In [5]:
logging.basicConfig(format='%(asctime)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO,
    handlers=[LoggingHandler()]
)

## Dataset

In [6]:
train_data = ClaimEvidenceDataset(
    claims_json="./data/train-claims.json",
    evidence_json="./data/evidence.json",
    negative_sample_strategy="related_random",
    negative_sample_size=500,
    preprocess_func=None
)
dev_data = ClaimEvidenceDataset(
    claims_json="./data/dev-claims.json",
    evidence_json="./data/evidence.json",
    negative_sample_strategy="related_random",
    negative_sample_size=500,
    preprocess_func=None
)

Generate claim-evidence pair with related_random strategy n=500


claims: 100%|██████████| 1228/1228 [01:20<00:00, 15.20it/s]


Generate claim-evidence pair with related_random strategy n=500


claims: 100%|██████████| 154/154 [00:08<00:00, 17.74it/s]


In [7]:
print(len(train_data))
print(len(dev_data))

1232122
148302


In [8]:
# for sample in train_data:
#     if sample.texts[0] == "Not only is there no scientific evidence that CO2 is a pollutant, higher CO2 concentrations actually help ecosystems support more plant and animal life.":
#         print(sample)

## Select model components

In [9]:
nlp = spacy.load("en_core_web_trf")
nlp

<spacy.lang.en.English at 0x28a379550>

In [10]:
model = SentenceTransformer(
    "sentence-transformers/msmarco-bert-base-dot-v5",
    device=torch_device
)
model

2023-04-27 22:47:16 - Load pretrained SentenceTransformer: sentence-transformers/msmarco-bert-base-dot-v5


SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
)

In [11]:
train_loss = SoftmaxLoss(
    model=model,
    sentence_embedding_dimension=model.get_sentence_embedding_dimension(),
    num_labels=2,
    concatenation_sent_rep=True,
    concatenation_sent_difference=True,
    concatenation_sent_multiplication=False,
    loss_fct=nn.CrossEntropyLoss()
)
train_loss

2023-04-27 22:47:27 - Softmax loss: #Vectors concatenated: 3


SoftmaxLoss(
  (model): SentenceTransformer(
    (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
    (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  )
  (classifier): Linear(in_features=2304, out_features=2, bias=True)
  (loss_fct): CrossEntropyLoss()
)

In [12]:
train_eval = BinaryClassificationEvaluator.from_input_examples(
    examples=dev_data,
    name=eval_name,
    write_csv=True,
    show_progress_bar=True
)
train_eval

<sentence_transformers.evaluation.BinaryClassificationEvaluator.BinaryClassificationEvaluator at 0x2c49bbaf0>

## Training

In [13]:
train_batch_size = 64
num_epochs = 5

In [14]:
train_dataloader = DataLoader(
    dataset=train_data,
    shuffle=True,
    batch_size=train_batch_size
)
dev_dataloader = DataLoader(
    dataset=dev_data,
    shuffle=True,
    batch_size=train_batch_size
)

In [15]:
#10% of train data for warm-up
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1)

In [16]:
# Train the model
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=num_epochs,
    evaluator=train_eval,
    evaluation_steps=1000,
    warmup_steps=warmup_steps,
    optimizer_class=torch.optim.AdamW,
    optimizer_params={"lr": 0.00002},
    weight_decay=0.01,
    output_path=str(model_save_path),
    save_best_model=True,
    show_progress_bar=True
)

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
                                            
Epoch:   0%|          | 0/5 [38:34<?, ?it/s]                    

2023-04-27 23:27:15 - Binary Accuracy Evaluation of the model on model_01_dev dataset in epoch 0 after 1000 steps:



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

2023-04-27 23:29:18 - Accuracy with Cosine-Similarity:           99.67	(Threshold: 0.9851)
2023-04-27 23:29:18 - F1 with Cosine-Similarity:                 13.26	(Threshold: 0.9738)
2023-04-27 23:29:18 - Precision with Cosine-Similarity:          21.56
2023-04-27 23:29:18 - Recall with Cosine-Similarity:             9.57
2023-04-27 23:29:18 - Average Precision with Cosine-Similarity:  5.60



                                            
                                                                
                                                                
                                                                
                                                                
Epoch:   0%|          | 0/5 [40:37<?, ?it/s]                    

2023-04-27 23:29:19 - Accuracy with Manhattan-Distance:           99.67	(Threshold: 59.5647)
2023-04-27 23:29:19 - F1 with Manhattan-Distance:                 12.97	(Threshold: 78.7452)
2023-04-27 23:29:19 - Precision with Manhattan-Distance:          20.09
2023-04-27 23:29:19 - Recall with Manhattan-Distance:             9.57
2023-04-27 23:29:19 - Average Precision with Manhattan-Distance:  5.44



                                            
                                                                
                                                                
                                                                
                                                                
Epoch:   0%|          | 0/5 [40:38<?, ?it/s]                    

2023-04-27 23:29:19 - Accuracy with Euclidean-Distance:           99.67	(Threshold: 2.6911)
2023-04-27 23:29:19 - F1 with Euclidean-Distance:                 12.98	(Threshold: 3.5669)
2023-04-27 23:29:19 - Precision with Euclidean-Distance:          20.17
2023-04-27 23:29:19 - Recall with Euclidean-Distance:             9.57
2023-04-27 23:29:19 - Average Precision with Euclidean-Distance:  5.44



                                            
                                                                
                                                                
                                                                
                                                                
                                                                
Epoch:   0%|          | 0/5 [40:38<?, ?it/s]                    

2023-04-27 23:29:19 - Accuracy with Dot-Product:           99.67	(Threshold: 252.7120)
2023-04-27 23:29:19 - F1 with Dot-Product:                 5.44	(Threshold: 239.2893)
2023-04-27 23:29:19 - Precision with Dot-Product:          3.84
2023-04-27 23:29:19 - Recall with Dot-Product:             9.37
2023-04-27 23:29:19 - Average Precision with Dot-Product:  2.02

2023-04-27 23:29:19 - Save model to /Users/johnsonzhou/git/comp90042-project/result/models/model_01_2023_04_27_22_45


                                            
Epoch:   0%|          | 0/5 [1:11:12<?, ?it/s]                     

2023-04-27 23:59:53 - Binary Accuracy Evaluation of the model on model_01_dev dataset in epoch 0 after 2000 steps:



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

2023-04-28 00:01:31 - Accuracy with Cosine-Similarity:           99.67	(Threshold: 0.9813)
2023-04-28 00:01:31 - F1 with Cosine-Similarity:                 14.54	(Threshold: 0.9632)
2023-04-28 00:01:31 - Precision with Cosine-Similarity:          14.23
2023-04-28 00:01:31 - Recall with Cosine-Similarity:             14.87
2023-04-28 00:01:31 - Average Precision with Cosine-Similarity:  6.95



                                              
                                                                   
                                                                   
                                                                   
                                                                   
Epoch:   0%|          | 0/5 [1:12:50<?, ?it/s]                     

2023-04-28 00:01:31 - Accuracy with Manhattan-Distance:           99.67	(Threshold: 67.3073)
2023-04-28 00:01:31 - F1 with Manhattan-Distance:                 14.43	(Threshold: 94.4607)
2023-04-28 00:01:31 - Precision with Manhattan-Distance:          14.20
2023-04-28 00:01:31 - Recall with Manhattan-Distance:             14.66
2023-04-28 00:01:32 - Average Precision with Manhattan-Distance:  6.87



                                              
                                                                   
                                                                   
                                                                   
                                                                   
Epoch:   0%|          | 0/5 [1:12:51<?, ?it/s]                     

2023-04-28 00:01:32 - Accuracy with Euclidean-Distance:           99.67	(Threshold: 3.0587)
2023-04-28 00:01:32 - F1 with Euclidean-Distance:                 14.71	(Threshold: 4.2692)
2023-04-28 00:01:32 - Precision with Euclidean-Distance:          14.75
2023-04-28 00:01:32 - Recall with Euclidean-Distance:             14.66
2023-04-28 00:01:32 - Average Precision with Euclidean-Distance:  6.82



                                              
                                                                   
                                                                   
                                                                   
                                                                   
                                                                   
Epoch:   0%|          | 0/5 [1:12:51<?, ?it/s]                     

2023-04-28 00:01:32 - Accuracy with Dot-Product:           99.67	(Threshold: 255.0250)
2023-04-28 00:01:32 - F1 with Dot-Product:                 10.99	(Threshold: 243.0355)
2023-04-28 00:01:32 - Precision with Dot-Product:          8.01
2023-04-28 00:01:32 - Recall with Dot-Product:             17.52
2023-04-28 00:01:32 - Average Precision with Dot-Product:  5.27

2023-04-28 00:01:32 - Save model to /Users/johnsonzhou/git/comp90042-project/result/models/model_01_2023_04_27_22_45


                                              
Epoch:   0%|          | 0/5 [1:35:03<?, ?it/s]                     

2023-04-28 00:23:44 - Binary Accuracy Evaluation of the model on model_01_dev dataset in epoch 0 after 3000 steps:



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

2023-04-28 00:25:21 - Accuracy with Cosine-Similarity:           99.67	(Threshold: 0.9890)
2023-04-28 00:25:21 - F1 with Cosine-Similarity:                 15.45	(Threshold: 0.9745)
2023-04-28 00:25:21 - Precision with Cosine-Similarity:          17.16
2023-04-28 00:25:21 - Recall with Cosine-Similarity:             14.05
2023-04-28 00:25:21 - Average Precision with Cosine-Similarity:  7.03



                                              
                                                                   
                                                                   
                                                                   
                                                                   
Epoch:   0%|          | 0/5 [1:36:41<?, ?it/s]                     

2023-04-28 00:25:22 - Accuracy with Manhattan-Distance:           99.67	(Threshold: 53.9586)
2023-04-28 00:25:22 - F1 with Manhattan-Distance:                 15.16	(Threshold: 81.3441)
2023-04-28 00:25:22 - Precision with Manhattan-Distance:          17.05
2023-04-28 00:25:22 - Recall with Manhattan-Distance:             13.65
2023-04-28 00:25:22 - Average Precision with Manhattan-Distance:  7.06



                                              
                                                                   
                                                                   
                                                                   
                                                                   
Epoch:   0%|          | 0/5 [1:36:41<?, ?it/s]                     

2023-04-28 00:25:22 - Accuracy with Euclidean-Distance:           99.67	(Threshold: 2.4397)
2023-04-28 00:25:22 - F1 with Euclidean-Distance:                 15.45	(Threshold: 3.7775)
2023-04-28 00:25:22 - Precision with Euclidean-Distance:          15.42
2023-04-28 00:25:22 - Recall with Euclidean-Distance:             15.48
2023-04-28 00:25:22 - Average Precision with Euclidean-Distance:  7.12



                                              
                                                                   
                                                                   
                                                                   
                                                                   
                                                                   
Epoch:   0%|          | 0/5 [1:36:41<?, ?it/s]                     

2023-04-28 00:25:23 - Accuracy with Dot-Product:           99.67	(Threshold: 269.8538)
2023-04-28 00:25:23 - F1 with Dot-Product:                 9.46	(Threshold: 261.7487)
2023-04-28 00:25:23 - Precision with Dot-Product:          7.29
2023-04-28 00:25:23 - Recall with Dot-Product:             13.44
2023-04-28 00:25:23 - Average Precision with Dot-Product:  3.98

2023-04-28 00:25:23 - Save model to /Users/johnsonzhou/git/comp90042-project/result/models/model_01_2023_04_27_22_45


Iteration:  19%|█▉        | 3737/19252 [1:48:11<7:29:10,  1.74s/it]
Epoch:   0%|          | 0/5 [1:48:11<?, ?it/s]


RuntimeError: MPS backend out of memory (MPS allocated: 40.27 GB, other allocations: 82.52 GB, max allowed: 122.40 GB). Tried to allocate 89.06 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).