In [1]:
import json
from argparse import Namespace
from models import *
import torch.nn as nn
import numpy as np
from transformers import BertConfig, RobertaConfig, XLMRobertaConfig, BertModel, RobertaModel, XLMRobertaModel,RobertaTokenizer
from collections import namedtuple
import torch

In [2]:
task = "ED";
dataset = "phee";
split = 1;
model_type = "CRFTagging";
pretrained_model_name = "roberta-base";
pretrained_model_alias = {
    "roberta-base": "roberta-base", 
};
config_dict =  {
        #// general config
        "task": task, 
        "dataset": dataset,
        "model_type": model_type, 
        "gpu_device": 0, 
        "seed": 0, 
        "cache_dir": "./cache", 
        "output_dir": "./outputs/%s_%s_%s_split%s_%s" % (model_type, task, dataset, split, pretrained_model_alias[pretrained_model_name]), 
        "train_file": "../data/preprocessing/%s/split%s/train.json" % (dataset, split),
        "dev_file": "../data/preprocessing/%s/split%s/dev.json" % (dataset, split),
        "test_file": "../data/preprocessing/%s/split%s/test.json" % (dataset, split),
        
        
        #// model config
        "pretrained_model_name": pretrained_model_name,
        "base_model_dropout": 0.2,
        "use_crf": True,
        "type_feature_num": 100, 
        "linear_hidden_num": 150,
        "linear_dropout": 0.2,
        "linear_bias": True, 
        "linear_activation": "relu",
        "multi_piece_strategy": "average", 
        "max_length": 200, 
        
        # // train config
        "max_epoch": 30,
        "warmup_epoch": 5,
        "accumulate_step": 1,
        "train_batch_size": 6,
        "eval_batch_size": 12,
        "learning_rate": 0.001,
        "base_model_learning_rate": 1e-05,
        "weight_decay": 0.001,
        "base_model_weight_decay": 1e-05,
        "grad_clipping": 5.0,
    }
config = Namespace(**config_dict)

In [3]:
# load trainer
VALID_TASKS = ["E2E", "ED", "EAE", "EARL"]

TRAINER_MAP = {

    ("CRFTagging", "ED"): CRFTaggingEDTrainer, 
    ("CRFTagging", "EAE"): CRFTaggingEAETrainer
}
trainer_class = TRAINER_MAP[(config.model_type, config.task)]

In [4]:
def load_ED_data(file, add_extra_info_fn, config):

    with open(file, 'r', encoding='utf-8') as fp:
        lines = fp.readlines()
    data = [json.loads(line) for line in lines]
    
    instances = []
    for dt in data:

        event_mentions = dt['event_mentions']
        event_mentions.sort(key=lambda x: x['trigger']['start'])

        triggers = []
        for i, event_mention in enumerate(event_mentions):
            # trigger = (start index, end index, event type, text span)
            trigger = (event_mention['trigger']['start'], 
                       event_mention['trigger']['end'], 
                       event_mention['event_type'], 
                       event_mention['trigger']['text'])

            triggers.append(trigger)

        triggers.sort(key=lambda x: (x[0], x[1]))
        
        instance = {"doc_id": dt["doc_id"], 
                    "wnd_id": dt["wnd_id"], 
                    "tokens": dt["tokens"], 
                    "text": dt["text"], 
                    "triggers": triggers,
                   }

        instances.append(instance)

    trigger_type_set = set()
    for instance in instances:
        for trigger in instance['triggers']:
            trigger_type_set.add(trigger[2])

    type_set = {"trigger": trigger_type_set}
    
    # approach-specific preprocessing
    new_instances = add_extra_info_fn(instances, data, config)
    assert len(new_instances) == len(instances)
    
    print('Loaded {} ED instances ({} trigger types) from {}'.format(
        len(new_instances), len(trigger_type_set), file))
    
    return new_instances, type_set
train_data, train_type_set = load_ED_data(config.train_file, trainer_class.add_extra_info_fn, config)
dev_data, dev_type_set = load_ED_data(config.dev_file, trainer_class.add_extra_info_fn, config)
test_data, test_type_set = load_ED_data(config.test_file, trainer_class.add_extra_info_fn, config)
type_set = {"trigger": train_type_set["trigger"] | dev_type_set["trigger"] | test_type_set["trigger"]}
print("There are {} trigger types in total".format(len(type_set["trigger"])))


Loaded 2897 ED instances (2 trigger types) from ../data/preprocessing/phee/split1/train.json
Loaded 965 ED instances (2 trigger types) from ../data/preprocessing/phee/split1/dev.json
Loaded 965 ED instances (2 trigger types) from ../data/preprocessing/phee/split1/test.json
There are 2 trigger types in total


In [5]:
from transformers import RobertaTokenizer, AutoTokenizer, get_linear_schedule_with_warmup
tokenizer = RobertaTokenizer.from_pretrained(config.pretrained_model_name, cache_dir=config.cache_dir, do_lower_case=False, add_prefix_space=True)

'HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /roberta-base/resolve/main/vocab.json (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x000001CC82CD5E10>, 'Connection to huggingface.co timed out. (connect timeout=10)'))' thrown while requesting HEAD https://huggingface.co/roberta-base/resolve/main/vocab.json


