In [None]:
!pip install --quiet transformers==3.1.0 seqeval[gpu]==0.0.12

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
from pathlib import Path

import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertConfig, BertForTokenClassification

WORK_DIR = Path("/content/drive/MyDrive/my_colab/litcoin")
DATASET_DIR = Path("/content/drive/MyDrive/my_colab/litcoin/data")
ABSTRACTS_TEST_CSV = DATASET_DIR / "abstracts_test.csv"
ABSTRACTS_TRAIN_CSV = DATASET_DIR / "abstracts_train.csv"
ENTITIES_TRAIN_CSV = DATASET_DIR / "entities_train.csv"
RELATIONS_TRAIN_CSV = DATASET_DIR / "relations_train.csv"
SUBMISSION_EXAMPLE_CSV = DATASET_DIR / "submission_example.csv"

## Loading Data

In [None]:
abstracts_train_df = pd.read_csv(ABSTRACTS_TRAIN_CSV, sep="\t")
abstracts_train_df.head()

Unnamed: 0,abstract_id,title,abstract
0,1353340,Late-onset metachromatic leukodystrophy: molec...,We report on a new allele at the arylsulfatase...
1,1671881,Two distinct mutations at a single BamHI site ...,Classical phenylketonuria is an autosomal rece...
2,1848636,Debrisoquine phenotype and the pharmacokinetic...,The metabolism of the cardioselective beta-blo...
3,2422478,Midline B3 serotonin nerves in rat medulla are...,Previous experiments in this laboratory have s...
4,2491010,Molecular and phenotypic analysis of patients ...,Eighty unrelated individuals with Duchenne mus...


In [None]:
abstracts_test_df = pd.read_csv(ABSTRACTS_TEST_CSV, sep="\t")
abstracts_test_df.head()

Unnamed: 0,abstract_id,title,abstract
0,1711760,Delayed institution of hypertension during foc...,The effect of induced hypertension instituted ...
1,6086495,Localisation of the Becker muscular dystrophy ...,A linkage study in 30 Becker muscular dystroph...
2,7018927,Pituitary response to luteinizing hormone-rele...,The effects of a 6-hour infusion with haloperi...
3,7811247,X-linked adrenoleukodystrophy (ALD): a novel m...,Fragments of the adrenoleukodystrophy (ALD) cD...
4,8944024,Detection of heterozygous mutations in BRCA1 u...,The ability to scan a large gene rapidly and a...


In [None]:
entities_train_df = pd.read_csv(ENTITIES_TRAIN_CSV, sep="\t")
entities_train_df.head()

Unnamed: 0,id,abstract_id,offset_start,offset_finish,type,mention,entity_ids
0,0,1353340,11,39,DiseaseOrPhenotypicFeature,metachromatic leukodystrophy,D007966
1,1,1353340,111,126,GeneOrGeneProduct,arylsulfatase A,410
2,2,1353340,128,132,GeneOrGeneProduct,ARSA,410
3,3,1353340,159,187,DiseaseOrPhenotypicFeature,metachromatic leukodystrophy,D007966
4,4,1353340,189,192,DiseaseOrPhenotypicFeature,MLD,D007966


In [None]:
train_df = pd.merge(abstracts_train_df, entities_train_df, on="abstract_id")
train_df["full_text"] = train_df["title"] + " " + train_df["abstract"]
#train_df.head()
train_df.query("abstract_id==1353340")

