In [None]:
data_folder = 'F:/BioNER-Abbrev/Dataset/NCBI'

In [None]:
from transformers import AutoModel, AutoTokenizer

model_name = "athiban2001/cord-scibert"

pubmedbert = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = False)
pubmedbert.save_pretrained(f"./pretrained_model/pubmedbert")
tokenizer.save_pretrained("./pretrained_model/pubmedbert")

In [None]:
import transformers

DEVICE = 'cuda'
MAX_LEN = 256
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 16
EPOCHS = 10
NUM_WORKER = 5
BASE_MODEL_PATH = './pretrained_model/pubmedbert'
MODEL_PATH = "models/NBCI_pubmedbert"
TRAINING_FILE = 'F:/BioNER-Abbrev/Dataset/JNLPBA'
TOKENIZER = transformers.BertTokenizer.from_pretrained(
    BASE_MODEL_PATH,
    do_lower_case=True
)

In [None]:
import torch

class EntityDataset:
    def __init__(self, texts, tags,enc_tag):
        # texts: [["hi", ",", "my", "name", "is", "abhishek"], ["hello".....]]
        # pos/tags: [[1 2 3 4 1 5], [....].....]]
        self.texts = texts
        # self.pos = pos
        self.tags = tags
        self.enc_tag=enc_tag
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, item):
        text = self.texts[item]
        # pos = self.pos[item]
        tags = self.tags[item]

        ids = []
        # target_pos = []
        target_tag =[]

        for i, s in enumerate(text):
            inputs = TOKENIZER.encode(
                str(s),
                add_special_tokens=False
            )
            # abhishek: ab ##hi ##sh ##ek
            input_len = len(inputs)
            ids.extend(inputs)
            # target_pos.extend([pos[i]] * input_len)
            target_tag.extend([tags[i]] * input_len)

        ids = ids[:MAX_LEN - 2]
        # target_pos = target_pos[:MAX_LEN - 2]
        target_tag = target_tag[:MAX_LEN - 2]

        ids = [102] + ids + [103]
        # target_pos = [0] + target_pos + [0]
        o_tag=self.enc_tag.transform(["O"])[0]
        target_tag = [o_tag] + target_tag + [o_tag]

        mask = [1] * len(ids)
        token_type_ids = [0] * len(ids)

        padding_len = MAX_LEN - len(ids)

        ids = ids + ([0] * padding_len)
        mask = mask + ([0] * padding_len)
        token_type_ids = token_type_ids + ([0] * padding_len)
        # target_pos = target_pos + ([0] * padding_len)
        target_tag = target_tag + ([0] * padding_len)

        return {
            "ids": torch.tensor(ids, dtype=torch.long),
            "mask": torch.tensor(mask, dtype=torch.long),
            "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
            # "target_pos": torch.tensor(target_pos, dtype=torch.long),
            "target_tag": torch.tensor(target_tag, dtype=torch.long),
            # "words":torch.tensor(words,dtype=torch.int)
        }

In [None]:
import torch
import transformers
import torch.nn as nn
from torchcrf import CRF


class EntityModel(nn.Module):
    def __init__(self, num_tag):
        super(EntityModel, self).__init__()
        self.num_tag = num_tag
        self.bert = transformers.BertModel.from_pretrained(BASE_MODEL_PATH,return_dict=False)
        self.bilstm= nn.LSTM(768, 1024 // 2, num_layers=1, bidirectional=True, batch_first=True)

        self.dropout_tag = nn.Dropout(0.3)
        
        self.hidden2tag_tag = nn.Linear(1024, self.num_tag)

        self.crf_tag = CRF(self.num_tag, batch_first=True)
    
    
    # return the loss only, not encode the tag
    def forward(self, ids, mask, token_type_ids, target_tag):
        x, _ = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)
        h, _ = self.bilstm(x)

        o_tag = self.dropout_tag(h)
        tag = self.hidden2tag_tag(o_tag)
        mask = torch.where(mask==1, True, False)

        loss_tag = - self.crf_tag(tag, target_tag, mask=mask, reduction='token_mean')
        loss=loss_tag
        
        return loss


    # encode the tag, dont return loss
    def encode(self, ids, mask, token_type_ids, target_tag):
        # Bert - BiLSTM
        x, _ = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)
        h, _ = self.bilstm(x)

        # drop out
        o_tag = self.dropout_tag(h)
        # o_pos = self.dropout_pos(h)

        # Hidden2Tag (Linear)
        tag = self.hidden2tag_tag(o_tag)
        mask = torch.where(mask==1, True, False)
        tag = self.crf_tag.decode(tag, mask=mask)

        return tag