In [6]:
def process_data(data):
        
        print("Removing overlapping triggers and over-length examples")
        
        # greedily remove overlapping triggers
        n_total = 0
        new_data = []
        for dt in data:
            
            n_total += 1
            
            if len(dt["tokens"]) > config.max_length:
                continue
            
            
            no_overlap_flag = np.ones((len(dt["tokens"]), ), dtype=bool)
            new_triggers = []
            for trigger in dt["triggers"]:
                start, end = trigger[0], trigger[1]
                if np.all(no_overlap_flag[start:end]):
                    new_triggers.append(trigger)
                    no_overlap_flag[start:end] = False
                            
            pieces = [tokenizer.tokenize(t, is_split_into_words=True) for t in dt["tokens"]]
            token_lens = [len(p) for p in pieces] 

            new_dt = {"doc_id": dt["doc_id"], 
                      "wnd_id": dt["wnd_id"], 
                      "tokens": dt["tokens"], 
                      "pieces": [p for w in pieces for p in w], 
                      "token_lens": token_lens, 
                      "token_num": len(dt["tokens"]), 
                      "text": dt["text"], 
                      "triggers": new_triggers
                     }
            
            new_data.append(new_dt)
                
        print(f"There are {len(new_data)}/{n_total} ED instances after removing overlapping triggers and over-length examples")

        return new_data

In [7]:
internal_train_data = process_data(train_data)
internal_dev_data = process_data(dev_data)

Removing overlapping triggers and over-length examples
There are 2897/2897 ED instances after removing overlapping triggers and over-length examples
Removing overlapping triggers and over-length examples
There are 965/965 ED instances after removing overlapping triggers and over-length examples


In [8]:
internal_train_data[0]

{'doc_id': '10030778_1',
 'wnd_id': '10030778_1_1',
 'tokens': ['Intravenous', 'azithromycin', '-', 'induced', 'ototoxicity', '.'],
 'pieces': ['ĠInt',
  'ra',
  'ven',
  'ous',
  'Ġaz',
  'ith',
  'romy',
  'cin',
  'Ġ-',
  'Ġinduced',
  'Ġot',
  'ot',
  'oxicity',
  'Ġ.'],
 'token_lens': [4, 4, 1, 1, 3, 1],
 'token_num': 6,
 'text': 'Intravenous azithromycin - induced ototoxicity .',
 'triggers': [(3, 4, 'Adverse_event', 'induced')]}

In [9]:
prefix = ['B', 'I']
trigger_label_stoi = {'O': 0}
for t in sorted(type_set["trigger"]):
    for p in prefix:
        trigger_label_stoi['{}-{}'.format(p, t)] = len(trigger_label_stoi)

label_stoi = {"trigger": trigger_label_stoi}
print(label_stoi)

{'trigger': {'O': 0, 'B-Adverse_event': 1, 'I-Adverse_event': 2, 'B-Potential_therapeutic_event': 3, 'I-Potential_therapeutic_event': 4}}


In [10]:
from collections import namedtuple
EDBatch_fields = ['batch_doc_id', 'batch_wnd_id', 'batch_tokens', 'batch_pieces', 'batch_token_lens', 'batch_token_num', 'batch_text', 'batch_triggers']
EDBatch = namedtuple('EDBatch', field_names=EDBatch_fields, defaults=[None] * len(EDBatch_fields))
def ED_collate_fn(batch):
    return EDBatch(
        batch_doc_id=[instance["doc_id"] for instance in batch],
        batch_wnd_id=[instance["wnd_id"] for instance in batch],
        batch_tokens=[instance["tokens"] for instance in batch], 
        batch_pieces=[instance["pieces"] for instance in batch], 
        batch_token_lens=[instance["token_lens"] for instance in batch], 
        batch_token_num=[instance["token_num"] for instance in batch], 
        batch_text=[instance["text"] for instance in batch], 
        batch_triggers=[instance["triggers"] for instance in batch], 
    )


In [11]:
from torch.utils.data import DataLoader
train = DataLoader(internal_train_data, batch_size=3, 
                                                         shuffle=True, drop_last=False, collate_fn=ED_collate_fn)

In [12]:
batch = next(iter(train))

In [13]:
batch