Unnamed: 0,abstract_id,title,abstract,id,offset_start,offset_finish,type,mention,entity_ids,full_text
0,1353340,Late-onset metachromatic leukodystrophy: molec...,We report on a new allele at the arylsulfatase...,0,11,39,DiseaseOrPhenotypicFeature,metachromatic leukodystrophy,D007966,Late-onset metachromatic leukodystrophy: molec...
1,1353340,Late-onset metachromatic leukodystrophy: molec...,We report on a new allele at the arylsulfatase...,1,111,126,GeneOrGeneProduct,arylsulfatase A,410,Late-onset metachromatic leukodystrophy: molec...
2,1353340,Late-onset metachromatic leukodystrophy: molec...,We report on a new allele at the arylsulfatase...,2,128,132,GeneOrGeneProduct,ARSA,410,Late-onset metachromatic leukodystrophy: molec...
3,1353340,Late-onset metachromatic leukodystrophy: molec...,We report on a new allele at the arylsulfatase...,3,159,187,DiseaseOrPhenotypicFeature,metachromatic leukodystrophy,D007966,Late-onset metachromatic leukodystrophy: molec...
4,1353340,Late-onset metachromatic leukodystrophy: molec...,We report on a new allele at the arylsulfatase...,4,189,192,DiseaseOrPhenotypicFeature,MLD,D007966,Late-onset metachromatic leukodystrophy: molec...
5,1353340,Late-onset metachromatic leukodystrophy: molec...,We report on a new allele at the arylsulfatase...,5,210,220,SequenceVariant,arginine84,rs74315458,Late-onset metachromatic leukodystrophy: molec...
6,1353340,Late-onset metachromatic leukodystrophy: molec...,We report on a new allele at the arylsulfatase...,6,264,277,GeneOrGeneProduct,arylsulfatase,410,Late-onset metachromatic leukodystrophy: molec...
7,1353340,Late-onset metachromatic leukodystrophy: molec...,We report on a new allele at the arylsulfatase...,7,363,366,DiseaseOrPhenotypicFeature,MLD,D007966,Late-onset metachromatic leukodystrophy: molec...
8,1353340,Late-onset metachromatic leukodystrophy: molec...,We report on a new allele at the arylsulfatase...,8,372,395,SequenceVariant,arginine84 to glutamine,rs74315458,Late-onset metachromatic leukodystrophy: molec...
9,1353340,Late-onset metachromatic leukodystrophy: molec...,We report on a new allele at the arylsulfatase...,9,442,446,GeneOrGeneProduct,ARSA,410,Late-onset metachromatic leukodystrophy: molec...


## Data preprocessing

In [None]:
SEED = 42
MAX_LEN = 512
TRAIN_BATCH_SIZE = 4 #8 # 4
ACCUM_STEPS = 4 # 16
VALID_BATCH_SIZE = 4
EPOCHS = 20
LEARNING_RATE = 2e-05
MAX_GRAD_NORM = 10
IS_GRAD_CLIP = False # True
BERT_MODEL_NAME = 'bert-large-uncased'
tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL_NAME)

In [None]:
def iter_word_offsets(input_string):
   inputs = tokenizer(
        input_string,
        is_pretokenized=False,
        return_offsets_mapping=True, 
        padding='max_length', 
        truncation=True, 
        max_length=MAX_LEN,
        return_tensors="pt"
    )
   
   ids = inputs["input_ids"]
   tokens = tokenizer.convert_ids_to_tokens(ids.squeeze().tolist())
   zipped_args = zip(
        tokens,
        inputs["offset_mapping"].squeeze().tolist(),
    )
   for token_to_word_idx, (t, t_offsets) in enumerate(zipped_args):
     yield t_offsets[0], t_offsets[1], t


def transform_data(df: pd.DataFrame):
    train_data = []
    
    for abstract_id, frame in df.groupby(["abstract_id"]):
        text = list(frame["full_text"])[0]
        tag_sequence = []
        word_sequence = []
        
        continue_previous_tag = False
        n_continuations = 0
        previous_tag = ""
        
        offset_iterator = iter_word_offsets(text)
        for offset_start, offset_finish, word in offset_iterator:
            word_sequence.append(word)

            inner_entities = frame[(frame["offset_start"] <= offset_start) & (offset_finish <= frame["offset_finish"])]
            if not inner_entities.empty:
              entity_type = list(inner_entities["type"])[0]
              if continue_previous_tag:
                tag_sequence.append(f"I-{entity_type}")
              else:
                tag_sequence.append(f"B-{entity_type}")
                continue_previous_tag = True
            else:
              tag_sequence.append("O")
              continue_previous_tag = False
        train_data.append({"abstract_id":abstract_id, "orig_text":frame["full_text"].iloc[0], "text": word_sequence, "tags": tag_sequence})
    
    bert_df = pd.DataFrame(train_data)
    return bert_df 


try:
  bert_train_df = pd.read_csv(DATASET_DIR / "bert_train_df2.csv")
except FileNotFoundError:
  bert_train_df = transform_data(train_df)
  bert_train_df.to_csv(DATASET_DIR / "bert_train_df2.csv", index=None)

bert_train_df.head()

Make sure text tokens and tags are aligned

In [None]:
all(bert_train_df["text"].str.len() == bert_train_df["tags"].str.len())
print('Fix this test. I suppose the test is incorrect now')

Fix this test. I suppose the test is incorrect now