In [None]:
import pandas as pd
import numpy as np

import joblib
import torch
import torch.utils.data

from sklearn import preprocessing
from sklearn import model_selection

total_tags=[]
with open(data_folder + "/classes.txt") as f:
  for line in f.readlines():
    total_tags.append(line.strip())

enc_tag = preprocessing.LabelEncoder()
enc_tag.fit(list(total_tags))

def process_data(data_path):
    sentences,tags=[],[]
    sentence,tag=[],[]
    
    total_tags=set()
    i=0
    
    for path in data_path:
        with open(path,"r") as f:
            for line in f:
                if i%10000==0:
                  print(len(sentences))
                i+=1
                line=line.strip()
                if line.startswith("-DOCSTART-"):
                    continue
                elif len(line)==0:
                    if sentence==[] and tag==[]:
                        continue
                    sentences.append(sentence)
                    tags.append(tag)
                    sentence,tag=[],[]
                else:
                    s,t=line.split("\t")
                    sentence.append(s)
                    tag.append(t)

    for i in range(len(tags)):
        tags[i]=enc_tag.transform(tags[i])
                
    return sentences, tags, enc_tag

In [None]:
from __future__ import division, print_function, unicode_literals

import sys
from collections import defaultdict

def split_tag(chunk_tag):
    """
    split chunk tag into IOBES prefix and chunk_type
    e.g. 
    B-PER -> (B, PER)
    O -> (O, None)
    """
    if chunk_tag == 'O':
        return ('O', None)
    return chunk_tag.split("-",maxsplit=1)

def is_chunk_end(prev_tag, tag):
    """
    check if the previous chunk ended between the previous and current word
    e.g. 
    (B-PER, I-PER) -> False
    (B-LOC, O)  -> True
    Note: in case of contradicting tags, e.g. (B-PER, I-LOC)
    this is considered as (B-PER, B-LOC)
    """
    prefix1, chunk_type1 = split_tag(prev_tag)
    prefix2, chunk_type2 = split_tag(tag)

    if prefix1 == 'O':
        return False
    if prefix2 == 'O':
        return prefix1 != 'O'

    if chunk_type1 != chunk_type2:
        return True

    return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S']

def is_chunk_start(prev_tag, tag):
    """
    check if a new chunk started between the previous and current word
    """
    prefix1, chunk_type1 = split_tag(prev_tag)
    prefix2, chunk_type2 = split_tag(tag)

    if prefix2 == 'O':
        return False
    if prefix1 == 'O':
        return prefix2 != 'O'

    if chunk_type1 != chunk_type2:
        return True

    return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S']


def calc_metrics(tp, p, t, percent=True):
    """
    compute overall precision, recall and FB1 (default values are 0.0)
    if percent is True, return 100 * original decimal value
    """
    precision = tp / p if p else 0
    recall = tp / t if t else 0
    fb1 = 2 * precision * recall / (precision + recall) if precision + recall else 0
    if percent:
        return 100 * precision, 100 * recall, 100 * fb1
    else:
        return precision, recall, fb1


def count_chunks(true_seqs, pred_seqs):
    """
    true_seqs: a list of true tags
    pred_seqs: a list of predicted tags
    return: 
    correct_chunks: a dict (counter), 
                    key = chunk types, 
                    value = number of correctly identified chunks per type
    true_chunks:    a dict, number of true chunks per type
    pred_chunks:    a dict, number of identified chunks per type
    correct_counts, true_counts, pred_counts: similar to above, but for tags
    """
    correct_chunks = defaultdict(int)
    true_chunks = defaultdict(int)
    pred_chunks = defaultdict(int)

    correct_counts = defaultdict(int)
    true_counts = defaultdict(int)
    pred_counts = defaultdict(int)

    prev_true_tag, prev_pred_tag = 'O', 'O'
    correct_chunk = None

    for true_tag, pred_tag in zip(true_seqs, pred_seqs):
        if true_tag == pred_tag:
            correct_counts[true_tag] += 1
        true_counts[true_tag] += 1
        pred_counts[pred_tag] += 1

        _, true_type = split_tag(true_tag)
        _, pred_type = split_tag(pred_tag)

        if correct_chunk is not None:
            true_end = is_chunk_end(prev_true_tag, true_tag)
            pred_end = is_chunk_end(prev_pred_tag, pred_tag)

            if pred_end and true_end:
                correct_chunks[correct_chunk] += 1
                correct_chunk = None
            elif pred_end != true_end or true_type != pred_type:
                correct_chunk = None

        true_start = is_chunk_start(prev_true_tag, true_tag)
        pred_start = is_chunk_start(prev_pred_tag, pred_tag)

        if true_start and pred_start and true_type == pred_type:
            correct_chunk = true_type
        if true_start:
            true_chunks[true_type] += 1
        if pred_start:
            pred_chunks[pred_type] += 1

        prev_true_tag, prev_pred_tag = true_tag, pred_tag
    if correct_chunk is not None:
        correct_chunks[correct_chunk] += 1

    return (correct_chunks, true_chunks, pred_chunks, 
        correct_counts, true_counts, pred_counts)

