In [1]:
import json
from argparse import Namespace
from models import *
import torch.nn as nn
import numpy as np
from transformers import BertConfig, RobertaConfig, XLMRobertaConfig, BertModel, RobertaModel, XLMRobertaModel,RobertaTokenizer
from collections import namedtuple
import torch

In [2]:
task = "EAE";
dataset = "phee";
split = 1;
model_type = "TagPrime";
pretrained_model_name = "roberta-base";
pretrained_model_alias = {
    "roberta-base": "roberta-base", 
};

config_dict = {
                #// general config
                "task": task, 
                "dataset": dataset,
                "model_type": model_type, 
                "gpu_device": 0, 
                "seed": 0, 
                "cache_dir": "./cache", 
                "train_file": "../data/processed_data/%s/split%s/train.json" % (dataset, split),
                "dev_file": "../data/processed_data/%s/split%s/dev.json" % (dataset, split),
                "test_file": "../data/processed_data/%s/split%s/test.json" % (dataset, split),
                
                #// model config
                "pretrained_model_name": pretrained_model_name,
                "base_model_dropout": 0.2,
                "use_crf": True,
                "use_trigger_feature": True,
                "use_type_feature": True, 
                "type_feature_num": 100, 
                "linear_hidden_num": 150,
                "linear_dropout": 0.2,
                "linear_bias": True, 
                "linear_activation": "relu",
                "multi_piece_strategy": "average", 
                "priming_type": "condition", 
                "max_length": 200, 
                
                #// train config
                "max_epoch": 90,
                "warmup_epoch": 5,
                "accumulate_step": 1,
                "train_batch_size": 6,
                "eval_batch_size": 12,
                "learning_rate": 0.001,
                "base_model_learning_rate": 1e-05,
                "weight_decay": 0.001,
                "base_model_weight_decay": 1e-05,
                "grad_clipping": 5.0,
            }

config = Namespace(**config_dict)

In [3]:
# load trainer
VALID_TASKS = ["E2E", "ED", "EAE", "EARL"]

TRAINER_MAP = {

    ("CRFTagging", "ED"): CRFTaggingEDTrainer, 
    ("CRFTagging", "EAE"): CRFTaggingEAETrainer,
    ("TagPrime", "ED"): TagPrimeEDTrainer, 
    ("TagPrime", "EAE"): TagPrimeEAETrainer
}
trainer_class = TRAINER_MAP[(config.model_type, config.task)]

In [4]:
def load_EAE_data(file, add_extra_info_fn, config):

    with open(file, 'r', encoding='utf-8') as fp:
        lines = fp.readlines()
    data = [json.loads(line) for line in lines]
    
    instances = []
    cnt = 0
    for dt in data:
        
        entities = dt['entity_mentions']

        event_mentions = dt['event_mentions']
        event_mentions.sort(key=lambda x: x['trigger']['start'])
        entity_map = {entity['id']: entity for entity in entities}
        ins = {}
        for i, event_mention in enumerate(event_mentions):
            # trigger = (start index, end index, event type, text span)
            trigger = (event_mention['trigger']['start'], 
                       event_mention['trigger']['end'], 
                       event_mention['event_type'], 
                       event_mention['trigger']['text'])
            arguments = []
            for arg in event_mention['arguments']:
                mapped_entity = entity_map[arg['entity_id']]
                
                # argument = (start index, end index, role type, text span)
                argument = (mapped_entity['start'], mapped_entity['end'], arg['role'], arg['text'])
                arguments.append(argument)

            arguments.sort(key=lambda x: (x[0], x[1]))

            ### if trigger and event type same , combine the arguments !!!
            if trigger not in ins.keys():
                ins[trigger] = arguments
            else:
                for item in arguments:
                    ins[trigger].append(item)
                cnt+=1

        for i in range(len(ins.keys())):
            instance = {"doc_id": dt["doc_id"], 
                        "wnd_id": dt["wnd_id"], 
                        "tokens": dt["tokens"], 
                        "text": dt["text"], 
                        "trigger": list(ins.keys())[i], 
                        "arguments": ins[list(ins.keys())[i]], 
                    }
            instances.append(instance)
            
    trigger_type_set = set()
    for instance in instances:
        trigger_type_set.add(instance['trigger'][2])

    role_type_set = set()
    for instance in instances:
        for argument in instance["arguments"]:
            role_type_set.add(argument[2])
                
    type_set = {"trigger": trigger_type_set, "role": role_type_set}
    
    # approach-specific preprocessing
    new_instances = add_extra_info_fn(instances, data, config)
    assert len(new_instances) == len(instances)
    
    print('Loaded {} EAE instances ({} trigger types and {} role types) from {}'.format(
        len(new_instances), len(trigger_type_set), len(role_type_set), file))
    print(f"combine cnt :{cnt}")
    return new_instances, type_set
    
    return new_instances, type_set
