In [1]:
import json
from argparse import Namespace
from models import *
import torch.nn as nn
import numpy as np
from transformers import BertConfig, RobertaConfig, XLMRobertaConfig, BertModel, RobertaModel, XLMRobertaModel,RobertaTokenizer
from collections import namedtuple
import torch

In [2]:
task = "EAE";
dataset = "phee";
split = 1;
model_type = "CRFTagging";
pretrained_model_name = "roberta-base";
pretrained_model_alias = {
    "roberta-base": "roberta-base", 
};
config_dict =  {
        #// general config
        "task": task, 
        "dataset": dataset,
        "model_type": model_type, 
        "gpu_device": 0, 
        "seed": 0, 
        "cache_dir": "./cache", 
        "output_dir": "./outputs/%s_%s_%s_split%s_%s" % (model_type, task, dataset, split, pretrained_model_alias[pretrained_model_name]), 
        "train_file": "../data/preprocessing/%s/split%s/train.json" % (dataset, split),
        "dev_file": "../data/preprocessing/%s/split%s/dev.json" % (dataset, split),
        "test_file": "../data/preprocessing/%s/split%s/test.json" % (dataset, split),
        
        
        #// model config
        "pretrained_model_name": pretrained_model_name,
        "base_model_dropout": 0.2,
        "use_crf": True,
        "use_trigger_feature": True,
        "use_type_feature": True, 
        "type_feature_num": 100, 
        "linear_hidden_num": 150,
        "linear_dropout": 0.2,
        "linear_bias": True, 
        "linear_activation": "relu",
        "multi_piece_strategy": "average", 
        "max_length": 200, 
        
        # // train config
        "max_epoch": 30,
        "warmup_epoch": 5,
        "accumulate_step": 1,
        "train_batch_size": 6,
        "eval_batch_size": 12,
        "learning_rate": 0.001,
        "base_model_learning_rate": 1e-05,
        "weight_decay": 0.001,
        "base_model_weight_decay": 1e-05,
        "grad_clipping": 5.0,
    }
config = Namespace(**config_dict)

In [3]:
# load trainer
VALID_TASKS = ["E2E", "ED", "EAE", "EARL"]

TRAINER_MAP = {

    ("CRFTagging", "ED"): CRFTaggingEDTrainer, 
    ("CRFTagging", "EAE"): CRFTaggingEAETrainer
}
trainer_class = TRAINER_MAP[(config.model_type, config.task)]

In [4]:
def load_EAE_data(file, add_extra_info_fn, config):

    with open(file, 'r', encoding='utf-8') as fp:
        lines = fp.readlines()
    data = [json.loads(line) for line in lines]
    
    instances = []
    for dt in data:
        
        entities = dt['entity_mentions']

        event_mentions = dt['event_mentions']
        event_mentions.sort(key=lambda x: x['trigger']['start'])

        entity_map = {entity['id']: entity for entity in entities}
        for i, event_mention in enumerate(event_mentions):
            # trigger = (start index, end index, event type, text span)
            trigger = (event_mention['trigger']['start'], 
                       event_mention['trigger']['end'], 
                       event_mention['event_type'], 
                       event_mention['trigger']['text'])

            arguments = []
            for arg in event_mention['arguments']:
                mapped_entity = entity_map[arg['entity_id']]
                
                # argument = (start index, end index, role type, text span)
                argument = (mapped_entity['start'], mapped_entity['end'], arg['role'], arg['text'])
                arguments.append(argument)

            arguments.sort(key=lambda x: (x[0], x[1]))
            
            instance = {"doc_id": dt["doc_id"], 
                        "wnd_id": dt["wnd_id"], 
                        "tokens": dt["tokens"], 
                        "text": dt["text"], 
                        "trigger": trigger, 
                        "arguments": arguments, 
                       }

            instances.append(instance)
            
    trigger_type_set = set()
    for instance in instances:
        trigger_type_set.add(instance['trigger'][2])

    role_type_set = set()
    for instance in instances:
        for argument in instance["arguments"]:
            role_type_set.add(argument[2])
                
    type_set = {"trigger": trigger_type_set, "role": role_type_set}
    
    # approach-specific preprocessing
    new_instances = add_extra_info_fn(instances, data, config)
    assert len(new_instances) == len(instances)
    
    print('Loaded {} EAE instances ({} trigger types and {} role types) from {}'.format(
        len(new_instances), len(trigger_type_set), len(role_type_set), file))
    
    return new_instances, type_set
    
    return new_instances, type_set
