In [70]:
# import os
# import re
# import csv
# import itertools

import nltk
# import pandas as pd
# import numpy as np
# import seaborn as sns
# import matplotlib.pyplot as plt
# from tqdm import tqdm, trange
from collections import defaultdict, OrderedDict

import torch
import torch.nn as nn
# from torch.utils.data import TensorDataset, DataLoader
# from torch.utils.data import RandomSampler, SequentialSampler
from transformers import BertModel, BertTokenizer, BertConfig
# from transformers import BertForTokenClassification, AdamW
# from transformers import get_linear_schedule_with_warmup

# import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
# from sklearn.model_selection import train_test_split

In [38]:
tokenizer = BertTokenizer(vocab_file='/Users/feiwang/Documents/Projects/biobert/biobert_v1.1_pubmed/vocab.txt', do_lower_case=False)

# check if GPU available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
max_seq_length = 64

In [1]:
test = [{'guid': 'test-0',
  'text': 'Clustering of missense mutations in the ataxia - telangiectasia gene in a sporadic T - cell leukaemia .',
  'label': 'O O O O O O B I I O O O B I I I I O'},
 {'guid': 'test-1',
  'text': 'Two of seventeen mutated T - PLL samples had a previously reported A - T allele .',
  'label': 'O O O O B I I O O O O O B I I O O'},
 {'guid': 'test-2',
  'text': 'In contrast , no mutations were detected in the p53 gene , suggesting that this tumour suppressor is not frequently altered in this leukaemia .',
  'label': 'O O O O O O O O O O O O O O O B O O O O O O O B O'},
 {'guid': 'test-3',
  'text': 'Occasional missense mutations in ATM were also found in tumour DNA from patients with B - cell non - Hodgkins lymphomas ( B - NHL ) and a B - NHL cell line .',
  'label': 'O O O O O O O O O B O O O O B I I I I I I O B I I O O O B I I O O O'},
 {'guid': 'test-4',
  'text': 'Constitutional RB1 - gene mutations in patients with isolated unilateral retinoblastoma .',
  'label': 'O O O O O O O O O B I O'},
 {'guid': 'test-5',
  'text': 'In most patients with isolated unilateral retinoblastoma , tumor development is initiated by somatic inactivation of both alleles of the RB1 gene .',
  'label': 'O O O O O B I O B O O O O O O O O O O O O O O'}]

In [2]:
label_list = ["[PAD]", "B", "I", "O", "X", "[CLS]", "[SEP]"] 

In [137]:
label_id_dict = {t: i for i,t in enumerate(label_list)}
id_label_dict = {i: t for i,t in enumerate(label_list)}

In [68]:
text = 'Clustering of missense mutations in the ataxia-telangiectasia gene in a sporadic T-cell leukaemia.'

In [14]:
texts = [text]

In [65]:
def text_to_tokens(sentence):
    # convert a question to a list of words
    word_list = nltk.word_tokenize(sentence)

    # convert a list of words to a (longer) list of tokens defined by vocab_file
    tokens = []
    for word in word_list:
        tokenized_word = tokenizer.tokenize(word)   
        tokens.extend(tokenized_word)

    # drop if token is longer than max_seq_length
    if len(tokens) >= max_seq_length - 1:
        tokens = tokens[0:(max_seq_length - 2)]

    return tokens

In [115]:
def process_sentences(sentences):
    list_of_token_lists = [text_to_tokens(sentence) for sentence in sentences]
    input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(token_list) for token_list in list_of_token_lists],
                              maxlen=max_seq_length, dtype="long", value=0.0,
                              truncating="post", padding="post")
    # attention masks make explicit reference to which tokens are actual words vs padded words
    # e.g. [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] a list of a 10-word sentence and 5 pads
    attention_masks = [[float(i != 0.0) for i in ii] for ii in input_ids]
    tensor_inputs = torch.tensor(input_ids)
    tensor_masks = torch.tensor(attention_masks)
    return tensor_inputs, tensor_masks

In [58]:
config = BertConfig.from_json_file('biobert_v1.1_pubmed/config.json')