if config.task == "EAE":
        train_data, train_type_set = load_EAE_data(config.train_file, trainer_class.add_extra_info_fn, config)
        dev_data, dev_type_set = load_EAE_data(config.dev_file, trainer_class.add_extra_info_fn, config)
        test_data, test_type_set = load_EAE_data(config.test_file, trainer_class.add_extra_info_fn, config)
        type_set = {"trigger": train_type_set["trigger"] | dev_type_set["trigger"] | test_type_set["trigger"], 
                    "role": train_type_set["role"] | dev_type_set["role"] | test_type_set["role"]}
        print("There are {} trigger types and {} role types in total".format(len(type_set["trigger"]), len(type_set["role"])))

Loaded 2987 EAE instances (2 trigger types and 16 role types) from ../data/processed_data/phee/split1/train.json
combine cnt :16
Loaded 1005 EAE instances (2 trigger types and 16 role types) from ../data/processed_data/phee/split1/dev.json
combine cnt :6
Loaded 1003 EAE instances (2 trigger types and 16 role types) from ../data/processed_data/phee/split1/test.json
combine cnt :2
There are 2 trigger types and 16 role types in total


In [5]:
def load_EAE_data(file, add_extra_info_fn, config):

    with open(file, 'r', encoding='utf-8') as fp:
        lines = fp.readlines()
    data = [json.loads(line) for line in lines]
    
    instances = []
    for dt in data:
        
        entities = dt['entity_mentions']

        event_mentions = dt['event_mentions']
        event_mentions.sort(key=lambda x: x['trigger']['start'])

        entity_map = {entity['id']: entity for entity in entities}
        for i, event_mention in enumerate(event_mentions):
            # trigger = (start index, end index, event type, text span)
            trigger = (event_mention['trigger']['start'], 
                       event_mention['trigger']['end'], 
                       event_mention['event_type'], 
                       event_mention['trigger']['text'])

            arguments = []
            for arg in event_mention['arguments']:
                mapped_entity = entity_map[arg['entity_id']]
                
                # argument = (start index, end index, role type, text span)
                argument = (mapped_entity['start'], mapped_entity['end'], arg['role'], arg['text'])
                arguments.append(argument)

            arguments.sort(key=lambda x: (x[0], x[1]))
            
            instance = {"doc_id": dt["doc_id"], 
                        "wnd_id": dt["wnd_id"], 
                        "tokens": dt["tokens"], 
                        "text": dt["text"], 
                        "trigger": trigger, 
                        "arguments": arguments, 
                       }

            instances.append(instance)
            
    trigger_type_set = set()
    for instance in instances:
        trigger_type_set.add(instance['trigger'][2])

    role_type_set = set()
    for instance in instances:
        for argument in instance["arguments"]:
            role_type_set.add(argument[2])
                
    type_set = {"trigger": trigger_type_set, "role": role_type_set}
    
    # approach-specific preprocessing
    new_instances = add_extra_info_fn(instances, data, config)
    assert len(new_instances) == len(instances)
    
    print('Loaded {} EAE instances ({} trigger types and {} role types) from {}'.format(
        len(new_instances), len(trigger_type_set), len(role_type_set), file))
    
    return new_instances, type_set
    