if config.task == "EAE":
        train_data, train_type_set = load_EAE_data(config.train_file, trainer_class.add_extra_info_fn, config)
        dev_data, dev_type_set = load_EAE_data(config.dev_file, trainer_class.add_extra_info_fn, config)
        test_data, test_type_set = load_EAE_data(config.test_file, trainer_class.add_extra_info_fn, config)
        type_set = {"trigger": train_type_set["trigger"] | dev_type_set["trigger"] | test_type_set["trigger"], 
                    "role": train_type_set["role"] | dev_type_set["role"] | test_type_set["role"]}
        print("There are {} trigger types and {} role types in total".format(len(type_set["trigger"]), len(type_set["role"])))

Loaded 3003 EAE instances (2 trigger types and 16 role types) from ../data/preprocessing/phee/split1/train.json
Loaded 1011 EAE instances (2 trigger types and 16 role types) from ../data/preprocessing/phee/split1/dev.json
Loaded 1005 EAE instances (2 trigger types and 16 role types) from ../data/preprocessing/phee/split1/test.json
There are 2 trigger types and 16 role types in total


In [58]:
train_data[3]

{'doc_id': '10082597_3',
 'wnd_id': '10082597_3_1',
 'tokens': ['RESULTS',
  ':',
  'A',
  '44',
  '-',
  'year',
  '-',
  'old',
  'man',
  'taking',
  'naproxen',
  'for',
  'chronic',
  'low',
  'back',
  'pain',
  'and',
  'a',
  '20',
  '-',
  'year',
  '-',
  'old',
  'woman',
  'on',
  'oxaprozin',
  'for',
  'rheumatoid',
  'arthritis',
  'presented',
  'with',
  'tense',
  'bullae',
  'and',
  'cutaneous',
  'fragility',
  'on',
  'the',
  'face',
  'and',
  'the',
  'back',
  'of',
  'the',
  'hands',
  '.'],
 'text': 'RESULTS : A 44 - year - old man taking naproxen for chronic low back pain and a 20 - year - old woman on oxaprozin for rheumatoid arthritis presented with tense bullae and cutaneous fragility on the face and the back of the hands .',
 'trigger': (29, 30, 'Adverse_event', 'presented'),
 'arguments': [(17, 24, 'Subject', 'a 20 - year - old woman'),
  (18, 23, 'Subject_Age', '20 - year - old'),
  (23, 24, 'Subject_Gender', 'woman'),
  (25, 26, 'Treatment', 'oxapro

In [59]:
train_data[4]

{'doc_id': '10082597_3',
 'wnd_id': '10082597_3_1',
 'tokens': ['RESULTS',
  ':',
  'A',
  '44',
  '-',
  'year',
  '-',
  'old',
  'man',
  'taking',
  'naproxen',
  'for',
  'chronic',
  'low',
  'back',
  'pain',
  'and',
  'a',
  '20',
  '-',
  'year',
  '-',
  'old',
  'woman',
  'on',
  'oxaprozin',
  'for',
  'rheumatoid',
  'arthritis',
  'presented',
  'with',
  'tense',
  'bullae',
  'and',
  'cutaneous',
  'fragility',
  'on',
  'the',
  'face',
  'and',
  'the',
  'back',
  'of',
  'the',
  'hands',
  '.'],
 'text': 'RESULTS : A 44 - year - old man taking naproxen for chronic low back pain and a 20 - year - old woman on oxaprozin for rheumatoid arthritis presented with tense bullae and cutaneous fragility on the face and the back of the hands .',
 'trigger': (29, 30, 'Adverse_event', 'presented'),
 'arguments': [(2, 9, 'Subject', 'A 44 - year - old man'),
  (3, 8, 'Subject_Age', '44 - year - old'),
  (8, 9, 'Subject_Gender', 'man'),
  (10, 11, 'Treatment', 'naproxen'),
  (1

In [35]:
# train
trainer = trainer_class(config, type_set)

In [36]:
tokenizer = RobertaTokenizer.from_pretrained(config.pretrained_model_name, cache_dir=config.cache_dir, do_lower_case=False, add_prefix_space=True)

In [37]:
def process_data( data):
        assert tokenizer, "Please load model and tokneizer before processing data!"
        
        print("Removing overlapping arguments and over-length examples")
        
        # greedily remove overlapping arguments
        n_total = 0
        new_data = []
        for dt in data:
            
            n_total += 1
            
            if len(dt["tokens"]) > config.max_length:
                continue
            
            trigger = dt["trigger"]
            no_overlap_flag = np.ones((len(dt["tokens"]), ), dtype=bool)
            new_arguments = []
            for argument in sorted(dt["arguments"]):
                start, end = argument[0], argument[1]
                if np.all(no_overlap_flag[start:end]):
                    new_arguments.append(argument)
                    no_overlap_flag[start:end] = False
            
            pieces = [tokenizer.tokenize(t, is_split_into_words=True) for t in dt["tokens"]]
            token_lens = [len(p) for p in pieces] 

            new_dt = {"doc_id": dt["doc_id"], 
                      "wnd_id": dt["wnd_id"], 
                      "tokens": dt["tokens"], 
                      "pieces": [p for w in pieces for p in w], 
                      "token_lens": token_lens, 
                      "token_num": len(dt["tokens"]), 
                      "text": dt["text"], 
                      "trigger": dt["trigger"], 
                      "arguments": new_arguments
                     }
            
            
            new_data.append(new_dt)
                
        print(f"There are {len(new_data)}/{n_total} EAE instances after removing overlapping arguments and over-length examples")

        return new_data
internal_train_data = process_data(train_data)
internal_dev_data = process_data(dev_data)

Removing overlapping arguments and over-length examples
There are 3003/3003 EAE instances after removing overlapping arguments and over-length examples
Removing overlapping arguments and over-length examples
There are 1011/1011 EAE instances after removing overlapping arguments and over-length examples


In [38]:
train_data[1]

{'doc_id': '10048291_2',
 'wnd_id': '10048291_2_1',
 'tokens': ['Unaccountable',
  'severe',
  'hypercalcemia',
  'in',
  'a',
  'patient',
  'treated',
  'for',
  'hypoparathyroidism',
  'with',
  'dihydrotachysterol',
  '.'],
 'text': 'Unaccountable severe hypercalcemia in a patient treated for hypoparathyroidism with dihydrotachysterol .',
 'trigger': (6, 7, 'Adverse_event', 'treated'),
 'arguments': [(0, 3, 'Effect', 'Unaccountable severe hypercalcemia'),
  (5, 6, 'Subject', 'patient'),
  (8, 9, 'Treatment_Disorder', 'hypoparathyroidism'),
  (10, 11, 'Treatment', 'dihydrotachysterol'),
  (10, 11, 'Treatment_Drug', 'dihydrotachysterol')],
 'extra_info': None}

In [39]:
EAEBatch_fields = ['batch_doc_id', 'batch_wnd_id', 'batch_tokens', 'batch_pieces', 'batch_token_lens', 'batch_token_num', 'batch_text', 'batch_trigger', 'batch_arguments']
EAEBatch = namedtuple('EAEBatch', field_names=EAEBatch_fields, defaults=[None] * len(EAEBatch_fields))

def EAE_collate_fn(batch):
    return EAEBatch(
        batch_doc_id=[instance["doc_id"] for instance in batch],
        batch_wnd_id=[instance["wnd_id"] for instance in batch],
        batch_tokens=[instance["tokens"] for instance in batch], 
        batch_pieces=[instance["pieces"] for instance in batch], 
        batch_token_lens=[instance["token_lens"] for instance in batch], 
        batch_token_num=[instance["token_num"] for instance in batch], 
        batch_text=[instance["text"] for instance in batch], 
        batch_trigger=[instance["trigger"] for instance in batch], 
        batch_arguments=[instance["arguments"] for instance in batch], 
    )

In [40]:
from torch.utils.data import DataLoader
train = DataLoader(internal_train_data, batch_size=3, 
                                                         shuffle=True, drop_last=False, collate_fn=EAE_collate_fn)
batch = next(iter(train))

In [41]:
def generate_tagging_vocab():
    prefix = ['B', 'I']
    trigger_label_stoi = {'O': 0}
    for t in sorted(type_set["trigger"]):
        for p in prefix:
            trigger_label_stoi['{}-{}'.format(p, t)] = len(trigger_label_stoi)

    role_label_stoi = {'O': 0}
    for t in sorted(type_set["role"]):
        for p in prefix:
            role_label_stoi['{}-{}'.format(p, t)] = len(role_label_stoi)
    
    label_stoi = {"trigger": trigger_label_stoi, "role": role_label_stoi}
    
    trigger_type_stoi = {t: i for i, t in enumerate(sorted(type_set["trigger"]))}
    role_type_stoi = {t: i for i, t in enumerate(sorted(type_set["role"]))}
    type_stoi = {"trigger": trigger_type_stoi, "role": role_type_stoi}
    return label_stoi,type_stoi
label_stoi,type_stoi = generate_tagging_vocab()

In [42]:
label_stoi

{'trigger': {'O': 0,
  'B-Adverse_event': 1,
  'I-Adverse_event': 2,
  'B-Potential_therapeutic_event': 3,
  'I-Potential_therapeutic_event': 4},
 'role': {'O': 0,
  'B-Combination_Drug': 1,
  'I-Combination_Drug': 2,
  'B-Effect': 3,
  'I-Effect': 4,
  'B-Subject': 5,
  'I-Subject': 6,
  'B-Subject_Age': 7,
  'I-Subject_Age': 8,
  'B-Subject_Disorder': 9,
  'I-Subject_Disorder': 10,
  'B-Subject_Gender': 11,
  'I-Subject_Gender': 12,
  'B-Subject_Population': 13,
  'I-Subject_Population': 14,
  'B-Subject_Race': 15,
  'I-Subject_Race': 16,
  'B-Treatment': 17,
  'I-Treatment': 18,
  'B-Treatment_Disorder': 19,
  'I-Treatment_Disorder': 20,
  'B-Treatment_Dosage': 21,
  'I-Treatment_Dosage': 22,
  'B-Treatment_Drug': 23,
  'I-Treatment_Drug': 24,
  'B-Treatment_Duration': 25,
  'I-Treatment_Duration': 26,
  'B-Treatment_Freq': 27,
  'I-Treatment_Freq': 28,
  'B-Treatment_Route': 29,
  'I-Treatment_Route': 30,
  'B-Treatment_Time_elapsed': 31,
  'I-Treatment_Time_elapsed': 32}}

In [43]:
def get_role_seqlabels(roles, token_num, specify_role=None):
        labels = ['O'] * token_num
        count = 0
        for role in roles:
            start, end = role[0], role[1]
            if end > token_num:
                continue
            role_type = role[2]

            if specify_role is not None:
                if role_type != specify_role:
                    continue

            if any([labels[i] != 'O' for i in range(start, end)]):
                count += 1
                continue

            labels[start] = 'B-{}'.format(role_type)
            for i in range(start + 1, end):
                labels[i] = 'I-{}'.format(role_type)
                
        return labels

In [44]:
get_role_seqlabels(batch.batch_arguments[0],len(batch.batch_tokens[0]))

['O',
 'O',
 'B-Treatment',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-Effect',
 'I-Effect',
 'I-Effect',
 'I-Effect',
 'I-Effect',
 'I-Effect',
 'I-Effect',
 'I-Effect',
 'I-Effect',
 'I-Effect',
 'I-Effect',
 'O']

In [None]:
batch.batch_arguments

In [45]:
def process_data(batch):
        enc_idxs = []
        enc_attn = []
        role_seqidxs = []
        trigger_types = []
        token_lens = []
        token_nums = []
        triggers = []
        max_token_num = max(batch.batch_token_num)
        
        for tokens, pieces, trigger, arguments, token_len, token_num in zip(batch.batch_tokens, batch.batch_pieces, batch.batch_trigger, 
                                                                      batch.batch_arguments, batch.batch_token_lens, batch.batch_token_num):
            
            piece_id = tokenizer.convert_tokens_to_ids(pieces)
            enc_idx = [tokenizer.convert_tokens_to_ids(tokenizer.bos_token)] + piece_id + [tokenizer.convert_tokens_to_ids(tokenizer.eos_token)]
            
            enc_idxs.append(enc_idx)
            enc_attn.append([1]*len(enc_idx))  
            
            role_seq = get_role_seqlabels(arguments, len(tokens))
            trigger_types.append(type_stoi["trigger"][trigger[2]])
            token_lens.append(token_len)
            token_nums.append(token_num)
            triggers.append(trigger)
            if config.use_crf:
                role_seqidxs.append([label_stoi["role"][s] for s in role_seq] + [0] * (max_token_num-len(tokens)))
            else:
                role_seqidxs.append([label_stoi["role"][s] for s in role_seq] + [-100] * (max_token_num-len(tokens)))
        max_len = max([len(enc_idx) for enc_idx in enc_idxs])
        enc_idxs = torch.LongTensor([enc_idx + [tokenizer.convert_tokens_to_ids(tokenizer.pad_token)]*(max_len-len(enc_idx)) for enc_idx in enc_idxs])
        enc_attn = torch.LongTensor([enc_att + [0]*(max_len-len(enc_att)) for enc_att in enc_attn])
        trigger_types = torch.LongTensor(trigger_types)
        return enc_idxs, enc_attn, role_seqidxs, trigger_types, token_lens, torch.LongTensor(token_nums),triggers


In [46]:
enc_idxs, enc_attn, role_seqidxs, trigger_types, token_lens, token_nums, triggers = process_data(batch)

In [47]:
triggers

[(15, 16, 'Adverse_event', 'induce'),
 (8, 9, 'Adverse_event', 'diagnosed'),
 (70, 71, 'Adverse_event', 'experiencing')]

In [48]:
base_model = RobertaModel.from_pretrained(config.pretrained_model_name, 
                                                           cache_dir=config.cache_dir, 
                                                           output_hidden_states=True)
base_config = RobertaConfig.from_pretrained(config.pretrained_model_name, 
                                                             cache_dir=config.cache_dir)
base_model_dim = base_config.hidden_size
print(base_model_dim)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


768


In [49]:
def token_lens_to_idxs( token_lens):
        """Map token lengths to a word piece index matrix (for torch.gather) and a
        mask tensor.
        For example (only show a sequence instead of a batch):
        token lengths: [1,1,1,3,1]
        =>
        indices: [[0,0,0], [1,0,0], [2,0,0], [3,4,5], [6,0,0]]
        masks: [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0],
                [0.33, 0.33, 0.33], [1.0, 0.0, 0.0]]
        Next, we use torch.gather() to select vectors of word pieces for each token,
        and average them as follows (incomplete code):
        outputs = torch.gather(bert_outputs, 1, indices) * masks
        outputs = bert_outputs.view(batch_size, seq_len, -1, self.bert_dim)
        outputs = bert_outputs.sum(2)
        :param token_lens (list): token lengths.
        :return: a index matrix and a mask tensor.
        """
        max_token_num = max([len(x) for x in token_lens])
        max_token_len = max([max(x) for x in token_lens])
        idxs, masks = [], []
        for seq_token_lens in token_lens:
            seq_idxs, seq_masks = [], []
            offset = 0
            for token_len in seq_token_lens:
                seq_idxs.extend([i + offset for i in range(token_len)]
                                + [-1] * (max_token_len - token_len))
                seq_masks.extend([1.0 / token_len] * token_len
                                 + [0.0] * (max_token_len - token_len))
                offset += token_len
            seq_idxs.extend([-1] * max_token_len * (max_token_num - len(seq_token_lens)))
            seq_masks.extend([0.0] * max_token_len * (max_token_num - len(seq_token_lens)))
            idxs.append(seq_idxs)
            masks.append(seq_masks)
        return idxs, masks, max_token_num, max_token_len

In [50]:
base_model_dropout = nn.Dropout(p=config.base_model_dropout)
def encode(piece_idxs, attention_masks, token_lens):
    """Encode input sequences with BERT
    :param piece_idxs (LongTensor): word pieces indices
    :param attention_masks (FloatTensor): attention mask
    :param token_lens (list): token lengths
    """
    batch_size, _ = piece_idxs.size()
    all_base_model_outputs = base_model(piece_idxs, attention_mask=attention_masks)
    base_model_outputs = all_base_model_outputs[0]
    if config.multi_piece_strategy == 'first':
        # select the first piece for multi-piece words
        offsets = token_lens_to_offsets(token_lens)
        offsets = piece_idxs.new(offsets) # batch x max_token_num
        # + 1 because the first vector is for [CLS]
        offsets = offsets.unsqueeze(-1).expand(batch_size, -1, self.bert_dim) + 1
        base_model_outputs = torch.gather(base_model_outputs, 1, offsets)
    elif config.multi_piece_strategy == 'average':
        # average all pieces for multi-piece words
        idxs, masks, token_num, token_len = token_lens_to_idxs(token_lens)
        idxs = piece_idxs.new(idxs).unsqueeze(-1).expand(batch_size, -1, base_model_dim) + 1
        masks = base_model_outputs.new(masks).unsqueeze(-1)
        base_model_outputs = torch.gather(base_base_model_outputs.size()model_outputs, 1, idxs) * masks
        base_model_outputs = base_model_outputs.view(batch_size, token_num, token_len, base_model_dim)
        base_model_outputs = base_model_outputs.sum(2)
    else:
        raise ValueError(f'Unknown multi-piece token handling strategy: {config.multi_piece_strategy}')
    base_model_outputs = base_model_dropout(base_model_outputs)
    return base_model_outputs

In [51]:
# encoding
base_model_outputs = encode(enc_idxs, enc_attn, token_lens)

In [52]:
base_model_outputs.size()

torch.Size([3, 92, 768])

In [54]:
def get_trigger_embedding(base_model_outputs, triggers):
        masks = []
        max_tokens = base_model_outputs.size(1)
        for trigger in triggers:
            seq_masks = [0] * max_tokens
            for element in range(trigger[0], trigger[1]):
                seq_masks[element] = 1
            masks.append(seq_masks)
        masks = base_model_outputs.new(masks)
        average = ((base_model_outputs*masks.unsqueeze(-1))/((masks.sum(dim=1,keepdim=True)).unsqueeze(-1))).sum(1)

        return average # batch x bert_dim


In [55]:
trigger_vec = get_trigger_embedding(base_model_outputs, triggers)
print(trigger_vec.size())

torch.Size([3, 768])


In [56]:
extend_tri_vec = trigger_vec.unsqueeze(1).repeat(1, base_model_outputs.size(1), 1)
print(extend_tri_vec.size())

torch.Size([3, 92, 768])


In [57]:
feature_dim = base_model_dim*2 if config.use_trigger_feature else base_model_dim
if config.use_type_feature:
        feature_dim += config.type_feature_num
        type_feature_module = nn.Embedding(len(type_set["trigger"]), config.type_feature_num)