EDBatch(batch_doc_id=['19354059_1', '18421192_3', '3365032_2'], batch_wnd_id=['19354059_1_1', '18421192_3_1', '3365032_2_1'], batch_tokens=[['The', 'most', 'common', 'complication', 'of', 'warfarin', 'use', 'is', 'adverse', 'bleeding', '.'], ['Type', '1', 'diabetes', 'mellitus', 'provoked', 'by', 'peginterferon', 'alpha', '-', '2b', 'plus', 'ribavirin', 'treatment', 'for', 'chronic', 'hepatitis', 'C.'], ['Polymyositis', 'after', 'propylthiouracil', 'treatment', 'for', 'hyperthyroidism', '.']], batch_pieces=[['ĠThe', 'Ġmost', 'Ġcommon', 'Ġcomplication', 'Ġof', 'Ġwar', 'far', 'in', 'Ġuse', 'Ġis', 'Ġadverse', 'Ġbleeding', 'Ġ.'], ['ĠType', 'Ġ1', 'Ġdiabetes', 'Ġmell', 'itus', 'Ġprovoked', 'Ġby', 'Ġpe', 'gin', 'ter', 'fer', 'on', 'Ġalpha', 'Ġ-', 'Ġ2', 'b', 'Ġplus', 'Ġrib', 'av', 'irin', 'Ġtreatment', 'Ġfor', 'Ġchronic', 'Ġhepatitis', 'ĠC', '.'], ['ĠPoly', 'my', 'osit', 'is', 'Ġafter', 'Ġprop', 'yl', 'th', 'iour', 'ac', 'il', 'Ġtreatment', 'Ġfor', 'Ġhyper', 'thy', 'roid', 'ism', 'Ġ.']], batch

In [14]:
def get_trigger_seqlabels(triggers, token_num, specify_trigger=None):
    labels = ['O'] * token_num
    count = 0
    for trigger in triggers:
        start, end = trigger[0], trigger[1]
        if end > token_num:
            continue
        trigger_type = trigger[2]

        if specify_trigger is not None:
            if trigger_type != specify_trigger:
                continue

        if any([labels[i] != 'O' for i in range(start, end)]):
            count += 1
            continue

        labels[start] = 'B-{}'.format(trigger_type)
        for i in range(start + 1, end):
            labels[i] = 'I-{}'.format(trigger_type)
            
    return labels

In [15]:
batch.batch_token_num

[11, 17, 7]

In [37]:
def adj_matrix_creat(tokens:list):
    nlp = spacy.load('en_core_web_sm')
    doc = Doc(nlp.vocab, words=tokens)
    adj_matrix = np.eye(len(tokens))
    nlp(doc)
    for j in range(len(tokens)):
        adj_matrix[j][doc[j].head.i] = 1
        adj_matrix[doc[j].head.i][j] = 1
    return adj_matrix

def Zeropad(adj_matrix,pad_length):
    adj_pad = np.pad(adj_matrix,(0,pad_length),'constant')
    return adj_pad

In [22]:
import spacy
import numpy as np
import torch
from spacy.tokens import Doc

In [40]:
import torch
def process_data(batch):
        enc_idxs = []
        enc_attn = []
        trigger_seqidxs = []
        token_lens = []
        token_nums = []
        max_token_num = max(batch.batch_token_num)
        adjs = []
        
        for tokens, pieces, triggers, token_len, token_num in zip(batch.batch_tokens, batch.batch_pieces, batch.batch_triggers, 
                                                                      batch.batch_token_lens, batch.batch_token_num):
            adj = adj_matrix_creat(tokens)
            piece_id = tokenizer.convert_tokens_to_ids(pieces)
            enc_idx = [tokenizer.convert_tokens_to_ids(tokenizer.bos_token)] + piece_id + [tokenizer.convert_tokens_to_ids(tokenizer.eos_token)]
            
            adjs.append(Zeropad(adj,(max_token_num - token_num)))
            enc_idxs.append(enc_idx)
            enc_attn.append([1]*len(enc_idx))  
            
            trigger_seq = get_trigger_seqlabels(triggers, len(tokens))
            
            token_lens.append(token_len)
            token_nums.append(token_num)
            if config.use_crf:
                trigger_seqidxs.append([label_stoi["trigger"][s] for s in trigger_seq] + [0] * (max_token_num-len(tokens)))
            else:
                trigger_seqidxs.append([label_stoi["trigger"][s] for s in trigger_seq] + [-100] * (max_token_num-len(tokens)))
        max_len = max([len(enc_idx) for enc_idx in enc_idxs])
        adjs = torch.LongTensor(adjs)
        enc_idxs = torch.LongTensor([enc_idx + [tokenizer.convert_tokens_to_ids(tokenizer.pad_token)]*(max_len-len(enc_idx)) for enc_idx in enc_idxs])
        enc_attn = torch.LongTensor([enc_att + [0]*(max_len-len(enc_att)) for enc_att in enc_attn])
        trigger_seqidxs = torch.LongTensor(trigger_seqidxs)
        return enc_idxs, enc_attn, trigger_seqidxs, token_lens, torch.LongTensor(token_nums),adjs

In [41]:
enc_idxs, enc_attn, trigger_seqidxs, token_lens, token_nums,adjs = process_data(batch=batch)

OSError: [E050] Can't find model 'en_core_web_sm'. It doesn't seem to be a Python package or a valid path to a data directory.

In [20]:
def token_lens_to_idxs(token_lens):
        """Map token lengths to a word piece index matrix (for torch.gather) and a
        mask tensor.
        For example (only show a sequence instead of a batch):
        token lengths: [1,1,1,3,1]
        =>
        indices: [[0,0,0], [1,0,0], [2,0,0], [3,4,5], [6,0,0]]
        masks: [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0],
                [0.33, 0.33, 0.33], [1.0, 0.0, 0.0]]
        Next, we use torch.gather() to select vectors of word pieces for each token,
        and average them as follows (incomplete code):
        outputs = torch.gather(bert_outputs, 1, indices) * masks
        outputs = bert_outputs.view(batch_size, seq_len, -1, self.bert_dim)
        outputs = bert_outputs.sum(2)
        :param token_lens (list): token lengths.
        :return: a index matrix and a mask tensor.
        """
        max_token_num = max([len(x) for x in token_lens])  # 最大token数
        max_token_len = max([max(x) for x in token_lens])  # token中最大的
        idxs, masks = [], []
        for seq_token_lens in token_lens:
            seq_idxs, seq_masks = [], []
            offset = 0
            for token_len in seq_token_lens:
                # max_token_len为一组
                seq_idxs.extend([i + offset for i in range(token_len)]
                                + [-1] * (max_token_len - token_len))
                seq_masks.extend([1.0 / token_len] * token_len
                                 + [0.0] * (max_token_len - token_len))
                offset += token_len
            #  补全 
            seq_idxs.extend([-1] * max_token_len * (max_token_num - len(seq_token_lens)))
            seq_masks.extend([0.0] * max_token_len * (max_token_num - len(seq_token_lens)))
            idxs.append(seq_idxs)
            masks.append(seq_masks)
        return idxs, masks, max_token_num, max_token_len

In [21]:
base_model = RobertaModel.from_pretrained(config.pretrained_model_name, 
                                                           cache_dir=config.cache_dir, 
                                                           output_hidden_states=True)
base_config = RobertaConfig.from_pretrained(config.pretrained_model_name, 
                                                             cache_dir=config.cache_dir)
base_model_dim = base_config.hidden_size
print(base_model_dim)

'HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /roberta-base/resolve/main/config.json (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x000001B4F03511D0>, 'Connection to huggingface.co timed out. (connect timeout=10)'))' thrown while requesting HEAD https://huggingface.co/roberta-base/resolve/main/config.json
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
'HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /roberta-base/resolve/main/config.json (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x000001B4F03B7B50>, 'Connection to huggingface.co timed out. (connect timeout=10)'))' thrown while requesting HEAD https://hugg

768


In [22]:

def encode( piece_idxs, attention_masks, token_lens):
    """Encode input sequences with BERT
    :param piece_idxs (LongTensor): word pieces indices
    :param attention_masks (FloatTensor): attention mask
    :param token_lens (list): token lengths
    """
    batch_size, _ = piece_idxs.size()
    all_base_model_outputs = base_model(piece_idxs, attention_mask=attention_masks)
    base_model_outputs = all_base_model_outputs[0] # (bs , seq_len , hidden_dim)
    if config.multi_piece_strategy == 'first':
        # select the first piece for multi-piece words
        offsets = token_lens_to_offsets(token_lens)
        offsets = piece_idxs.new(offsets) # batch x max_token_num
        # + 1 because the first vector is for [CLS]
        offsets = offsets.unsqueeze(-1).expand(batch_size, -1, bert_dim) + 1
        base_model_outputs = torch.gather(base_model_outputs, 1, offsets)
    elif config.multi_piece_strategy == 'average':
        # average all pieces for multi-piece words
        idxs, masks, token_num, token_len = token_lens_to_idxs(token_lens)   # idx,masks:(bs , max_token_num * max_token_len)
        
        idxs = piece_idxs.new(idxs).unsqueeze(-1).expand(batch_size, -1, base_model_dim) + 1 # idx:(bs , max_token_num * max_token_len, base_model_dim)

        masks = base_model_outputs.new(masks).unsqueeze(-1) # masks:(bs , max_token_num * max_token_len, 1)
 
        base_model_outputs = torch.gather(base_model_outputs, 1, idxs) * masks # (bs, max_token_num * max_token_len,base_model_dim)

        base_model_outputs = base_model_outputs.view(batch_size, token_num, token_len, base_model_dim)
        base_model_outputs = base_model_outputs.sum(2)
    else:
        raise ValueError(f'Unknown multi-piece token handling strategy: {config.multi_piece_strategy}')
    superbase_model_dropout = nn.Dropout(p=config.base_model_dropout)
    base_model_outputs = superbase_model_dropout(base_model_outputs)
    return base_model_outputs

In [23]:
# encoding
base_model_outputs = encode(enc_idxs, enc_attn, token_lens)

In [24]:
print(base_model_outputs.size())

torch.Size([3, 35, 768])


In [25]:
class Linears(nn.Module):
    """Multiple linear layers with Dropout."""
    def __init__(self, dimensions, activation='relu', dropout_prob=0.0, bias=True):
        super().__init__()
        assert len(dimensions) > 1
        self.layers = nn.ModuleList([nn.Linear(dimensions[i], dimensions[i + 1], bias=bias)
                                     for i in range(len(dimensions) - 1)])
        self.activation = getattr(torch, activation)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, inputs):
        outputs = []
        for i, layer in enumerate(self.layers):
            if i > 0:
                inputs = self.activation(inputs)
                inputs = self.dropout(inputs)
            inputs = layer(inputs)
            outputs.append(inputs)
        return outputs[-1]