if config.task == "EAE":
        train_data, train_type_set = load_EAE_data(config.train_file, trainer_class.add_extra_info_fn, config)
        dev_data, dev_type_set = load_EAE_data(config.dev_file, trainer_class.add_extra_info_fn, config)
        test_data, test_type_set = load_EAE_data(config.test_file, trainer_class.add_extra_info_fn, config)
        type_set = {"trigger": train_type_set["trigger"] | dev_type_set["trigger"] | test_type_set["trigger"], 
                    "role": train_type_set["role"] | dev_type_set["role"] | test_type_set["role"]}
        print("There are {} trigger types and {} role types in total".format(len(type_set["trigger"]), len(type_set["role"])))

Loaded 3003 EAE instances (2 trigger types and 16 role types) from ../data/processed_data/phee/split1/train.json
Loaded 1011 EAE instances (2 trigger types and 16 role types) from ../data/processed_data/phee/split1/dev.json
Loaded 1005 EAE instances (2 trigger types and 16 role types) from ../data/processed_data/phee/split1/test.json
There are 2 trigger types and 16 role types in total


In [6]:
# train
trainer = trainer_class(config, type_set)
print(trainer)

<models.TagPrime.EAEtrainer.TagPrimeEAETrainer object at 0x0000016E471BE3D0>


In [7]:
tokenizer = RobertaTokenizer.from_pretrained(config.pretrained_model_name, cache_dir=config.cache_dir, do_lower_case=False, add_prefix_space=True)

'HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /roberta-base/resolve/main/vocab.json (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x0000016E4281DD90>, 'Connection to huggingface.co timed out. (connect timeout=10)'))' thrown while requesting HEAD https://huggingface.co/roberta-base/resolve/main/vocab.json


In [8]:
def process_data( data):
    assert tokenizer, "Please load model and tokneizer before processing data!"
    
    print("Removing overlapping arguments and over-length examples")
    
    # greedily remove overlapping arguments
    n_total = 0
    new_data = []
    for dt in data:
        
        n_total += 1
        
        if len(dt["tokens"]) > config.max_length:
            continue
        
        trigger = dt["trigger"]
        no_overlap_flag = np.ones((len(dt["tokens"]), ), dtype=bool)
        new_arguments = []
        for argument in sorted(dt["arguments"]):
            start, end = argument[0], argument[1]
            if np.all(no_overlap_flag[start:end]):
                new_arguments.append(argument)
                no_overlap_flag[start:end] = False
        
        pieces = [tokenizer.tokenize(t, is_split_into_words=True) for t in dt["tokens"]]
        token_lens = [len(p) for p in pieces] 

        new_dt = {"doc_id": dt["doc_id"], 
                    "wnd_id": dt["wnd_id"], 
                    "tokens": dt["tokens"], 
                    "pieces": [p for w in pieces for p in w], 
                    "token_lens": token_lens, 
                    "token_num": len(dt["tokens"]), 
                    "text": dt["text"], 
                    "trigger": dt["trigger"], 
                    "arguments": new_arguments
                    }
        
        
        new_data.append(new_dt)
            
    print(f"There are {len(new_data)}/{n_total} EAE instances after removing overlapping arguments and over-length examples")

    return new_data
internal_train_data = process_data(train_data)
internal_dev_data = process_data(dev_data)

Removing overlapping arguments and over-length examples
There are 3003/3003 EAE instances after removing overlapping arguments and over-length examples
Removing overlapping arguments and over-length examples
There are 1011/1011 EAE instances after removing overlapping arguments and over-length examples


In [9]:
EAEBatch_fields = ['batch_doc_id', 'batch_wnd_id', 'batch_tokens', 'batch_pieces', 'batch_token_lens', 'batch_token_num', 'batch_text', 'batch_trigger', 'batch_arguments']
EAEBatch = namedtuple('EAEBatch', field_names=EAEBatch_fields, defaults=[None] * len(EAEBatch_fields))