def get_result(correct_chunks, true_chunks, pred_chunks,
    correct_counts, true_counts, pred_counts, verbose=True):
    """
    if verbose, print overall performance, as well as preformance per chunk type;
    otherwise, simply return overall prec, rec, f1 scores
    """
    # sum counts
    sum_correct_chunks = sum(correct_chunks.values())
    sum_true_chunks = sum(true_chunks.values())
    sum_pred_chunks = sum(pred_chunks.values())

    sum_correct_counts = sum(correct_counts.values())
    sum_true_counts = sum(true_counts.values())

    nonO_correct_counts = sum(v for k, v in correct_counts.items() if k != 'O')
    nonO_true_counts = sum(v for k, v in true_counts.items() if k != 'O')

    chunk_types = sorted(list(set(list(true_chunks) + list(pred_chunks))))

    # compute overall precision, recall and FB1 (default values are 0.0)
    prec, rec, f1 = calc_metrics(sum_correct_chunks, sum_pred_chunks, sum_true_chunks)
    res = (prec, rec, f1)
    if not verbose:
        return res

    # print overall performance, and performance per chunk type
    
    print("processed %i tokens with %i phrases; " % (sum_true_counts, sum_true_chunks), end='')
    print("found: %i phrases; correct: %i.\n" % (sum_pred_chunks, sum_correct_chunks), end='')
        
    print("accuracy: %6.2f%%; (non-O)" % (100*nonO_correct_counts/nonO_true_counts))
    print("accuracy: %6.2f%%; " % (100*sum_correct_counts/sum_true_counts), end='')
    print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" % (prec, rec, f1))

    # for each chunk type, compute precision, recall and FB1 (default values are 0.0)
    for t in chunk_types:
        prec, rec, f1 = calc_metrics(correct_chunks[t], pred_chunks[t], true_chunks[t])
        print("%17s: " %t , end='')
        print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" %
                    (prec, rec, f1), end='')
        print("  %d" % pred_chunks[t])

    return res
    # you can generate LaTeX output for tables like in
    # http://cnts.uia.ac.be/conll2003/ner/example.tex
    # but I'm not implementing this

def evaluate(true_seqs, pred_seqs, verbose=True):
    (correct_chunks, true_chunks, pred_chunks,
        correct_counts, true_counts, pred_counts) = count_chunks(true_seqs, pred_seqs)
    result = get_result(correct_chunks, true_chunks, pred_chunks,
        correct_counts, true_counts, pred_counts, verbose=verbose)
    return result

In [None]:
from tqdm import tqdm
import numpy as np
from sklearn.metrics import classification_report

def train_fn(data_loader, model, optimizer, device):
    model.train()
    final_loss = 0
    for data in tqdm(data_loader, total=len(data_loader)):
        for k, v in data.items():
            data[k] = v.to(device)
        optimizer.zero_grad()
        loss = model(**data)
        loss.backward()
        optimizer.step()
        final_loss += loss.item()
    return final_loss / len(data_loader)

def eval_fn(data_loader, model, device):
    model.eval()
    final_loss = 0

    for data in tqdm(data_loader, total=len(data_loader)):
        for k, v in data.items():
            data[k] = v.to(device)
        loss = model(**data)
        final_loss += loss.item()
    return final_loss / len(data_loader)