feature_dim = base_model_dim
trigger_label_ffn = Linears([feature_dim, config.linear_hidden_num, len(label_stoi["trigger"])],
                                      dropout_prob=config.linear_dropout, 
                                      bias=config.linear_bias, 
                                      activation=config.linear_activation)
entity_label_scores = trigger_label_ffn(base_model_outputs)

In [26]:
entity_label_scores.size()

torch.Size([3, 35, 5])

In [27]:
def sequence_mask(lens, max_len=None):
    """Generate a sequence mask tensor from sequence lengths, used by CRF."""
    batch_size = lens.size(0)
    if max_len is None:
        max_len = lens.max().item()
    ranges = torch.arange(0, max_len, device=lens.device).long()
    ranges = ranges.unsqueeze(0).expand(batch_size, max_len)
    lens_exp = lens.unsqueeze(1).expand_as(ranges)
    mask = ranges < lens_exp
    return mask

In [29]:
token_nums

tensor([32, 29, 35])

In [30]:
mask = sequence_mask(token_nums)

In [32]:
mask.long()

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

In [33]:
label_stoi

{'trigger': {'O': 0,
  'B-Adverse_event': 1,
  'I-Adverse_event': 2,
  'B-Potential_therapeutic_event': 3,
  'I-Potential_therapeutic_event': 4}}

In [35]:
from torchcrf import CRF
num_tags = len(label_stoi['trigger'])  # number of tags is 5
model = CRF(num_tags,batch_first=True)

In [36]:
trigger_seqidxs

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

In [38]:
emissions = entity_label_scores
tags = trigger_seqidxs
model(emissions, tags,mask = mask,reduction = 'mean')

tensor(-56.9406, grad_fn=<MeanBackward0>)

