# BERT DDI Model

## Set-up

In [51]:
import os
import data_processor
from torch_shallow_neural_classifier import TorchShallowNeuralClassifier
from sklearn.metrics import classification_report
import utils
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
import pandas as pd
import itertools
from sklearn.model_selection import train_test_split

In [7]:
# Set all the random seeds for reproducibility. Only the
# system and torch seeds are relevant for this notebook.

utils.fix_random_seeds()

In [8]:
SEMEVAL_HOME = os.path.join("semeval", "task9_train_pair")

## Start logging

In [11]:
import logging
logger = logging.getLogger()
logger.level = logging.ERROR

## Define Model

In [13]:
class HfBertClassifierModel(nn.Module):
    def __init__(self, n_classes, weights_name='bert-base-uncased'):
        super().__init__()
        self.n_classes = n_classes
        self.weights_name = weights_name
        self.bert = BertModel.from_pretrained(self.weights_name)
        self.hidden_dim = self.bert.embeddings.word_embeddings.embedding_dim
        # The only new parameters -- the classifier layer:
        self.W = nn.Linear(self.hidden_dim, self.n_classes)
        
    def forward(self, X):
        """Here, `X` is an np.array in which each element is a pair 
        consisting of an index into the BERT embedding and a 1 or 0
        indicating whether the token is masked. The `fit` method will 
        train all these parameters against a softmax objective.
        
        """
        indices = X[: , 0, : ]
        # Type conversion, since the base class insists on
        # casting this as a FloatTensor, but we ned Long
        # for `bert`.
        indices = indices.long()
        mask = X[: , 1, : ]      
        (final_hidden_states, cls_output) = self.bert(
            indices, attention_mask=mask)       
        return self.W(cls_output)

In [14]:
class HfBertClassifier(TorchShallowNeuralClassifier):
    def __init__(self, weights_name, *args, **kwargs):
        self.weights_name = weights_name
        self.tokenizer = BertTokenizer.from_pretrained(self.weights_name)
        super().__init__(*args, **kwargs)
        
    def define_graph(self):
        """This method is used by `fit`. We override it here to use our
        new BERT-based graph.
        
        """
        bert = HfBertClassifierModel(
            self.n_classes_, weights_name=self.weights_name)
        bert.train()
        return bert
    
    def encode(self, X, max_length=None):
        """The `X` is a list of strings. We use the model's tokenizer
        to get the indices and mask information.
        
        Returns
        -------
        list of [index, mask] pairs, where index is an int and mask
        is 0 or 1.
        
        """
        data = self.tokenizer.batch_encode_plus(
            X, 
            max_length=max_length,
            add_special_tokens=True, 
            pad_to_max_length=True,
            return_attention_mask=True)
        indices = data['input_ids']
        mask = data['attention_mask']
        return [[i, m] for i, m in zip(indices, mask)]

## Import data

In [36]:
os.getcwd()

'c:\\Users\\julien_lauret\\Documents\\Python Scripts\\Stanford NLU\\cs224u_project'

In [35]:
os.chdir('./Stanford NLU/cs224u_project')

In [37]:
os.chdir('./SemEval/semeval_task9_train_pair')
DB_dataset = data_processor.Dataset('DrugBank').from_training_data('DrugBank')
ML_dataset = data_processor.Dataset('MedLine').from_training_data('MedLine')
#os.chdir('../')

nk/Clomipramine_ddi.xml
./Train/DrugBank/Clonazepam_ddi.xml
./Train/DrugBank/Clonidine_ddi.xml
./Train/DrugBank/Clopidogrel_ddi.xml
./Train/DrugBank/Clorazepate_ddi.xml
./Train/DrugBank/Clozapine_ddi.xml
./Train/DrugBank/Coagulation factor VIIa_ddi.xml
./Train/DrugBank/Codeine_ddi.xml
./Train/DrugBank/Colchicine_ddi.xml
./Train/DrugBank/Colesevelam_ddi.xml
./Train/DrugBank/Colestipol_ddi.xml
./Train/DrugBank/Colistimethate_ddi.xml
./Train/DrugBank/Conivaptan_ddi.xml
./Train/DrugBank/Conjugated Estrogens_ddi.xml
./Train/DrugBank/Corticotropin_ddi.xml
./Train/DrugBank/Cortisone acetate_ddi.xml
./Train/DrugBank/Cosyntropin_ddi.xml
./Train/DrugBank/Cromoglicate_ddi.xml
./Train/DrugBank/Crotamiton_ddi.xml
./Train/DrugBank/Cyanocobalamin_ddi.xml
./Train/DrugBank/Cyclobenzaprine_ddi.xml
./Train/DrugBank/Cyclopentolate_ddi.xml
./Train/DrugBank/Cyclophosphamide_ddi.xml
./Train/DrugBank/Cycloserine_ddi.xml
./Train/DrugBank/Cyproheptadine_ddi.xml
./Train/DrugBank/Cysteamine_ddi.xml
./Train/DrugBa