In [None]:
original_tags = list(train_df["type"].unique())
bert_b_tags = [f"B-{tag}" for tag in original_tags]
bert_i_tags = [f"I-{tag}" for tag in original_tags]
bert_tags = ["O"] + bert_b_tags + bert_i_tags

labels_to_ids = {k: v for v, k in enumerate(bert_tags)}
ids_to_labels = {v: k for v, k in enumerate(bert_tags)}
labels_to_ids

{'B-CellLine': 6,
 'B-ChemicalEntity': 5,
 'B-DiseaseOrPhenotypicFeature': 1,
 'B-GeneOrGeneProduct': 2,
 'B-OrganismTaxon': 4,
 'B-SequenceVariant': 3,
 'I-CellLine': 12,
 'I-ChemicalEntity': 11,
 'I-DiseaseOrPhenotypicFeature': 7,
 'I-GeneOrGeneProduct': 8,
 'I-OrganismTaxon': 10,
 'I-SequenceVariant': 9,
 'O': 0}

In [None]:
import os
import random

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

In [None]:
import ast

class dataset(Dataset):
  def __init__(self, dataframe, tokenizer, max_len):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len

  def __getitem__(self, index):
        # step 1: get the sentence and word labels 
        #sentence = self.data.text[index]
        sentence = self.data.orig_text[index]
        word_labels = self.data.tags[index]

        if type(word_labels) is str:
          word_labels = ast.literal_eval(word_labels)

        # step 2: use tokenizer to encode sentence (includes padding/truncation up to max length)
        # BertTokenizerFast provides a handy "return_offsets_mapping" functionality for individual tokens

        encoding = self.tokenizer(sentence,
                             is_pretokenized=False, 
                             return_offsets_mapping=True, 
                             padding='max_length', 
                             truncation=True, 
                             max_length=self.max_len)
        # step 3: create token labels only for first word pieces of each tokenized word
        labels = [labels_to_ids[label] for label in word_labels] 
        # code based on https://huggingface.co/transformers/custom_datasets.html#tok-ner
        # create an empty array of -100 of length max_length
        encoded_labels = np.ones(len(encoding["offset_mapping"]), dtype=int) * -100
        
        # set only labels whose first offset position is 0 and the second is not 0
        for idx, mapping in enumerate(encoding["offset_mapping"]):
          encoded_labels[idx] = labels[idx]

        # step 4: turn everything into PyTorch tensors
        item = {key: torch.as_tensor(val) for key, val in encoding.items()}
        item['labels'] = torch.as_tensor(encoded_labels)
        
        return item

  def __len__(self):
        return self.len

In [None]:
train_size = 0.8
train_dataset = bert_train_df.sample(frac=train_size, random_state=SEED)
test_dataset = bert_train_df.drop(train_dataset.index).reset_index(drop=True)
# train_dataset = bert_train_df.sample(frac=train_size, random_state=SEED)
# test_dataset = bert_train_df.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)

print("FULL Dataset: {}".format(bert_train_df.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))

training_set = dataset(train_dataset, tokenizer, MAX_LEN)
testing_set = dataset(test_dataset, tokenizer, MAX_LEN)

FULL Dataset: (481, 4)
TRAIN Dataset: (385, 4)
TEST Dataset: (96, 4)


In [None]:
train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

test_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': False,
                'num_workers': 0
                }

training_loader = DataLoader(training_set, **train_params)
testing_loader = DataLoader(testing_set, **test_params)

In [None]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
print(device)

cuda


In [None]:
from torch import nn

model = BertForTokenClassification.from_pretrained(BERT_MODEL_NAME, num_labels=len(labels_to_ids))
model.dropout = nn.Dropout(0.3)
model.to(device)

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-large

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), eps=1e-1

In [None]:
# sent = bert_train_df2.iloc[0].text

# if type(sent) is str:
#   sent = ast.literal_eval(sent)

# type(sent)

In [None]:
inputs = training_set[2]
input_ids = inputs["input_ids"].unsqueeze(0)
attention_mask = inputs["attention_mask"].unsqueeze(0)
labels = inputs["labels"].unsqueeze(0)

input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)

outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
initial_loss = outputs[0]
initial_loss

tensor(2.5234, device='cuda:0', grad_fn=<NllLossBackward0>)

In [None]:
tr_logits = outputs[1]
tr_logits.shape

torch.Size([1, 512, 13])

In [None]:
from transformers import get_linear_schedule_with_warmup

optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
# scheduler = get_linear_schedule_with_warmup(
#     optimizer, 
#     num_warmup_steps = 0, # Default value in run_glue.py
#     num_training_steps = len(training_loader) * EPOCHS
# )

In [None]:
# Defining the training function on the 80% of the dataset for tuning the bert model
def train(epoch):
    tr_loss, tr_accuracy = 0, 0
    nb_tr_examples, nb_tr_steps = 0, 0
    tr_preds, tr_labels = [], []
    # put model in training mode
    model.train()
    
    for idx, batch in enumerate(training_loader):
        
        ids = batch['input_ids'].to(device, dtype = torch.long)
        mask = batch['attention_mask'].to(device, dtype = torch.long)
        labels = batch['labels'].to(device, dtype = torch.long)

        loss, tr_logits = model(input_ids=ids, attention_mask=mask, labels=labels)
        tr_loss += loss.item()

        nb_tr_steps += 1
        nb_tr_examples += labels.size(0)
        
        if idx % 100==0:
            loss_step = tr_loss/nb_tr_steps
            # print(f"Training loss per 100 training steps: {loss_step}")
           
        # compute training accuracy
        flattened_targets = labels.view(-1) # shape (batch_size * seq_len,)
        active_logits = tr_logits.view(-1, model.num_labels) # shape (batch_size * seq_len, num_labels)
        flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size * seq_len,)
        
        # only compute accuracy at active labels
        active_accuracy = labels.view(-1) != -100 # shape (batch_size, seq_len)
        #active_labels = torch.where(active_accuracy, labels.view(-1), torch.tensor(-100).type_as(labels))
        
        labels = torch.masked_select(flattened_targets, active_accuracy)
        predictions = torch.masked_select(flattened_predictions, active_accuracy)
        
        tr_labels.extend(labels)
        tr_preds.extend(predictions)

        tmp_tr_accuracy = accuracy_score(labels.cpu().numpy(), predictions.cpu().numpy())
        tr_accuracy += tmp_tr_accuracy
    
        # # gradient clipping
        if IS_GRAD_CLIP:
          torch.nn.utils.clip_grad_norm_(
              parameters=model.parameters(), max_norm=MAX_GRAD_NORM
          )
        
        # backward pass
        loss.backward()
        if (idx+1) % ACCUM_STEPS == 0:
          optimizer.step()
          optimizer.zero_grad()
        
        
        # scheduler.step()

    epoch_loss = tr_loss / nb_tr_steps
    tr_accuracy = tr_accuracy / nb_tr_steps
    return epoch_loss

In [None]:
def valid(model, testing_loader):
    # put model in evaluation mode
    model.eval()
    
    eval_loss, eval_accuracy = 0, 0
    nb_eval_examples, nb_eval_steps = 0, 0
    eval_preds, eval_labels = [], []
    
    with torch.no_grad():
        for idx, batch in enumerate(testing_loader):
            
            ids = batch['input_ids'].to(device, dtype = torch.long)
            mask = batch['attention_mask'].to(device, dtype = torch.long)
            labels = batch['labels'].to(device, dtype = torch.long)
            
            loss, eval_logits = model(input_ids=ids, attention_mask=mask, labels=labels)
            
            eval_loss += loss.item()

            nb_eval_steps += 1
            nb_eval_examples += labels.size(0)
        
            if idx % 100==0:
                loss_step = eval_loss/nb_eval_steps
                # print(f"Validation loss per 100 evaluation steps: {loss_step}")
              
            # compute evaluation accuracy
            flattened_targets = labels.view(-1) # shape (batch_size * seq_len,)
            active_logits = eval_logits.view(-1, model.num_labels) # shape (batch_size * seq_len, num_labels)
            flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size * seq_len,)
            
            # only compute accuracy at active labels
            active_accuracy = labels.view(-1) != -100 # shape (batch_size, seq_len)
        
            labels = torch.masked_select(flattened_targets, active_accuracy)
            predictions = torch.masked_select(flattened_predictions, active_accuracy)
            
            eval_labels.extend(labels)
            eval_preds.extend(predictions)
            
            tmp_eval_accuracy = accuracy_score(labels.cpu().numpy(), predictions.cpu().numpy())
            eval_accuracy += tmp_eval_accuracy

    labels = [ids_to_labels[id.item()] for id in eval_labels]
    predictions = [ids_to_labels[id.item()] for id in eval_preds]
    
    eval_loss = eval_loss / nb_eval_steps
    eval_accuracy = eval_accuracy / nb_eval_steps
    # print(f"Validation Loss: {eval_loss}")
    # print(f"Validation Accuracy: {eval_accuracy}")

    return labels, predictions, eval_loss

