In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [None]:
DATASET_FULL_TEXT = './gdrive/Shareddrives/DATASETS/PMC-Sents-FULL/'
OUTPUT_MODEL_DIR = './gdrive/Shareddrives/MODELS/'

In [None]:
!pip install transformers -q
!pip install sentence_transformers -q

[K     |████████████████████████████████| 4.7 MB 5.0 MB/s 
[K     |████████████████████████████████| 596 kB 39.7 MB/s 
[K     |████████████████████████████████| 6.6 MB 33.7 MB/s 
[K     |████████████████████████████████| 101 kB 9.3 MB/s 
[K     |████████████████████████████████| 85 kB 2.7 MB/s 
[K     |████████████████████████████████| 1.3 MB 27.7 MB/s 
[?25h  Building wheel for sentence-transformers (setup.py) ... [?25l[?25hdone


In [None]:
import pandas as pd

In [None]:
df_train = pd.read_parquet(DATASET_FULL_TEXT + 'train_full_text_span1.parquet')
df_dev_triplets = pd.read_parquet(DATASET_FULL_TEXT + 'dev_triplets_span1.parquet')
df_test_triplets = pd.read_parquet(DATASET_FULL_TEXT + 'test_triplets_span1.parquet')

In [None]:
df_train.columns

Index(['id', 'text', 'label_id'], dtype='object')

In [None]:
df_dev_triplets.columns

Index(['anchor', 'positive', 'negative'], dtype='object')

In [None]:
# same text but with different labels
print(f'{len(df_train.index)}')
mask = df_train.groupby('text')['label_id'].transform('nunique') > 1
df_train = df_train[~mask].copy()
print(f'{len(df_train.index)}')

138473
138473


In [None]:
from sentence_transformers import InputExample
from tqdm import tqdm

train_set = []
guid = 1
for idx, row in tqdm(df_train.iterrows(), total=len(df_train.index)):
    train_set.append(InputExample(
        guid=guid,
        texts=[row['text']],
        label=row['label_id']
    ))
    guid += 1
len(train_set)

100%|██████████| 138473/138473 [00:13<00:00, 10014.42it/s]


138473

In [None]:
dev_set = []

guid = 1
for idx, row in tqdm(df_dev_triplets.iterrows(), total=len(df_dev_triplets.index)):
    dev_set.append(InputExample(
        guid=guid,
        texts=[row['anchor'], row['positive'], row['negative']],
    ))
    guid += 1
len(dev_set)

100%|██████████| 17309/17309 [00:01<00:00, 15794.25it/s]


17309

In [None]:
test_set = []

guid = 1
for idx, row in tqdm(df_test_triplets.iterrows(), total=len(df_test_triplets.index)):
    test_set.append(InputExample(
        guid=guid,
        texts=[row['anchor'], row['positive'], row['negative']],
    ))
    guid += 1
len(test_set)

100%|██████████| 17310/17310 [00:01<00:00, 13377.67it/s]


17310

In [None]:
import logging
from sentence_transformers import LoggingHandler

logging.basicConfig(
    format="%(asctime)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=logging.INFO,
    handlers=[LoggingHandler()],
)

In [None]:
from datetime import datetime
model_name = 'sentence-transformers/all-MiniLM-L6-v2'
model_file_name = model_name.split('/')[-1] + '_FullText_Span1_TripletAll'

train_batch_size = 32
output_path = (
    OUTPUT_MODEL_DIR
    + model_file_name
    + "-"
    + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)
output_path

'./gdrive/Shareddrives/MODELS/all-MiniLM-L6-v2_FullText_Span1_TripletAll-2022-08-16_17-23-14'

In [None]:
from sentence_transformers import models, SentenceTransformer

model = SentenceTransformer(model_name, device='cuda')

Downloading:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/350 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/13.2k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/349 [00:00<?, ?B/s]

In [None]:
from sentence_transformers import datasets
from torch.utils.data import DataLoader

loader = DataLoader(train_set, shuffle=True, batch_size=train_batch_size)

In [None]:
from sentence_transformers import losses

train_loss = losses.BatchAllTripletLoss(model=model)

In [None]:
from sentence_transformers.evaluation import TripletEvaluator

dev_evaluator = TripletEvaluator.from_input_examples(
    dev_set, write_csv=True, show_progress_bar=True, name='full_text-dev'
)

In [None]:
logging.info("Performancea before fine-tuning:")
dev_evaluator(model)

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

0.5266624299497371

In [None]:
logging.info("Evaluating model on test set")
test_evaluator = TripletEvaluator.from_input_examples(
    test_set, write_csv=True, show_progress_bar=True, name='full_text-test'
)
test_evaluator(model)

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

0.5194685153090699

In [None]:
num_epochs = 20

warmup_steps = int(len(loader) * num_epochs  * 0.1)  # 10% of train data

In [None]:
model_output_path = output_path + '/model'
checkpoint_output_path = output_path + '/checkpoint'

print(model_output_path)
print(checkpoint_output_path)

./gdrive/Shareddrives/MODELS/all-MiniLM-L6-v2_FullText_Span1_TripletAll-2022-08-16_17-23-14/model
./gdrive/Shareddrives/MODELS/all-MiniLM-L6-v2_FullText_Span1_TripletAll-2022-08-16_17-23-14/checkpoint


In [None]:
%%time
model.fit(
    train_objectives=[(loader, train_loss)],
    epochs=num_epochs,
    warmup_steps=warmup_steps,
    output_path=model_output_path,
    show_progress_bar=True,
    evaluator=dev_evaluator,
    save_best_model=True,
    checkpoint_save_total_limit=1,
    checkpoint_path=checkpoint_output_path    
)  

Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4328 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

CPU times: user 2h 5min 35s, sys: 3min 14s, total: 2h 8min 50s
Wall time: 2h 8min 36s


In [None]:
logging.info("Evaluating model on test set")
test_evaluator = TripletEvaluator.from_input_examples(
    test_set, write_csv=True, show_progress_bar=True, name='full_text-test'
)
model.evaluate(test_evaluator)

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

0.7024263431542461

In [None]:
output_path + "_PMC-Sents-FULL"

In [None]:
model.save(output_path + "_PMC-Sents-FULL")