In [49]:
def tag_paths_to_spans( paths, token_nums, vocab):
        """
        Convert predicted tag paths to a list of spans (entity mentions or event
        triggers).
        :param paths: predicted tag paths.
        :return (list): a list (batch) of lists (sequence) of spans.
        """
        batch_mentions = []
        itos = {i: s for s, i in vocab.items()}
        for i, path in enumerate(paths):
            mentions = []
            cur_mention = None
            path = path[:token_nums[i].item()]
            for j, tag in enumerate(path):
                tag = itos[tag]
                if tag == 'O':
                    prefix = tag = 'O'
                else:
                    prefix, tag = tag.split('-', 1)
                if prefix == 'B':
                    if cur_mention:
                        mentions.append(cur_mention)
                    cur_mention = [j, j + 1, tag]
                elif prefix == 'I':
                    if cur_mention is None:
                        # treat it as B-*
                        cur_mention = [j, j + 1, tag]
                    elif cur_mention[-1] == tag:
                        cur_mention[1] = j + 1
                    else:
                        # treat it as B-*
                        mentions.append(cur_mention)
                        cur_mention = [j, j + 1, tag]
                else:
                    if cur_mention:
                        mentions.append(cur_mention)
                    cur_mention = None
            if cur_mention:
                mentions.append(cur_mention)
            batch_mentions.append(mentions)
            
        return batch_mentions

In [46]:
token_nums

tensor([32, 29, 35])

In [47]:
label_stoi["trigger"]

{'O': 0,
 'B-Adverse_event': 1,
 'I-Adverse_event': 2,
 'B-Potential_therapeutic_event': 3,
 'I-Potential_therapeutic_event': 4}

In [59]:
model.decode(emissions,mask = mask)

[[3,
  2,
  4,
  0,
  4,
  3,
  2,
  3,
  3,
  4,
  4,
  3,
  4,
  3,
  3,
  4,
  0,
  1,
  3,
  4,
  3,
  4,
  0,
  4,
  3,
  4,
  3,
  2,
  4,
  3,
  4,
  2],
 [3,
  3,
  4,
  0,
  3,
  4,
  3,
  3,
  2,
  4,
  0,
  3,
  2,
  4,
  3,
  2,
  3,
  2,
  4,
  0,
  1,
  4,
  1,
  3,
  2,
  1,
  4,
  1,
  4],
 [2,
  4,
  1,
  3,
  2,
  4,
  0,
  4,
  0,
  4,
  3,
  3,
  2,
  3,
  4,
  3,
  3,
  2,
  3,
  2,
  2,
  4,
  0,
  1,
  3,
  0,
  2,
  4,
  4,
  0,
  1,
  3,
  2,
  4,
  4]]

In [50]:
tag_paths_to_spans(model.decode(emissions,mask = mask), 
                    token_nums, 
                    label_stoi["trigger"])