def EAE_collate_fn(batch):
    return EAEBatch(
        batch_doc_id=[instance["doc_id"] for instance in batch],
        batch_wnd_id=[instance["wnd_id"] for instance in batch],
        batch_tokens=[instance["tokens"] for instance in batch], 
        batch_pieces=[instance["pieces"] for instance in batch], 
        batch_token_lens=[instance["token_lens"] for instance in batch], 
        batch_token_num=[instance["token_num"] for instance in batch], 
        batch_text=[instance["text"] for instance in batch], 
        batch_trigger=[instance["trigger"] for instance in batch], 
        batch_arguments=[instance["arguments"] for instance in batch], 
    )

In [10]:
from torch.utils.data import DataLoader
import pprint
train = DataLoader(internal_train_data, batch_size=4, 
                                                         shuffle=True, drop_last=False, collate_fn=EAE_collate_fn)
batch = next(iter(train))
pprint.pprint(batch)

EAEBatch(batch_doc_id=['480917_1', '19249953_10', '17458405_2', '24318743_3'], batch_wnd_id=['480917_1_1', '19249953_10_1', '17458405_2_1', '24318743_3_1'], batch_tokens=[['A', 'case', 'of', 'Erythema', 'Multiforme', 'Bullosum', 'in', 'patient', 'of', 'lepromatous', 'leprosy', 'with', 'pulmonary', 'tuberculosis', 'due', 'to', 'Rifampicin', 'is', 'described', '.'], ['To', 'our', 'knowledge', ',', 'however', ',', 'this', 'is', 'the', 'first', 'case', 'report', 'of', 'a', 'possible', 'sitagliptin', '-', 'lovastatin', 'interaction', 'that', 'may', 'have', 'caused', 'rhabdomyolysis', '.'], ['Since', 'its', 'FDA', 'approval', 'in', '2002', ',', 'there', 'are', 'no', 'known', 'citations', 'of', 'ezetimibe', '-', 'induced', 'pancreatitis', '.'], ['Based', 'on', 'prior', 'data', 'suggesting', 'that', 'scheduling', 'alterations', 'of', 'platinum', 'would', 'increase', 'activity', ',', 'the', 'aim', 'of', 'the', 'present', 'study', 'was', 'to', 'assess', 'the', 'potential', 'therapeutic', 'benefi

In [11]:
type_set

{'trigger': {'Adverse_event', 'Potential_therapeutic_event'},
 'role': {'Combination_Drug',
  'Effect',
  'Subject',
  'Subject_Age',
  'Subject_Disorder',
  'Subject_Gender',
  'Subject_Population',
  'Subject_Race',
  'Treatment',
  'Treatment_Disorder',
  'Treatment_Dosage',
  'Treatment_Drug',
  'Treatment_Duration',
  'Treatment_Freq',
  'Treatment_Route',
  'Treatment_Time_elapsed'}}

In [12]:
def generate_tagging_vocab():
        prefix = ['B', 'I']
        trigger_label_stoi = {'O': 0}
        for t in sorted(type_set["trigger"]):
            for p in prefix:
                trigger_label_stoi['{}-{}'.format(p, t)] = len(trigger_label_stoi)

        role_label_stoi = {'O': 0}
        for t in sorted(type_set["role"]):
            for p in prefix:
                role_label_stoi['{}-{}'.format(p, t)] = len(role_label_stoi)
        
        if config.priming_type == "condition+relation":
            label_stoi = {"trigger": trigger_label_stoi, "role": {"O": 0, "B-Pred": 1, "I-Pred": 2}}
        else:
            label_stoi = {"trigger": trigger_label_stoi, "role": role_label_stoi}
        
        trigger_type_stoi = {t: i for i, t in enumerate(sorted(type_set["trigger"]))}
        role_type_stoi = {t: i for i, t in enumerate(sorted(type_set["role"]))}
        type_stoi = {"trigger": trigger_type_stoi, "role": role_type_stoi}
        return label_stoi , type_stoi
label_stoi , type_stoi = generate_tagging_vocab()

In [13]:
print(label_stoi)

{'trigger': {'O': 0, 'B-Adverse_event': 1, 'I-Adverse_event': 2, 'B-Potential_therapeutic_event': 3, 'I-Potential_therapeutic_event': 4}, 'role': {'O': 0, 'B-Combination_Drug': 1, 'I-Combination_Drug': 2, 'B-Effect': 3, 'I-Effect': 4, 'B-Subject': 5, 'I-Subject': 6, 'B-Subject_Age': 7, 'I-Subject_Age': 8, 'B-Subject_Disorder': 9, 'I-Subject_Disorder': 10, 'B-Subject_Gender': 11, 'I-Subject_Gender': 12, 'B-Subject_Population': 13, 'I-Subject_Population': 14, 'B-Subject_Race': 15, 'I-Subject_Race': 16, 'B-Treatment': 17, 'I-Treatment': 18, 'B-Treatment_Disorder': 19, 'I-Treatment_Disorder': 20, 'B-Treatment_Dosage': 21, 'I-Treatment_Dosage': 22, 'B-Treatment_Drug': 23, 'I-Treatment_Drug': 24, 'B-Treatment_Duration': 25, 'I-Treatment_Duration': 26, 'B-Treatment_Freq': 27, 'I-Treatment_Freq': 28, 'B-Treatment_Route': 29, 'I-Treatment_Route': 30, 'B-Treatment_Time_elapsed': 31, 'I-Treatment_Time_elapsed': 32}}


In [14]:
label_stoi["role"]

{'O': 0,
 'B-Combination_Drug': 1,
 'I-Combination_Drug': 2,
 'B-Effect': 3,
 'I-Effect': 4,
 'B-Subject': 5,
 'I-Subject': 6,
 'B-Subject_Age': 7,
 'I-Subject_Age': 8,
 'B-Subject_Disorder': 9,
 'I-Subject_Disorder': 10,
 'B-Subject_Gender': 11,
 'I-Subject_Gender': 12,
 'B-Subject_Population': 13,
 'I-Subject_Population': 14,
 'B-Subject_Race': 15,
 'I-Subject_Race': 16,
 'B-Treatment': 17,
 'I-Treatment': 18,
 'B-Treatment_Disorder': 19,
 'I-Treatment_Disorder': 20,
 'B-Treatment_Dosage': 21,
 'I-Treatment_Dosage': 22,
 'B-Treatment_Drug': 23,
 'I-Treatment_Drug': 24,
 'B-Treatment_Duration': 25,
 'I-Treatment_Duration': 26,
 'B-Treatment_Freq': 27,
 'I-Treatment_Freq': 28,
 'B-Treatment_Route': 29,
 'I-Treatment_Route': 30,
 'B-Treatment_Time_elapsed': 31,
 'I-Treatment_Time_elapsed': 32}

In [15]:
print(type_stoi)

{'trigger': {'Adverse_event': 0, 'Potential_therapeutic_event': 1}, 'role': {'Combination_Drug': 0, 'Effect': 1, 'Subject': 2, 'Subject_Age': 3, 'Subject_Disorder': 4, 'Subject_Gender': 5, 'Subject_Population': 6, 'Subject_Race': 7, 'Treatment': 8, 'Treatment_Disorder': 9, 'Treatment_Dosage': 10, 'Treatment_Drug': 11, 'Treatment_Duration': 12, 'Treatment_Freq': 13, 'Treatment_Route': 14, 'Treatment_Time_elapsed': 15}}


In [16]:
def get_role_seqlabels( roles, token_num, specify_role=None, use_unified_label=False):
    labels = ['O'] * token_num
    count = 0
    for role in roles:
        start, end = role[0], role[1]
        if end > token_num:
            continue
        role_type = role[2]

        if specify_role is not None:
            if role_type != specify_role:
                continue
        # 是否至少有一个值为True，如果是则返回True，否则返回False
        if any([labels[i] != 'O' for i in range(start, end)]):
            count += 1
            continue

        if (specify_role is not None) and use_unified_label:
            labels[start] = 'B-{}'.format("Pred")
            for i in range(start + 1, end):
                labels[i] = 'I-{}'.format("Pred")
        else:
            labels[start] = 'B-{}'.format(role_type)
            for i in range(start + 1, end):
                labels[i] = 'I-{}'.format(role_type)

    return labels

In [17]:
batch.batch_arguments

[[(3, 6, 'Effect', 'Erythema Multiforme Bullosum'),
  (7,
   14,
   'Subject',
   'patient of lepromatous leprosy with pulmonary tuberculosis'),
  (16, 17, 'Treatment', 'Rifampicin')],
 [(15, 16, 'Combination_Drug', 'sitagliptin'),
  (17, 18, 'Combination_Drug', 'lovastatin'),
  (23, 24, 'Effect', 'rhabdomyolysis')],
 [(13, 14, 'Treatment', 'ezetimibe'), (16, 17, 'Effect', 'pancreatitis')],
 [(24, 27, 'Effect', 'potential therapeutic benefit'),
  (28, 29, 'Combination_Drug', 'phenoxodiol'),
  (46, 47, 'Treatment_Freq', 'weekly'),
  (49, 50, 'Combination_Drug', 'carboplatin'),
  (51, 52, 'Treatment_Disorder', 'PROC')]]

In [18]:
batch = next(iter(train))

In [19]:
from pattern import event_type_tags,event_description
def process_data(batch):
        enc_idxs = []
        pro_idxs = []
        enc_attn_pro = []
        enc_attn = []
        role_seqidxs = []
        trigger_types = []
        token_lens = []
        token_nums = []
        triggers = []
        max_token_num = max(batch.batch_token_num)
        
        for tokens, pieces, trigger, arguments, token_len, token_num in zip(batch.batch_tokens, batch.batch_pieces, batch.batch_trigger, 
                                                                      batch.batch_arguments, batch.batch_token_lens, batch.batch_token_num):
            
        
            event_type_map = event_type_tags[config.dataset]

            prompt = "{} {} {} {} {}".format(event_type_map[trigger[2]], 
                                             tokenizer.sep_token, trigger[3],
                                             tokenizer.sep_token,event_description[config.dataset][trigger[2]]["event description"])
            
            prompt_id = tokenizer.encode(prompt, add_special_tokens=False,is_split_into_words = False) 
             
            piece_id = tokenizer.convert_tokens_to_ids(pieces)
            
            enc_idx = [tokenizer.convert_tokens_to_ids(tokenizer.bos_token)] + piece_id + [tokenizer.convert_tokens_to_ids(tokenizer.eos_token)]
            prompt_idx = [tokenizer.convert_tokens_to_ids(tokenizer.bos_token)] + prompt_id + [tokenizer.convert_tokens_to_ids(tokenizer.eos_token)]
            
            # enc_idx = enc_idx[:base_config.max_position_embeddings-2]
            # prompt_idx = prompt_idx[:base_config.max_position_embeddings-2]
            
            pro_idxs.append(prompt_idx)
            enc_idxs.append(enc_idx)

            enc_attn_pro.append([1]*len(prompt_idx))
            enc_attn.append([1]*len(enc_idx))  

            role_seq = get_role_seqlabels(arguments, len(tokens))
            trigger_types.append(type_stoi["trigger"][trigger[2]])
            token_lens.append(token_len)
            token_nums.append(token_num)
            triggers.append(trigger)
            if config.use_crf:
                role_seqidxs.append([label_stoi["role"][s] for s in role_seq] + [0] * (max_token_num-len(tokens)))
            else:
                role_seqidxs.append([label_stoi["role"][s] for s in role_seq] + [-100] * (max_token_num-len(tokens)))
                    
        max_len = max([len(enc_idx) for enc_idx in enc_idxs])
        enc_idxs = torch.LongTensor([enc_idx + [tokenizer.convert_tokens_to_ids(tokenizer.pad_token)]*(max_len-len(enc_idx)) for enc_idx in enc_idxs])
        enc_attn = torch.LongTensor([enc_att + [0]*(max_len-len(enc_att)) for enc_att in enc_attn])
        
        max_len = max([len(prompt_idx) for prompt_idx in pro_idxs])
        pro_idxs = torch.LongTensor([prompt_idx + [tokenizer.convert_tokens_to_ids(tokenizer.pad_token)]*(max_len-len(prompt_idx)) for prompt_idx in pro_idxs])
        enc_attn_pro = torch.LongTensor([enc_attn_pr + [0]*(max_len-len(enc_attn_pr)) for enc_attn_pr in enc_attn_pro])

        
        trigger_types = torch.LongTensor(trigger_types)
        role_seqidxs = torch.LongTensor(role_seqidxs)
        return enc_idxs, enc_attn, role_seqidxs, trigger_types, token_lens, torch.LongTensor(token_nums), triggers,pro_idxs,enc_attn_pro

In [20]:
enc_idxs, enc_attn, role_seqidxs, trigger_types, token_lens, token_nums, triggers ,pro_idxs,enc_attn_pro = process_data(batch)

In [21]:
batch.batch_token_num

[14, 11, 27, 19]

In [22]:
print(pro_idxs.size(),enc_attn_pro.size())

torch.Size([4, 25]) torch.Size([4, 25])


In [23]:
base_model = RobertaModel.from_pretrained(config.pretrained_model_name, 
                                                           cache_dir=config.cache_dir, 
                                                           output_hidden_states=True)
base_config = RobertaConfig.from_pretrained(config.pretrained_model_name, 
                                                             cache_dir=config.cache_dir)
base_model_dim = base_config.hidden_size
print(base_model_dim)

'HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /roberta-base/resolve/main/config.json (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x0000016E4D11A610>, 'Connection to huggingface.co timed out. (connect timeout=10)'))' thrown while requesting HEAD https://huggingface.co/roberta-base/resolve/main/config.json
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
'HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /roberta-base/resolve/main/config.json (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x0000016E4D25B550>, 'Connection to huggingface.co timed out. (connect timeout=10)'))' thrown while requesting HEAD https://hugg

768


In [29]:
def token_lens_to_idxs( token_lens):
        """Map token lengths to a word piece index matrix (for torch.gather) and a
        mask tensor.
        For example (only show a sequence instead of a batch):
        token lengths: [1,1,1,3,1]
        =>
        indices: [[0,0,0], [1,0,0], [2,0,0], [3,4,5], [6,0,0]]
        masks: [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0],
                [0.33, 0.33, 0.33], [1.0, 0.0, 0.0]]
        Next, we use torch.gather() to select vectors of word pieces for each token,
        and average them as follows (incomplete code):
        outputs = torch.gather(bert_outputs, 1, indices) * masks
        outputs = bert_outputs.view(batch_size, seq_len, -1, self.bert_dim)
        outputs = bert_outputs.sum(2)
        :param token_lens (list): token lengths.
        :return: a index matrix and a mask tensor.
        """
        max_token_num = max([len(x) for x in token_lens])
        max_token_len = max([max(x) for x in token_lens])
        idxs, masks = [], []
        for seq_token_lens in token_lens:
            seq_idxs, seq_masks = [], []
            offset = 0
            for token_len in seq_token_lens:
                seq_idxs.extend([i + offset for i in range(token_len)]
                                + [-1] * (max_token_len - token_len))
                seq_masks.extend([1.0 / token_len] * token_len
                                 + [0.0] * (max_token_len - token_len))
                offset += token_len
            seq_idxs.extend([-1] * max_token_len * (max_token_num - len(seq_token_lens)))
            seq_masks.extend([0.0] * max_token_len * (max_token_num - len(seq_token_lens)))
            idxs.append(seq_idxs)
            masks.append(seq_masks)
        return idxs, masks, max_token_num, max_token_len

In [30]:
base_model_dropout = nn.Dropout(p=config.base_model_dropout)
def encode(piece_idxs, attention_masks, token_lens):
    """Encode input sequences with BERT
    :param piece_idxs (LongTensor): word pieces indices
    :param attention_masks (FloatTensor): attention mask
    :param token_lens (list): token lengths
    """
    batch_size, _ = piece_idxs.size()
    all_base_model_outputs = base_model(piece_idxs, attention_mask=attention_masks)
    base_model_outputs = all_base_model_outputs[0]
    if config.multi_piece_strategy == 'first':
        # select the first piece for multi-piece words
        offsets = token_lens_to_offsets(token_lens)
        offsets = piece_idxs.new(offsets) # batch x max_token_num
        # + 1 because the first vector is for [CLS]
        offsets = offsets.unsqueeze(-1).expand(batch_size, -1, bert_dim) + 1
        base_model_outputs = torch.gather(base_model_outputs, 1, offsets)
    elif config.multi_piece_strategy == 'average':
        # average all pieces for multi-piece words
        idxs, masks, token_num, token_len = token_lens_to_idxs(token_lens)
        idxs = piece_idxs.new(idxs).unsqueeze(-1).expand(batch_size, -1, base_model_dim) + 1
        masks = base_model_outputs.new(masks).unsqueeze(-1)
        base_model_outputs = torch.gather(base_model_outputs, 1, idxs) * masks
        base_model_outputs = base_model_outputs.view(batch_size, token_num, token_len, base_model_dim)
        base_model_outputs = base_model_outputs.sum(2)
    else:
        raise ValueError(f'Unknown multi-piece token handling strategy: {config.multi_piece_strategy}')
    base_model_outputs = base_model_dropout(base_model_outputs)
    return base_model_outputs
base_model_outputs = encode(enc_idxs, enc_attn, token_lens)

In [31]:
pro_model_outputs = base_model(pro_idxs, attention_mask=enc_attn_pro)[0]

In [39]:
class CrossAttention(nn.Module):
    """cross attention between input text and prompt"""
    def __init__(self,d_model,d_k,d_v):
        super().__init__()
        self.d_k = d_k 
        self.W_Q = nn.Linear(d_model, d_k , bias=False)
        self.W_K = nn.Linear(d_model, d_k , bias=False)
        self.W_V = nn.Linear(d_model, d_v , bias=False)
        self.fc = nn.Linear( d_v, d_model, bias=False)
    def forward(self, input_Q, input_K, input_V,attn_mask):
        '''
        input_Q: [batch_size, len_q, d_model]
        input_K: [batch_size, len_k, d_model]
        input_V: [batch_size, len_v(=len_k), d_model]
        attn_mask: [batch_size, seq_len, seq_len]
        '''
        # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        seq_len = input_Q.size(1)
        Q = self.W_Q(input_Q) # Q: [batch_size, len_q, d_k]
        K = self.W_K(input_K) # K: [batch_size, len_k, d_k]
        V = self.W_V(input_V) # V: [batch_size, len_v(=len_k), d_v]

        scores = torch.matmul(Q,K.transpose(-1, -2))/np.sqrt(self.d_k)
        attn_mask = attn_mask.unsqueeze(1).repeat(1,seq_len,1)
        scores.masked_fill_(attn_mask==0, -1e9)
        attn = nn.Softmax(dim=-1)(scores)
        output = torch.matmul(attn, V)
        return output

In [40]:
base_model_dim = base_config.hidden_size
cross_att = CrossAttention(d_model = base_model_dim,
                            d_k = base_model_dim ,
                            d_v = base_model_dim)

In [44]:
atten = cross_att(base_model_outputs,pro_model_outputs,pro_model_outputs,enc_attn_pro)

In [45]:
atten.size()

torch.Size([4, 27, 768])

In [108]:
base_model_outputs.size()

torch.Size([4, 29, 768])

In [43]:
torch.cat((base_model_outputs,atten),-1).size()

torch.Size([4, 27, 1536])