In [1]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [2]:
DATASET_FULL_TEXT = './gdrive/Shareddrives/DATASETS/PMC-Sents-FULL/'
OUTPUT_MODEL_DIR = './gdrive/Shareddrives/MODELS/'

In [3]:
!pip install transformers -q
!pip install sentence_transformers -q

[K     |████████████████████████████████| 4.7 MB 9.1 MB/s 
[K     |████████████████████████████████| 101 kB 13.3 MB/s 
[K     |████████████████████████████████| 6.6 MB 51.6 MB/s 
[K     |████████████████████████████████| 596 kB 73.3 MB/s 
[K     |████████████████████████████████| 85 kB 3.9 MB/s 
[K     |████████████████████████████████| 1.3 MB 18.0 MB/s 
[?25h  Building wheel for sentence-transformers (setup.py) ... [?25l[?25hdone


In [4]:
import pandas as pd

In [5]:
df_train = pd.read_parquet(DATASET_FULL_TEXT + 'train_full_text_span1.parquet')
df_dev = pd.read_parquet(DATASET_FULL_TEXT + 'dev_triplets_span1.parquet')
df_test = pd.read_parquet(DATASET_FULL_TEXT + 'test_triplets_span1.parquet')

In [6]:
df_train.columns

Index(['id', 'text', 'label_id'], dtype='object')

In [7]:
df_dev.columns

Index(['anchor', 'positive', 'negative'], dtype='object')

In [9]:
from sentence_transformers import InputExample
from tqdm import tqdm

train_set = []
guid = 1
for idx, row in tqdm(df_train.iterrows(), total=len(df_train.index)):
    train_set.append(InputExample(
        guid=guid,
        texts=[row['text']],
        label=row['label_id']
    ))
    guid += 1
len(train_set)

100%|██████████| 138473/138473 [00:07<00:00, 18322.20it/s]


138473

In [10]:
dev_set = []

guid = 1
for idx, row in tqdm(df_dev.iterrows(), total=len(df_dev.index)):
    dev_set.append(InputExample(
        guid=guid,
        texts=[row['anchor'], row['positive'], row['negative']],
    ))
    guid += 1
len(dev_set)

100%|██████████| 17309/17309 [00:00<00:00, 18942.94it/s]


17309

In [11]:
test_set = []

guid = 1
for idx, row in tqdm(df_test.iterrows(), total=len(df_test.index)):
    test_set.append(InputExample(
        guid=guid,
        texts=[row['anchor'], row['positive'], row['negative']],
    ))
    guid += 1
len(test_set)

100%|██████████| 17310/17310 [00:01<00:00, 16405.46it/s]


17310

In [12]:
import logging
from sentence_transformers import LoggingHandler

logging.basicConfig(
    format="%(asctime)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=logging.INFO,
    handlers=[LoggingHandler()],
)

In [13]:
from datetime import datetime
model_name = 'allenai/scibert_scivocab_uncased'
model_file_name = model_name.split('/')[-1] + '_FullText1_TripletAll'

train_batch_size = 16
output_path = (
    OUTPUT_MODEL_DIR
    + model_file_name
    + "-"
    + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)
output_path

'./gdrive/Shareddrives/MODELS/scibert_scivocab_uncased_FullText1_TripletAll-2022-08-12_20-39-50'

In [14]:
from sentence_transformers import models, SentenceTransformer

bert = models.Transformer(model_name)
pooler = models.Pooling(
    bert.get_word_embedding_dimension(),
    pooling_mode_mean_tokens=True
)

model = SentenceTransformer(modules=[bert, pooler])

Downloading config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/422M [00:00<?, ?B/s]

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading vocab.txt:   0%|          | 0.00/223k [00:00<?, ?B/s]

In [15]:
from sentence_transformers import datasets
from torch.utils.data import DataLoader

loader = DataLoader(train_set, shuffle=True, batch_size=train_batch_size)

In [16]:
from sentence_transformers import losses

train_loss = losses.BatchAllTripletLoss(model=model)

In [17]:
from sentence_transformers.evaluation import TripletEvaluator

dev_evaluator = TripletEvaluator.from_input_examples(
    dev_set, write_csv=True, show_progress_bar=True, name='full_text-dev'
)

In [18]:
logging.info("Performance before fine-tuning:")
dev_evaluator(model)

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

0.5507539430354151

In [19]:
logging.info("Evaluating model on test set")
test_evaluator = TripletEvaluator.from_input_examples(
    test_set, write_csv=True, show_progress_bar=True, name='full_text-test'
)
test_evaluator(model)

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

0.5456383593298672

In [20]:
num_epochs = 20

warmup_steps = int(len(loader) * num_epochs  * 0.1)  # 10% of train data

In [21]:
model_output_path = output_path + '/model'
checkpoint_output_path = output_path + '/checkpoint'

print(model_output_path)
print(checkpoint_output_path)

./gdrive/Shareddrives/MODELS/scibert_scivocab_uncased_FullText1_TripletAll-2022-08-12_20-39-50/model
./gdrive/Shareddrives/MODELS/scibert_scivocab_uncased_FullText1_TripletAll-2022-08-12_20-39-50/checkpoint


In [22]:
%%time
model.fit(
    train_objectives=[(loader, train_loss)],
    epochs=num_epochs,
    warmup_steps=warmup_steps,
    output_path=model_output_path,
    show_progress_bar=True,
    evaluator=dev_evaluator,
    save_best_model=True,
    checkpoint_save_total_limit=1,
    checkpoint_path=checkpoint_output_path    
)

Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Iteration:   0%|          | 0/8655 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

CPU times: user 7h 7min 58s, sys: 1h 37min 47s, total: 8h 45min 46s
Wall time: 8h 57min 27s


In [23]:
logging.info("Performance after fine-tuning:")
dev_evaluator(model)

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

0.7782078687388064

In [24]:
logging.info("Evaluating model on test set")
test_evaluator = TripletEvaluator.from_input_examples(
    test_set, write_csv=True, show_progress_bar=True, name='full_text-test'
)
model.evaluate(test_evaluator)

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

Batches:   0%|          | 0/1082 [00:00<?, ?it/s]

0.7822645869439631

In [25]:
output_path + "_PMC-Sents-FULL"

'./gdrive/Shareddrives/MODELS/scibert_scivocab_uncased_FullText1_TripletAll-2022-08-12_20-39-50_PMC-Sents-FULL'

In [26]:
model.save(output_path + "_PMC-Sents-FULL")