# load the pretained biobert model
tmp_d = torch.load('biobert_v1.1_pubmed/pytorch_model.bin', map_location=device)
state_dict = OrderedDict()

for i in list(tmp_d.keys())[:199]:
    x = i
    if i.find('bert') > -1:
        x = '.'.join(i.split('.')[1:])
    state_dict[x] = tmp_d[i]

In [125]:
class BioBertNER(nn.Module):

    def __init__(self, label_num, config, state_dict):
        super().__init__()
        self.bert = BertModel(config)
        self.bert.load_state_dict(state_dict, strict=False)
        self.dropout = nn.Dropout(p=0.3)
        self.linear_output = nn.Linear(self.bert.config.hidden_size, label_num)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        encoded_layer = outputs[0]#[-1] # torch.Size([1, max_len, hidden_weights]) -> torch.Size([max_len, hidden_weights])
#         pool_layer = outputs[1]
        output = self.dropout(encoded_layer)
        output = self.linear_output(output)
        return output.argmax(-1)

In [126]:
model = BioBertNER(len(label_list), config, state_dict)

In [116]:
input_ids, attention_masks = process_sentences([text])

In [84]:
input_ids

tensor([[  140, 23225,  1158,  1104,  5529, 22615, 17157,  1107,  1103,  1120,
          7897,  1465,   118, 21359, 19514,  1663,  5822, 17506,  1161,  5565,
          1107,   170,   188, 27695,   157,   118,  2765,  5837, 12658, 20504,
           119,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0]])

In [94]:
encoded_layer, pool_layer = bert(input_ids=input_ids, attention_mask=attention_masks)

In [97]:
output = bert(input_ids=input_ids, attention_mask=attention_masks)

In [108]:
output[0].shape

torch.Size([1, 64, 768])

In [122]:
# output = dropout(output[0])
linear_output = nn.Linear(bert.config.hidden_size, label_num)
output = linear_output(output)

In [123]:
output.shape

torch.Size([1, 64, 7])

In [128]:
output.argmax(-1).shape

torch.Size([1, 64])

In [136]:
tag_id_dict

{'[PAD]': 0, 'B': 1, 'I': 2, 'O': 3, 'X': 4, '[CLS]': 5, '[SEP]': 6}

In [146]:
ids = input_ids.view(-1,input_ids.size()[-1])
masks = attention_masks.view(-1,attention_masks.size()[-1])
with torch.no_grad():
    y_hat = model(ids, masks)
output_ids= y_hat.to('cpu').numpy()
tokens = tokenizer.convert_ids_to_tokens(ids.to('cpu').numpy()[0])
tokens = [i for i in tokens if i != '[PAD]']
output_ids = output_ids[0][:len(tokens)]
new_tokens, new_labels = [], []
for token, label_idx in zip(tokens, output_ids):
    if token.startswith("##"):
        new_tokens[-1] = new_tokens[-1] + token[2:]
    else:
        new_labels.append(id_label_dict[label_idx])
        new_tokens.append(token)

In [147]:
for t,l in zip(new_tokens, new_labels):
    print("{}\t{}".format(t, l))

Clustering	B
of	[SEP]
missense	[SEP]
mutations	B
in	[SEP]
the	[SEP]
ataxia	[SEP]
-	[SEP]
telangiectasia	I
gene	[PAD]
in	B
a	B
sporadic	[SEP]
T	[PAD]
-	[PAD]
cell	B
leukaemia	[PAD]
.	[PAD]


In [None]:
new_labels = []

In [75]:
def token_long_to_short(tokens_vocab):
    tokens_word_only = []
    for token in tokens_vocab:
        if token.startswith("##"):
            tokens_word_only[-1] = tokens_word_only[-1] + token[2:]
        else:

            tokens_word_only.append(token)
    return tokens_word_only

In [78]:
[token_long_to_short(i) for i in tokens]

[['Clustering',
  'of',
  'missense',
  'mutations',
  'in',
  'the',
  'ataxia',
  '-',
  'telangiectasia',
  'gene',
  'in',
  'a',
  'sporadic',
  'T',
  '-',
  'cell',
  'leukaemia',
  '.']]