In [40]:
def create_classification_task(dataset):
    '''Take a dataprocessor dataset object and return a Pandas dataframe for classification task '''

    classification_task_df = pd.DataFrame(columns=['e1_id','e1_type','e1_name','e2_id','e2_type','e2_name','sentence','ddi','label'])

    for doc in dataset.documents:
        for sent in doc.sentences:
            if len(sent.map)==0:
                if len(sent.entities)==1:
                    classification_task_df = classification_task_df.append({
                        'e1_id': sent.entities[0]._id,
                        'e1_type': sent.entities[0].type,
                        'e1_name': sent.entities[0].text,
                        'e2_id':"",
                        'e2_type':"",
                        'e2_name':"",
                        'sentence':sent.text,
                        'ddi':False,
                        'label':'NO_DDI'
                    },ignore_index=True)
                if len(sent.entities)>2:
                    for (i,j)in itertools.combinations(range(len(sent.entities)),2):
                        classification_task_df = classification_task_df.append({
                            'e1_id': sent.entities[i]._id,
                            'e1_type': sent.entities[i].type,
                            'e1_name': sent.entities[i].text,
                            'e2_id':sent.entities[j]._id,
                            'e2_type':sent.entities[j].type,
                            'e2_name':sent.entities[j].text,
                            'sentence':sent.text,
                            'ddi':False,
                            'label':'NO_DDI'
                        },ignore_index=True)
            if len(sent.map)>0:
                for i, (k, v) in enumerate(sent.map.items()):
                    for entity in sent.entities:
                        if entity._id == k:
                            e1_id = k
                            e1_type = entity.type
                            e1_name = entity.text
                        if entity._id == v[0]:
                            e2_id = v[0]
                            e2_type = entity.type
                            e2_name = entity.text
                    classification_task_df = classification_task_df.append({
                            'e1_id': e1_id,
                            'e1_type':e1_type,
                            'e1_name':e1_name,
                            'e2_id':e2_id,
                            'e2_type':e2_type,
                            'e2_name':e2_name,
                            'sentence':sent.text,
                            'ddi':True,
                            'label':v[1]
                        },ignore_index=True)
    return classification_task_df

In [44]:
classification_task_df = pd.concat([
    create_classification_task(DB_dataset),
    create_classification_task(ML_dataset)
])

In [46]:
classification_task_df.tail()

Unnamed: 0,e1_id,e1_type,e1_name,e2_id,e2_type,e2_name,sentence,ddi,label
1610,DDI-MedLine.d113.s1.e1,group,anxiolytic,DDI-MedLine.d113.s1.e2,group,hypnotic drugs,The benzodiazepines are a family of anxiolytic...,False,NO_DDI
1611,DDI-MedLine.d113.s2.e0,drug,ethanol,,,,"When taken concurrently with ethanol, a pharma...",False,NO_DDI
1612,DDI-MedLine.d113.s3.e0,drug,temazepam,DDI-MedLine.d113.s3.e2,drug,ethanol,In addition to this pharmacological interactio...,True,mechanism
1613,DDI-MedLine.d113.s5.e0,drug,alcohol,DDI-MedLine.d113.s5.e1,group,"3-hydroxy-1,4-benzodiazepine",The results raise the possibility that the eth...,True,mechanism
1614,DDI-MedLine.d113.s6.e0,drug,ethanol,,,,The acid-catalyzed ethanol-drug reaction is a ...,False,NO_DDI


In [47]:
classification_task_df.groupby("label").count()

Unnamed: 0_level_0,e1_id,e1_type,e1_name,e2_id,e2_type,e2_name,sentence,ddi
label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
NO_DDI,11248,11248,11248,11248,11248,11248,11248,11248
advise,589,589,589,589,589,589,589,589
effect,1085,1085,1085,1085,1085,1085,1085,1085
int,61,61,61,61,61,61,61,61
mechanism,934,934,934,934,934,934,934,934


In [48]:
# the task is quite unbalanced. "int" sample is very small

## Train model 

In [50]:
bert_experiment_1 = HfBertClassifier(
    'bert-base-uncased', 
    batch_size=16, # small batch size for use on notebook
    max_iter=4, 
    eta=0.00002)

In [52]:
X_text_train, X_text_test, y_train, y_test = train_test_split(classification_task_df['sentence'], classification_task_df['label'], test_size=0.2, random_state=42)

In [54]:
X_indices_train = bert_experiment_1.encode(X_text_train)
X_indices_test = bert_experiment_1.encode(X_text_test)

In [56]:
%time _ = bert_experiment_1.fit(X_indices_train, y_train)

Downloading: 100%|██████████| 433/433 [00:00<00:00, 452kB/s]
Downloading: 100%|██████████| 440M/440M [00:41<00:00, 10.5MB/s]
Finished epoch 4 of 4; error is 29.41819013369968Wall time: 19h 56min 7s


In [59]:

bert_experiment_1_preds = bert_experiment_1.predict(X_indices_test)

In [60]:
print(classification_report(bert_experiment_1_preds, y_test , digits=3))

precision    recall  f1-score   support

      NO_DDI      0.987     0.982     0.984      2264
      advise      0.854     0.874     0.864       127
      effect      0.856     0.877     0.867       204
         int      0.643     0.692     0.667        13
   mechanism      0.894     0.909     0.901       176

    accuracy                          0.963      2784
   macro avg      0.847     0.867     0.857      2784
weighted avg      0.964     0.963     0.964      2784



In [65]:
os.getcwd()
#os.chdir('../')

'c:\\Users\\julien_lauret\\Documents\\Python Scripts\\Stanford NLU\\cs224u_project'

In [67]:
bert_experiment_1.to_pickle('BERT_exp1.pkl')