[[[0, 1, 'Potential_therapeutic_event'],
  [1, 2, 'Adverse_event'],
  [2, 3, 'Potential_therapeutic_event'],
  [4, 5, 'Potential_therapeutic_event'],
  [5, 6, 'Potential_therapeutic_event'],
  [6, 7, 'Adverse_event'],
  [7, 8, 'Potential_therapeutic_event'],
  [8, 11, 'Potential_therapeutic_event'],
  [11, 13, 'Potential_therapeutic_event'],
  [13, 14, 'Potential_therapeutic_event'],
  [14, 16, 'Potential_therapeutic_event'],
  [17, 18, 'Adverse_event'],
  [18, 20, 'Potential_therapeutic_event'],
  [20, 22, 'Potential_therapeutic_event'],
  [23, 24, 'Potential_therapeutic_event'],
  [24, 26, 'Potential_therapeutic_event'],
  [26, 27, 'Potential_therapeutic_event'],
  [27, 28, 'Adverse_event'],
  [28, 29, 'Potential_therapeutic_event'],
  [29, 31, 'Potential_therapeutic_event'],
  [31, 32, 'Adverse_event']],
 [[0, 1, 'Potential_therapeutic_event'],
  [1, 3, 'Potential_therapeutic_event'],
  [4, 6, 'Potential_therapeutic_event'],
  [6, 7, 'Potential_therapeutic_event'],
  [7, 8, 'Potenti

In [51]:
def log_sum_exp(tensor, dim=0, keepdim: bool = False):
    """LogSumExp operation used by CRF."""
    m, _ = tensor.max(dim, keepdim=keepdim)
    if keepdim:
        stable_vec = tensor - m
    else:
        stable_vec = tensor - m.unsqueeze(dim)
    return m + (stable_vec.exp().sum(dim, keepdim=keepdim)).log()
def sequence_mask(lens, max_len=None):
    """Generate a sequence mask tensor from sequence lengths, used by CRF."""
    batch_size = lens.size(0)
    if max_len is None:
        max_len = lens.max().item()
    ranges = torch.arange(0, max_len, device=lens.device).long()
    ranges = ranges.unsqueeze(0).expand(batch_size, max_len)
    lens_exp = lens.unsqueeze(1).expand_as(ranges)
    mask = ranges < lens_exp
    return mask
class CRF(nn.Module):
    def __init__(self, label_vocab, bioes=False):
        super(CRF, self).__init__()

        self.label_vocab = label_vocab
        self.label_size = len(label_vocab) + 2
        self.bioes = bioes

        self.start = self.label_size - 2
        self.end = self.label_size - 1
        transition = torch.randn(self.label_size, self.label_size)
        self.transition = nn.Parameter(transition)
        self.initialize()

    def initialize(self):
        self.transition.data[:, self.end] = -100.0
        self.transition.data[self.start, :] = -100.0

        for label, label_idx in self.label_vocab.items():
            if label.startswith('I-') or label.startswith('E-'):
                self.transition.data[label_idx, self.start] = -100.0
            if label.startswith('B-') or label.startswith('I-'):
                self.transition.data[self.end, label_idx] = -100.0

        for label_from, label_from_idx in self.label_vocab.items():
            if label_from == 'O':
                label_from_prefix, label_from_type = 'O', 'O'
            else:
                label_from_prefix, label_from_type = label_from.split('-', 1)

            for label_to, label_to_idx in self.label_vocab.items():
                if label_to == 'O':
                    label_to_prefix, label_to_type = 'O', 'O'
                else:
                    label_to_prefix, label_to_type = label_to.split('-', 1)

                if self.bioes:
                    is_allowed = any(
                        [
                            label_from_prefix in ['O', 'E', 'S']
                            and label_to_prefix in ['O', 'B', 'S'],

                            label_from_prefix in ['B', 'I']
                            and label_to_prefix in ['I', 'E']
                            and label_from_type == label_to_type
                        ]
                    )
                else:
                    is_allowed = any(
                        [
                            label_to_prefix in ['B', 'O'],

                            label_from_prefix in ['B', 'I']
                            and label_to_prefix == 'I'
                            and label_from_type == label_to_type
                        ]
                    )
                if not is_allowed:
                    self.transition.data[
                        label_to_idx, label_from_idx] = -100.0

    def pad_logits(self, logits):
        """Pad the linear layer output with <SOS> and <EOS> scores.
        :param logits: Linear layer output (no non-linear function).
        """
        batch_size, seq_len, _ = logits.size()
        pads = logits.new_full((batch_size, seq_len, 2), -100.0,
                               requires_grad=False)
        logits = torch.cat([logits, pads], dim=2)
        return logits
    
    # 计算转移分数
    def calc_binary_score(self, labels, lens):
        batch_size, seq_len = labels.size()

        # A tensor of size batch_size * (seq_len + 2)
        labels_ext = labels.new_empty((batch_size, seq_len + 2))
        labels_ext[:, 0] = self.start
        labels_ext[:, 1:-1] = labels
        mask = sequence_mask(lens + 1, max_len=(seq_len + 2)).long()
        pad_stop = labels.new_full((1,), self.end, requires_grad=False)
        pad_stop = pad_stop.unsqueeze(-1).expand(batch_size, seq_len + 2)
        labels_ext = (1 - mask) * pad_stop + mask * labels_ext
        labels = labels_ext

        trn = self.transition
        trn_exp = trn.unsqueeze(0).expand(batch_size, self.label_size,
                                          self.label_size)
        lbl_r = labels[:, 1:]
        lbl_rexp = lbl_r.unsqueeze(-1).expand(*lbl_r.size(), self.label_size)
        # score of jumping to a tag
        trn_row = torch.gather(trn_exp, 1, lbl_rexp)

        lbl_lexp = labels[:, :-1].unsqueeze(-1)
        trn_scr = torch.gather(trn_row, 2, lbl_lexp)
        trn_scr = trn_scr.squeeze(-1)

        mask = sequence_mask(lens + 1).float()
        trn_scr = trn_scr * mask
        score = trn_scr

        return score
    
    # 计算状态分数
    def calc_unary_score(self, logits, labels, lens):
        """Checked"""
        labels_exp = labels.unsqueeze(-1)
        scores = torch.gather(logits, 2, labels_exp).squeeze(-1)
        mask = sequence_mask(lens).float()
        scores = scores * mask
        return scores

    def calc_gold_score(self, logits, labels, lens):
        """Checked"""
        unary_score = self.calc_unary_score(logits, labels, lens).sum(
            1).squeeze(-1)
        binary_score = self.calc_binary_score(labels, lens).sum(1).squeeze(-1)
        return unary_score + binary_score

    def calc_norm_score(self, logits, lens):
        batch_size, _, _ = logits.size()
        alpha = logits.new_full((batch_size, self.label_size), -100.0)
        alpha[:, self.start] = 0
        lens_ = lens.clone()

        logits_t = logits.transpose(1, 0)
        for logit in logits_t:
            logit_exp = logit.unsqueeze(-1).expand(batch_size,
                                                   self.label_size,
                                                   self.label_size)
            alpha_exp = alpha.unsqueeze(1).expand(batch_size,
                                                  self.label_size,
                                                  self.label_size)
            trans_exp = self.transition.unsqueeze(0).expand_as(alpha_exp)
            mat = logit_exp + alpha_exp + trans_exp
            alpha_nxt = log_sum_exp(mat, 2).squeeze(-1)

            mask = (lens_ > 0).float().unsqueeze(-1).expand_as(alpha)
            alpha = mask * alpha_nxt + (1 - mask) * alpha
            lens_ = lens_ - 1

        alpha = alpha + self.transition[self.end].unsqueeze(0).expand_as(alpha)
        norm = log_sum_exp(alpha, 1).squeeze(-1)

        return norm

    def loglik(self, logits, labels, lens):
        norm_score = self.calc_norm_score(logits, lens)
        gold_score = self.calc_gold_score(logits, labels, lens)
        return gold_score - norm_score

    def viterbi_decode(self, logits, lens):
        """Borrowed from pytorch tutorial
        Arguments:
            logits: [batch_size, seq_len, n_labels] FloatTensor
            lens: [batch_size] LongTensor
        """
        batch_size, _, n_labels = logits.size()
        vit = logits.new_full((batch_size, self.label_size), -100.0)
        vit[:, self.start] = 0
        c_lens = lens.clone()

        logits_t = logits.transpose(1, 0)
        pointers = []
        for logit in logits_t:
            vit_exp = vit.unsqueeze(1).expand(batch_size, n_labels, n_labels)
            trn_exp = self.transition.unsqueeze(0).expand_as(vit_exp)
            vit_trn_sum = vit_exp + trn_exp
            vt_max, vt_argmax = vit_trn_sum.max(2)

            vt_max = vt_max.squeeze(-1)
            vit_nxt = vt_max + logit
            pointers.append(vt_argmax.squeeze(-1).unsqueeze(0))

            mask = (c_lens > 0).float().unsqueeze(-1).expand_as(vit_nxt)
            vit = mask * vit_nxt + (1 - mask) * vit

            mask = (c_lens == 1).float().unsqueeze(-1).expand_as(vit_nxt)
            vit += mask * self.transition[self.end].unsqueeze(
                0).expand_as(vit_nxt)

            c_lens = c_lens - 1

        pointers = torch.cat(pointers)
        scores, idx = vit.max(1)
        paths = [idx.unsqueeze(1)]
        for argmax in reversed(pointers):
            idx_exp = idx.unsqueeze(-1)
            idx = torch.gather(argmax, 1, idx_exp)
            idx = idx.squeeze(-1)

            paths.insert(0, idx.unsqueeze(1))

        paths = torch.cat(paths[1:], 1)
        scores = scores.squeeze(-1)

        return scores, paths

    def calc_conf_score_(self, logits, labels):
        batch_size, _, _ = logits.size()

        logits_t = logits.transpose(1, 0)
        scores = [[] for _ in range(batch_size)]
        pre_labels = [self.start] * batch_size
        for i, logit in enumerate(logits_t):
            logit_exp = logit.unsqueeze(-1).expand(batch_size,
                                                   self.label_size,
                                                   self.label_size)
            trans_exp = self.transition.unsqueeze(0).expand(batch_size,
                                                            self.label_size,
                                                            self.label_size)
            score = logit_exp + trans_exp
            score = score.view(-1, self.label_size * self.label_size) \
                .softmax(1)
            for j in range(batch_size):
                cur_label = labels[j][i]
                cur_score = score[j][cur_label * self.label_size + pre_labels[j]]
                scores[j].append(cur_score)
                pre_labels[j] = cur_label
        return scores

In [52]:
trigger_crf = CRF(label_stoi["trigger"], bioes=False)
            

In [54]:
entity_label_scores_.size()

torch.Size([3, 35, 7])

In [55]:
def tag_paths_to_spans( paths, token_nums, vocab):
        """
        Convert predicted tag paths to a list of spans (entity mentions or event
        triggers).
        :param paths: predicted tag paths.
        :return (list): a list (batch) of lists (sequence) of spans.
        """
        batch_mentions = []
        itos = {i: s for s, i in vocab.items()}
        for i, path in enumerate(paths):
            mentions = []
            cur_mention = None
            path = path.tolist()[:token_nums[i].item()]
            for j, tag in enumerate(path):
                tag = itos[tag]
                if tag == 'O':
                    prefix = tag = 'O'
                else:
                    prefix, tag = tag.split('-', 1)
                if prefix == 'B':
                    if cur_mention:
                        mentions.append(cur_mention)
                    cur_mention = [j, j + 1, tag]
                elif prefix == 'I':
                    if cur_mention is None:
                        # treat it as B-*
                        cur_mention = [j, j + 1, tag]
                    elif cur_mention[-1] == tag:
                        cur_mention[1] = j + 1
                    else:
                        # treat it as B-*
                        mentions.append(cur_mention)
                        cur_mention = [j, j + 1, tag]
                else:
                    if cur_mention:
                        mentions.append(cur_mention)
                    cur_mention = None
            if cur_mention:
                mentions.append(cur_mention)
            batch_mentions.append(mentions)
            
        return batch_mentions

In [56]:
entity_label_scores_ = trigger_crf.pad_logits(entity_label_scores)
_, entity_label_preds = trigger_crf.viterbi_decode(entity_label_scores_,
                                                        token_nums)
entities = tag_paths_to_spans(entity_label_preds, 
                                    token_nums, 
                                    label_stoi["trigger"])

In [58]:
trigger_crf.loglik(entity_label_scores_, 
                                                           trigger_seqidxs, 
                                                           token_nums)

tensor([-76.2858, -68.1855, -83.8794], grad_fn=<SubBackward0>)

In [31]:
torch.set_printoptions(precision=3,sci_mode=False)
trigger_crf.transition

Parameter containing:
tensor([[    -0.112,      1.147,     -0.032,     -0.184,     -0.351,      1.541,
           -100.000],
        [    -0.821,      0.051,      0.238,     -0.445,     -0.877,     -0.638,
           -100.000],
        [  -100.000,      0.813,      0.533,   -100.000,   -100.000,   -100.000,
           -100.000],
        [    -0.096,      0.842,     -0.645,      0.967,     -1.543,     -0.322,
           -100.000],
        [  -100.000,   -100.000,   -100.000,      0.066,     -1.322,   -100.000,
           -100.000],
        [  -100.000,   -100.000,   -100.000,   -100.000,   -100.000,   -100.000,
           -100.000],
        [     0.985,   -100.000,   -100.000,   -100.000,   -100.000,      0.813,
           -100.000]], requires_grad=True)

In [43]:
entity_label_scores_ = trigger_crf.pad_logits(entity_label_scores)

In [53]:
lens = token_nums
labels = trigger_seqidxs
logits = entity_label_scores_

In [126]:
labels

tensor([[5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6],
        [5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6],
        [5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6]])

In [93]:
labels.new_full((1,), trigger_crf.end, requires_grad=False).unsqueeze(-1).expand(batch_size, seq_len + 2)

tensor([[6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
         6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6],
        [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
         6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6],
        [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
         6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]])

In [95]:

batch_size, seq_len = labels.size()
# A tensor of size batch_size * (seq_len + 2)
labels_ext = labels.new_empty((batch_size, seq_len + 2))
labels_ext[:, 0] = trigger_crf.start
labels_ext[:, 1:-1] = labels
mask = sequence_mask(lens + 1, max_len=(seq_len + 2)).long()

pad_stop = labels.new_full((1,), trigger_crf.end, requires_grad=False)
pad_stop = pad_stop.unsqueeze(-1).expand(batch_size, seq_len + 2)
labels_ext = (1 - mask) * pad_stop + mask * labels_ext
labels = labels_ext
print(labels)

torch.Size([3, 46])
tensor([[5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6],
        [5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6],
        [5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6]])


In [128]:
trigger_crf.transition

Parameter containing:
tensor([[    -0.112,      1.147,     -0.032,     -0.184,     -0.351,      1.541,
           -100.000],
        [    -0.821,      0.051,      0.238,     -0.445,     -0.877,     -0.638,
           -100.000],
        [  -100.000,      0.813,      0.533,   -100.000,   -100.000,   -100.000,
           -100.000],
        [    -0.096,      0.842,     -0.645,      0.967,     -1.543,     -0.322,
           -100.000],
        [  -100.000,   -100.000,   -100.000,      0.066,     -1.322,   -100.000,
           -100.000],
        [  -100.000,   -100.000,   -100.000,   -100.000,   -100.000,   -100.000,
           -100.000],
        [     0.985,   -100.000,   -100.000,   -100.000,   -100.000,      0.813,
           -100.000]], requires_grad=True)

In [105]:
trn = trigger_crf.transition
trn_exp = trn.unsqueeze(0).expand(batch_size, trigger_crf.label_size,
                                    trigger_crf.label_size)
lbl_r = labels[:, 1:]
lbl_rexp = lbl_r.unsqueeze(-1).expand(*lbl_r.size(), trigger_crf.label_size)

In [116]:
# score of jumping to a tag
trn_row = torch.gather(trn_exp, 1, lbl_rexp)

lbl_lexp = labels[:, :-1].unsqueeze(-1)
trn_scr = torch.gather(trn_row, 2, lbl_lexp)
trn_scr = trn_scr.squeeze(-1)

mask = sequence_mask(lens + 1).float()
trn_scr = trn_scr * mask
score = trn_scr

In [125]:
score

tensor([[ 1.541, -0.112, -0.112, -0.112, -0.112, -0.112, -0.112, -0.112, -0.112,
         -0.112, -0.112, -0.112, -0.112, -0.821,  1.147, -0.112, -0.112, -0.112,
         -0.112, -0.112, -0.112, -0.112, -0.112, -0.112, -0.112, -0.112, -0.112,
         -0.112, -0.112, -0.112,  0.985, -0.000, -0.000, -0.000, -0.000, -0.000,
         -0.000, -0.000, -0.000, -0.000, -0.000, -0.000, -0.000, -0.000, -0.000],
        [ 1.541, -0.112, -0.112, -0.112, -0.112, -0.112, -0.112, -0.112, -0.112,
         -0.821,  1.147, -0.112, -0.112, -0.821,  1.147, -0.112, -0.112, -0.112,
         -0.112, -0.112, -0.112, -0.112, -0.112, -0.112, -0.112,  0.985, -0.000,
         -0.000, -0.000, -0.000, -0.000, -0.000, -0.000, -0.000, -0.000, -0.000,
         -0.000, -0.000, -0.000, -0.000, -0.000, -0.000, -0.000, -0.000, -0.000],
        [ 1.541, -0.112, -0.112, -0.112, -0.112, -0.112, -0.112, -0.112, -0.112,
         -0.112, -0.112, -0.112, -0.112, -0.112, -0.112, -0.112, -0.112, -0.112,
         -0.112, -0.112, -

In [113]:
lbl_lexp = labels[:, :-1].unsqueeze(-1)

In [118]:
lbl_lexp

tensor([[[5],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [1],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [6],
         [6],
         [6],
         [6],
         [6],
         [6],
         [6],
         [6],
         [6],
         [6],
         [6],
         [6],
         [6],
         [6]],

        [[5],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [1],
         [0],
         [0],
         [0],
         [1],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
    

In [117]:
torch.gather(trn_row, 2, lbl_lexp)

tensor([[[   1.541],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.821],
         [   1.147],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [  -0.112],
         [   0.985],
         [-100.000],
         [-100.000],
         [-100.000],
         [-100.000],
         [-100.000],
         [-100.000],
         [-100.000],
         [-100.000],
         [-100.000],
         [-100.000],
         [-100.000],
         [-100.000],
         [-100.000],
         [-100.000]],

        [[   1.541],
         [  -0.112],
         [ 

In [81]:
binary_score = trigger_crf.calc_binary_score(labels, lens).sum(1).squeeze(-1)

In [83]:
unary_score + binary_score

tensor([3.536, 4.560, 3.051], grad_fn=<AddBackward0>)

In [87]:
entity_label_loglik = trigger_crf.loglik(entity_label_scores_, 
                                                           trigger_seqidxs, 
                                                           token_nums)

In [89]:
entity_label_loglik.mean()

tensor(-78.1870, grad_fn=<MeanBackward0>)