def test_fn(dataset,model,device,enc_tag):
  final_test = []
  final_pred = []
  O=enc_tag.transform(["O"])[0]
  
  with torch.no_grad():
    for data in tqdm(dataset):
      for k, v in data.items():
          data[k] = v.to(device).unsqueeze(0)

      tag = model.encode(**data)
      padded_pred=tag[0]
      test=data["target_tag"].cpu()[0][:len(padded_pred)]
      test=enc_tag.inverse_transform(test)
      padded_pred=enc_tag.inverse_transform(padded_pred)
      final_pred.extend(padded_pred[1:-1])
      final_test.extend(test[1:-1])
  
  print(evaluate(final_test, final_pred))

def load_model(epochs):
  path=MODEL_PATH+f"_{epochs}.bin"
  device = torch.device(DEVICE)
  model = EntityModel(num_tag=num_tag)
  model.load_state_dict(torch.load(path))
  model.to(device)
  return model

In [None]:
if __name__ == "__main__":
    sentences, tag, enc_tag = process_data([data_folder + '/NCBI-disease-IOB/NCBI_train.tsv',
                                            data_folder + '/NCBI-disease-IOB/NCBI_dev.tsv'])
    test_sentences, test_tag, _ = process_data([data_folder + '/NCBI-disease-IOB/NCBI_test.tsv'])

    meta_data = {
        "enc_tag": enc_tag
    }

    joblib.dump(meta_data, "meta.bin")

    num_tag = len(list(enc_tag.classes_))

    (
        train_sentences,
        valid_sentences,
        train_tag,
        valid_tag
    ) = model_selection.train_test_split(sentences, tag, random_state=42, test_size=0.1)

    train_dataset = EntityDataset(texts=train_sentences, tags=train_tag,enc_tag=enc_tag)

    train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, num_workers=NUM_WORKER)

    valid_dataset = EntityDataset(texts=valid_sentences, tags=valid_tag,enc_tag=enc_tag)

    valid_data_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=VALID_BATCH_SIZE, num_workers=NUM_WORKER)

    test_dataset=EntityDataset(texts=test_sentences,tags=test_tag,enc_tag=enc_tag)

    device = torch.device(DEVICE)
    model = EntityModel(num_tag=num_tag)
    model.to(device)

    # optimizer = torch.optim.Adam(model.parameters(), lr=3e-5, weight_decay=0.001)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

In [None]:
# 
for epoch in range(EPOCHS):
    train_loss = train_fn(train_data_loader, model, optimizer, device)
    torch.cuda.empty_cache()
    valid_loss = eval_fn(valid_data_loader, model, device)
    torch.cuda.empty_cache()
    print(f"Train Loss = {train_loss}")
    print(f"Validation Loss = {valid_loss}")
    test_fn(test_dataset, model, device,enc_tag)
    torch.save(model.state_dict(), MODEL_PATH+f"_{epoch}.bin")

In [None]:
import joblib
import torch
import argparse

def predict_sentence(model, sentence, enc_tag):
    sentence = sentence.split()
    test_dataset = EntityDataset(
        texts=[sentence], 
        # pos=[[0] * len(sentence)], 
        tags=[[0] * len(sentence)],
        enc_tag=enc_tag
    )
    
    with torch.no_grad():
        data = test_dataset[0]
        for k, v in data.items():
            data[k] = v.to(device).unsqueeze(0)

        tag = model.encode(**data)
        tag = enc_tag.inverse_transform(tag[0])
        # pos = enc_pos.inverse_transform(pos[0])

    return tag

if __name__ == "__main__":
    sentence = "Management of Critically Ill Patients with Severe Acute Respiratory Syndrome (SARS)."

    # join the arr -> string sentence (nargs='+', dont have to use "" when enter the string)
    tokenized_sentence = TOKENIZER.encode(sentence)
    tokenized = TOKENIZER.tokenize(sentence)

    # meta_data: enc_pos/enc_tag - POS/TAG label encoder
    meta_data = joblib.load("meta.bin")
    # enc_pos = meta_data["enc_pos"]
    enc_tag = meta_data["enc_tag"]

    # num_pos = len(list(enc_pos.classes_))
    num_tag = len(list(enc_tag.classes_))

    # set up device, model
    device = torch.device(DEVICE)
    model = EntityModel(num_tag=num_tag)
    model.load_state_dict(torch.load(MODEL_PATH+f"_{9}.bin"))
    model.to(device)

    tags = predict_sentence(model, sentence, enc_tag)

    print(sentence)
    print(len(tags),len(tokenized_sentence))
    for token,tag in zip(tokenized_sentence,tags):
      print(token,TOKENIZER.decode(token),tag)