In [None]:
from seqeval.metrics import classification_report, f1_score

best_f1 = -1
for epoch in range(EPOCHS):
    print(f"Training epoch: {epoch + 1}/{EPOCHS}")
    epoch_loss = train(epoch)
    labels, predictions, val_loss = valid(model, testing_loader)
    f1_val = f1_score(labels, predictions)
    scheduler.step(val_loss)
    print(f"Train loss: {epoch_loss}")
    print(f"Valid loss: {val_loss}")
    print(f"F1: {f1_val}")
    print("\n")
    
    if f1_val > best_f1 or best_f1 == -1:
      torch.save(model.state_dict(), WORK_DIR / "best.pth")
      best_f1 = f1_val
    #print(classification_report(labels, predictions))

print(f"Best F1: {best_f1}")

In [None]:
abstracts_test_df["full_text"] = abstracts_test_df["title"] + " " + abstracts_test_df["abstract"]
abstracts_test_df.head()

Unnamed: 0,abstract_id,title,abstract,full_text
0,1711760,Delayed institution of hypertension during foc...,The effect of induced hypertension instituted ...,Delayed institution of hypertension during foc...
1,6086495,Localisation of the Becker muscular dystrophy ...,A linkage study in 30 Becker muscular dystroph...,Localisation of the Becker muscular dystrophy ...
2,7018927,Pituitary response to luteinizing hormone-rele...,The effects of a 6-hour infusion with haloperi...,Pituitary response to luteinizing hormone-rele...
3,7811247,X-linked adrenoleukodystrophy (ALD): a novel m...,Fragments of the adrenoleukodystrophy (ALD) cD...,X-linked adrenoleukodystrophy (ALD): a novel m...
4,8944024,Detection of heterozygous mutations in BRCA1 u...,The ability to scan a large gene rapidly and a...,Detection of heterozygous mutations in BRCA1 u...


In [None]:
def predict(model, text):
    full_prediction = []
    split_words = list(iter_word_offsets(text))
    offsets = [(start, end) for start, end, _ in split_words]
    words = [word for _, _, word in split_words]

    inputs = tokenizer(
        text,
        is_pretokenized=False,
        return_offsets_mapping=True, 
        padding='max_length', 
        truncation=True, 
        max_length=MAX_LEN,
        return_tensors="pt"
    )

    # move to gpu
    ids = inputs["input_ids"].to(device)
    mask = inputs["attention_mask"].to(device)
    # forward pass
    
    outputs = model(ids, attention_mask=mask)
    logits = outputs[0]

    active_logits = logits.view(-1, model.num_labels) # shape (batch_size * seq_len, num_labels)
    flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size*seq_len,) - predictions at the token level

    tokens = tokenizer.convert_ids_to_tokens(ids.squeeze().tolist())
    token_predictions = [ids_to_labels[i] for i in flattened_predictions.cpu().numpy()]

    zipped_args = zip(
        ids.squeeze().tolist(),
        tokens,
        token_predictions,
        inputs["offset_mapping"].squeeze().tolist(),
    )

    current_span = None
    current_tag = None
    predicted_results = {}
    for token_to_word_idx, (t_idx, t, t_pred, t_offsets) in enumerate(zipped_args):
        if not t_offsets[0] == t_offsets[1]:
            word_id = inputs.token_to_word(token_to_word_idx)
            word_offset = inputs.word_to_chars(word_id)
            if t_pred.startswith("B"):
                if current_span and current_tag:
                    yield {"offset_start": current_span[0], "offset_finish": current_span[1], "type": current_tag}
                current_span = [word_offset[0], word_offset[1]]
                current_tag = t_pred[2:]
            if t_pred.startswith("I"):
                if current_span:
                    current_span[1] = word_offset[1]
                else:
                    current_span = [word_offset[0], word_offset[1]]
                    current_tag = t_pred[2:]
            if t_pred == "O":
                if current_span and current_tag:
                    yield {"offset_start": current_span[0], "offset_finish": current_span[1], "type": current_tag}
                    current_span = current_tag = None
                # predicted_results[word_id] = t_pred[2:]

    if current_span and current_tag:
        yield {"offset_start": current_span[0], "offset_finish": current_span[1], "type": current_tag}


def create_submission(test_df):
    submission_preds = []
    # BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=len(labels_to_ids))
    # BertForTokenClassification(self.num_classes, bert_model_name = self.bert_model_name,)
    model = BertForTokenClassification.from_pretrained(BERT_MODEL_NAME, num_labels=len(labels_to_ids))
    checkpoint = torch.load(WORK_DIR / 'best.pth', map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint)
    model.to(device)
    checkpoint = None
    del checkpoint
    
    for idx, row in abstracts_test_df.iterrows():
        for pred_row in predict(model, row["full_text"]):
            pred_row["abstract_id"] = row["abstract_id"]
            submission_preds.append(pred_row)

    preds_df = pd.DataFrame(submission_preds)
    preds_df = preds_df[["abstract_id", "offset_start", "offset_finish", "type"]]
    preds_df.to_csv(WORK_DIR / "lit_bert_subm.csv", sep="\t", index_label="id")

    return preds_df


create_submission(abstracts_test_df)

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-large

Unnamed: 0,abstract_id,offset_start,offset_finish,type
0,1711760,23,35,DiseaseOrPhenotypicFeature
1,1711760,49,66,DiseaseOrPhenotypicFeature
2,1711760,78,89,DiseaseOrPhenotypicFeature
3,1711760,113,125,DiseaseOrPhenotypicFeature
4,1711760,165,197,DiseaseOrPhenotypicFeature
...,...,...,...,...
3360,30442153,1766,1772,GeneOrGeneProduct
3361,30442153,1780,1788,GeneOrGeneProduct
3362,30442153,1911,1917,GeneOrGeneProduct
3363,30442153,1921,1924,GeneOrGeneProduct


In [None]:
preds_loaded_df = pd.read_csv(WORK_DIR / "lit_bert_subm.csv", sep="\t")
preds_loaded_df

Unnamed: 0,id,abstract_id,offset_start,offset_finish,type
0,0,1711760,23,35,DiseaseOrPhenotypicFeature
1,1,1711760,49,66,DiseaseOrPhenotypicFeature
2,2,1711760,78,89,DiseaseOrPhenotypicFeature
3,3,1711760,113,125,DiseaseOrPhenotypicFeature
4,4,1711760,165,197,DiseaseOrPhenotypicFeature
...,...,...,...,...,...
3360,3360,30442153,1766,1772,GeneOrGeneProduct
3361,3361,30442153,1780,1788,GeneOrGeneProduct
3362,3362,30442153,1911,1917,GeneOrGeneProduct
3363,3363,30442153,1921,1924,GeneOrGeneProduct


In [None]:
merged_preds_df = pd.merge(preds_loaded_df, abstracts_test_df, on="abstract_id", how="left")

In [None]:
for idx, row in pd.merge(preds_loaded_df, abstracts_test_df, on="abstract_id", how="left").sample(50).iterrows():
    _s = row["offset_start"]
    _f = row["offset_finish"]
    print(row["type"], row["full_text"][_s:_f], _s, _f)


DiseaseOrPhenotypicFeature NYS 1225 1228
ChemicalEntity Diazinon 117 125
DiseaseOrPhenotypicFeature type 1 diabetes 364 379
GeneOrGeneProduct AT1aR 1099 1104
DiseaseOrPhenotypicFeature SEPN-related myopathy 633 654
GeneOrGeneProduct IL-17 1216 1221
DiseaseOrPhenotypicFeature inflammatory 1603 1615
ChemicalEntity AraG 579 583
GeneOrGeneProduct CB 1642 1644
SequenceVariant 677C>T 1708 1714
ChemicalEntity glucose 953 960
DiseaseOrPhenotypicFeature metabolic syndrome 336 354
ChemicalEntity NaCl 422 426
SequenceVariant rs2476601 919 928
DiseaseOrPhenotypicFeature ALIOS 877 882
DiseaseOrPhenotypicFeature retinal diseases 244 260
DiseaseOrPhenotypicFeature sporadic Alzheimer's disease 93 121
ChemicalEntity ribavirin 308 317
DiseaseOrPhenotypicFeature type 2B VWD 1471 1482
ChemicalEntity CPM 235 238
DiseaseOrPhenotypicFeature cancer 1143 1149
DiseaseOrPhenotypicFeature deafness 1025 1033
GeneOrGeneProduct V-ets erythroblastosis virus E26 oncogene homolog2 441 491
GeneOrGeneProduct glycoprotein