In [2]:
import os
import sys
import json
import torch
import argparse
import pickle as cPickle
import logging

import spacy
from nltk.corpus import wordnet as wn

for pack in os.listdir("src"):
    sys.path.append(os.path.join("src", pack))

sys.path.append("/src/shared/")

from collections import defaultdict
from swirl_parsing import parse_swirl_output
from allen_srl_reader import read_srl
from create_bert_embeddings import *
from classes import Document, Sentence, Token, EventMention, EntityMention
from extraction_utils import *

In [6]:
config_path = "build_features_config.json"
output_path="data/feature"

In [7]:
nlp = spacy.load('en_core_web_sm')

with open(config_path, 'r') as js_file:
    config_dict = json.load(js_file)

with open(os.path.join(output_path,'build_features_config.json'), "w") as js_file:
    json.dump(config_dict, js_file, indent=4, sort_keys=True)



In [8]:

def load_mentions_from_json(mentions_json_file, docs, is_event, is_gold_mentions):
    '''
    Loading mentions from JSON file and add those to the documents objects
    :param mentions_json_file: the JSON file contains the mentions
    :param docs:  set of document objects
    :param is_event: a boolean indicates whether the function extracts event or entity mentions
    :param is_gold_mentions: a boolean indicates whether the function extracts gold or predicted
    mentions
    '''
    with open(mentions_json_file, 'r') as js_file:
        js_mentions = json.load(js_file)

    for js_mention in js_mentions:
        doc_id = js_mention["doc_id"].replace('.xml', '')
        sent_id = js_mention["sent_id"]
        tokens_numbers = js_mention["tokens_number"]
        mention_type = js_mention["mention_type"]
        is_singleton = js_mention["is_singleton"]
        is_continuous = js_mention["is_continuous"]
        mention_str = js_mention["tokens_str"]
        coref_chain = js_mention["coref_chain"]
        if mention_str is None:
            print(js_mention)
        head_text, head_lemma = find_head(mention_str)
        score = js_mention["score"]
        try:
            token_objects = docs[doc_id].get_sentences()[sent_id].find_mention_tokens(tokens_numbers)
        except:
            print('error when looking for mention tokens')
            print('doc id {} sent id {}'.format(doc_id, sent_id))
            print('token numbers - {}'.format(str(tokens_numbers)))
            print('mention string {}'.format(mention_str))
            print('sentence - {}'.format(docs[doc_id].get_sentences()[sent_id].get_raw_sentence()))
            raise

        # Sanity check - check if all mention's tokens can be found
        if not token_objects:
            print('Can not find tokens of a mention - {} {} {}'.format(doc_id, sent_id,tokens_numbers))

        # Mark the mention's gold coref chain in its tokens
        if is_gold_mentions:
            for token in token_objects:
                if is_event:
                    token.gold_event_coref_chain.append(coref_chain)
                else:
                    token.gold_entity_coref_chain.append(coref_chain)

        if is_event:
            mention = EventMention(doc_id, sent_id, tokens_numbers,token_objects,mention_str, head_text,
                                   head_lemma, is_singleton, is_continuous, coref_chain)
        else:
            mention = EntityMention(doc_id, sent_id, tokens_numbers,token_objects, mention_str, head_text,
                                    head_lemma, is_singleton, is_continuous, coref_chain, mention_type)

        mention.probability = score  # a confidence score for predicted mentions (if used), set gold mentions prob to 1.0
        if is_gold_mentions:
            docs[doc_id].get_sentences()[sent_id].add_gold_mention(mention, is_event)
        else:
            docs[doc_id].get_sentences()[sent_id]. \
                add_predicted_mention(mention, is_event,
                                      relaxed_match=config_dict["relaxed_match_with_gold_mention"])


def load_gold_mentions(docs,events_json, entities_json):
    '''
    A function loads gold event and entity mentions
    :param docs: set of document objects
    :param events_json:  a JSON file contains the gold event mentions (of a specific split - train/dev/test)
    :param entities_json: a JSON file contains the gold entity mentions (of a specific split - train/dev/test)
    '''
    load_mentions_from_json(events_json,docs,is_event=True, is_gold_mentions=True)
    load_mentions_from_json(entities_json,docs,is_event=False, is_gold_mentions=True)


def load_predicted_mentions(docs,events_json, entities_json):
    '''
    This function loads predicted event and entity mentions
    :param docs: set of document objects
    :param events_json:  a JSON file contains predicted event mentions (of a specific split - train/dev/test)
    :param entities_json: a JSON file contains predicted entity mentions (of a specific split - train/dev/test)
    '''
    load_mentions_from_json(events_json,docs,is_event=True, is_gold_mentions=False)
    load_mentions_from_json(entities_json,docs,is_event=False, is_gold_mentions=False)


def load_gold_data(split_txt_file, events_json, entities_json):
    '''
    This function loads the texts of each split and its gold mentions, create document objects
    and stored the gold mentions within their suitable document objects
    :param split_txt_file: the text file of each split is written as 5 columns (stored in data/intermid)
    :param events_json: a JSON file contains the gold event mentions
    :param entities_json: a JSON file contains the gold event mentions
    :return:
    '''
    logger.info('Loading gold mentions...')
    docs = load_ECB_plus(split_txt_file)
    load_gold_mentions(docs, events_json, entities_json)
    return docs


def load_predicted_data(docs, pred_events_json, pred_entities_json):
    '''
    This function loads the predicted mentions and stored them within their suitable document objects
    (suitable for loading the test data)
    :param docs: dictionary that contains the document objects
    :param pred_events_json: a JSON file contains predicted event mentions
    :param pred_entities_json: a JSON file contains predicted event mentions
    :return:
    '''
    logger.info('Loading predicted mentions...')
    load_predicted_mentions(docs, pred_events_json, pred_entities_json)


def find_head(x):
    '''
    This function finds the head and head lemma of a mention x
    :param x: A mention object
    :return: the head word and
    '''

    x_parsed = nlp(x)
    for tok in x_parsed:
        if tok.head == tok:
            if tok.lemma_ == u'-PRON-':
                return tok.text, tok.text.lower()
            return tok.text,tok.lemma_


def have_string_match(mention,arg_str ,arg_start, arg_end):
    '''
    This function checks whether a given entity mention has a string match (strict or relaxed)
    with a span of an extracted argument
    :param mention: a candidate entity mention
    :param arg_str: the argument's text
    :param arg_start: the start index of the argument's span
    :param arg_end: the end index of the argument's span
    :return: True if there is a string match (strict or relaxed) between the entity mention
    and the extracted argument's span, and false otherwise
    '''
    if mention.mention_str == arg_str and mention.start_offset == arg_start:  # exact string match + same start index
        return True
    if mention.mention_str == arg_str:  # exact string match
        return True
    if mention.start_offset >= arg_start and mention.end_offset <= arg_end:  # the argument span contains the mention span
        return True
    if arg_start >= mention.start_offset and arg_end <= mention.end_offset:  # the mention span contains the mention span
        return True
    if len(set(mention.tokens_numbers).intersection(set(range(arg_start,arg_end + 1)))) > 0: # intersection between the mention's tokens and the argument's tokens
        return True
    return False


def add_arg_to_event(entity, event, rel_name):
    '''
    Adds the entity mention as an argument (in a specific role) of an event mention and also adds the
    event mention as predicate (in a specific role) of the entity mention
    :param entity: an entity mention object
    :param event: an event mention object
    :param rel_name: the specific role
    '''
    if rel_name == 'A0':
        event.arg0 = (entity.mention_str, entity.mention_id)
        entity.add_predicate((event.mention_str, event.mention_id), 'A0')
    elif rel_name == 'A1':
        event.arg1 = (entity.mention_str, entity.mention_id)
        entity.add_predicate((event.mention_str, event.mention_id), 'A1')
    elif rel_name == 'AM-TMP':
        event.amtmp = (entity.mention_str, entity.mention_id)
        entity.add_predicate((event.mention_str, event.mention_id), 'AM-TMP')
    elif rel_name == 'AM-LOC':
        event.amloc = (entity.mention_str, entity.mention_id)
        entity.add_predicate((event.mention_str, event.mention_id), 'AM-LOC')


def find_argument(rel_name, rel_tokens, matched_event, sent_entities, sent_obj, is_gold, srl_obj):
    '''
    This function matches between an argument of an event mention and an entity mention.
    :param rel_name: the specific role of the argument
    :param rel_tokens: the argument's tokens
    :param matched_event: the event mention
    :param sent_entities: a entity mentions exist in the event's sentence.
    :param sent_obj: the object represents the sentence
    :param is_gold: whether the argument need to be matched with a gold mention or not
    :param srl_obj: an object represents the extracted SRL argument.
    :return True if the extracted SRL argument was matched with an entity mention.
    '''
    arg_start_ix = rel_tokens[0]
    if len(rel_tokens) > 1:
        arg_end_ix = rel_tokens[1]
    else:
        arg_end_ix = rel_tokens[0]

    if arg_end_ix >= len(sent_obj.get_tokens()):
        print('argument bound mismatch with sentence length')
        print('arg start index - {}'.format(arg_start_ix))
        print('arg end index - {}'.format(arg_end_ix))
        print('sentence length - {}'.format(len(sent_obj.get_tokens())))
        print('raw sentence: {}'.format(sent_obj.get_raw_sentence()))
        print('matched event: {}'.format(str(matched_event)))
        print('srl obj - {}'.format(str(srl_obj)))

    arg_str, arg_tokens = sent_obj.fetch_mention_string(arg_start_ix, arg_end_ix)

    entity_found = False
    matched_entity = None
    for entity in sent_entities:
        if have_string_match(entity, arg_str, arg_start_ix, arg_end_ix):
            if rel_name == 'AM-TMP' and entity.mention_type != 'TIM':
                continue
            if rel_name == 'AM-LOC' and entity.mention_type != 'LOC':
                continue
            entity_found = True
            matched_entity = entity
            break
    if entity_found:
        add_arg_to_event(matched_entity, matched_event, rel_name)
        if is_gold:
            return True
        else:
            if matched_entity.gold_mention_id is not None:
                return True
            else:
                return False
    else:
        return False


def match_allen_srl_structures(dataset, srl_data, is_gold):
    '''
    Matches between extracted predicates and event mentions and between their arguments and
    entity mentions, designed to handle the output of Allen NLP SRL system
    :param dataset: an object represents the spilt (train/dev/test)
    :param srl_data: a dictionary contains the predicate-argument structures
    :param is_gold: whether to match predicate-argument structures with gold mentions or with predicted mentions
    '''
    matched_events_count = 0
    matched_args_count = 0

    for topic_id, topic in dataset.topics.items():
        for doc_id, doc in topic.docs.items():
            for sent_id, sent in doc.get_sentences().items():
                # Handling nominalizations in case we don't use syntactic dependencies (which already handle this)
                if not config_dict["use_dep"]:
                    sent_str = sent.get_raw_sentence()
                    parsed_sent = nlp(sent_str)
                    find_nominalizations_args(parsed_sent, sent, is_gold)
                sent_srl_info = None

                if doc_id in srl_data:
                    doc_srl = srl_data[doc_id]
                    if int(sent_id) in doc_srl:
                        sent_srl_info = doc_srl[int(sent_id)]

                if sent_srl_info is not None:
                    for event_srl in sent_srl_info.srl:
                        event_text = event_srl.verb.text
                        event_ecb_tok_ids = event_srl.verb.ecb_tok_ids

                        if is_gold:
                            sent_events = sent.gold_event_mentions
                            sent_entities = sent.gold_entity_mentions
                        else:
                            sent_events = sent.pred_event_mentions
                            sent_entities = sent.pred_entity_mentions
                        event_found = False
                        matched_event = None

                        for event_mention in sent_events:
                            if event_ecb_tok_ids == event_mention.tokens_numbers or \
                                    event_text == event_mention.mention_str or \
                                    event_text in event_mention.mention_str or \
                                    event_mention.mention_str in event_text:
                                event_found = True
                                matched_event = event_mention
                                if is_gold:
                                    matched_events_count += 1
                                elif matched_event.gold_mention_id is not None:
                                    matched_events_count += 1
                            if event_found:
                                break
                        if event_found:
                            if event_srl.arg0 is not None:
                                if match_entity_with_srl_argument(sent_entities, matched_event,
                                                                  event_srl.arg0, 'A0', is_gold):
                                    matched_args_count += 1

                            if event_srl.arg1 is not None:
                                if match_entity_with_srl_argument(sent_entities, matched_event,
                                                                  event_srl.arg1, 'A1', is_gold):
                                    matched_args_count += 1
                            if event_srl.arg_tmp is not None:
                                if match_entity_with_srl_argument(sent_entities, matched_event,
                                                                  event_srl.arg_tmp, 'AM-TMP', is_gold):
                                    matched_args_count += 1

                            if event_srl.arg_loc is not None:
                                if match_entity_with_srl_argument(sent_entities, matched_event,
                                                                  event_srl.arg_loc, 'AM-LOC', is_gold):
                                    matched_args_count += 1

    logger.info('SRL matched events - ' + str(matched_events_count))
    logger.info('SRL matched args - ' + str(matched_args_count))


def match_entity_with_srl_argument(sent_entities, matched_event ,srl_arg,rel_name, is_gold):
    '''
    This function matches between an argument of an event mention and an entity mention.
    Designed to handle the output of Allen NLP SRL system
    :param sent_entities: the entity mentions in the event's sentence
    :param matched_event: the event mention
    :param srl_arg: the extracted argument
    :param rel_name: the role name
    :param is_gold: whether to match the argument with gold entity mention or with predicted entity mention
    :return:
    '''
    found_entity = False
    matched_entity = None
    for entity in sent_entities:
        if srl_arg.ecb_tok_ids == entity.tokens_numbers or \
                srl_arg.text == entity.mention_str or \
                srl_arg.text in entity.mention_str or \
                entity.mention_str in srl_arg.text:
            if rel_name == 'AM-TMP' and entity.mention_type != 'TIM':
                continue
            if rel_name == 'AM-LOC' and entity.mention_type != 'LOC':
                continue
            found_entity = True
            matched_entity = entity

        if found_entity:
            break

    if found_entity:
        add_arg_to_event(matched_entity, matched_event, rel_name)
        if is_gold:
            return True
        else:
            if matched_entity.gold_mention_id is not None:
                return True
            else:
                return False
    else:
        return False


def load_srl_info(dataset, srl_data, is_gold):
    '''
    Matches between extracted predicates and event mentions and between their arguments and
    entity mentions.
    :param dataset: an object represents the spilt (train/dev/test)
    :param srl_data: a dictionary contains the predicate-argument structures
    :param is_gold: whether to match predicate-argument structures with gold mentions or with predicted mentions
    '''
    matched_events_count = 0
    unmatched_event_count = 0
    matched_args_count = 0

    matched_identified_events = 0
    matched_identified_args = 0
    for topic_id, topic in dataset.topics.items():
        for doc_id, doc in topic.docs.items():
            for sent_id, sent in doc.get_sentences().items():
                # Handling nominalizations if we don't use dependency parsing (that already handles it)
                if not config_dict["use_dep"]:
                    sent_str = sent.get_raw_sentence()
                    parsed_sent = nlp(sent_str)
                    find_nominalizations_args(parsed_sent, sent, is_gold)
                sent_srl_info = {}

                if doc_id in srl_data:
                    doc_srl = srl_data[doc_id]
                    if int(sent_id) in doc_srl:
                        sent_srl_info = doc_srl[int(sent_id)]
                else:
                    print('doc not in srl data - ' + doc_id)

                for event_key, srl_obj in sent_srl_info.items():
                    if is_gold:
                        sent_events = sent.gold_event_mentions
                        sent_entities = sent.gold_entity_mentions
                    else:
                        sent_events = sent.pred_event_mentions
                        sent_entities = sent.pred_entity_mentions
                    event_found = False
                    matched_event = None
                    for event_mention in sent_events:
                        if event_key in event_mention.tokens_numbers:
                            event_found = True
                            matched_event = event_mention
                            if is_gold:
                                matched_events_count += 1
                            elif matched_event.gold_mention_id is not None:
                                    matched_events_count += 1
                        if event_found:
                            break
                    if event_found:
                        for rel_name, rel_tokens in srl_obj.arg_info.items():
                            if find_argument(rel_name, rel_tokens, matched_event, sent_entities, sent, is_gold,srl_obj):
                                matched_args_count += 1
                    else:
                        unmatched_event_count += 1
    logger.info('SRL matched events - ' + str(matched_events_count))
    logger.info('SRL unmatched events - ' + str(unmatched_event_count))
    logger.info('SRL matched args - ' + str(matched_args_count))


def find_topic_gold_clusters(topic):
    '''
    Finds the gold clusters of a specific topic
    :param topic: a topic object
    :return: a mapping of coref chain to gold cluster (for a specific topic) and the topic's mentions
    '''
    event_mentions = []
    entity_mentions = []
    # event_gold_tag_to_cluster = defaultdict(list)
    # entity_gold_tag_to_cluster = defaultdict(list)

    event_gold_tag_to_cluster = {}
    entity_gold_tag_to_cluster = {}

    for doc_id, doc in topic.docs.items():
        for sent_id, sent in doc.sentences.items():
            event_mentions.extend(sent.gold_event_mentions)
            entity_mentions.extend(sent.gold_entity_mentions)

    for event in event_mentions:
        if event.gold_tag != '-':
            if event.gold_tag not in event_gold_tag_to_cluster:
                event_gold_tag_to_cluster[event.gold_tag] = []
            event_gold_tag_to_cluster[event.gold_tag].append(event)
    for entity in entity_mentions:
        if entity.gold_tag != '-':
            if entity.gold_tag not in entity_gold_tag_to_cluster:
                entity_gold_tag_to_cluster[entity.gold_tag] = []
            entity_gold_tag_to_cluster[entity.gold_tag].append(entity)

    return event_gold_tag_to_cluster, entity_gold_tag_to_cluster, event_mentions, entity_mentions


def write_dataset_statistics(split_name, dataset, check_predicted):
    '''
    Prints the split statistics
    :param split_name: the split name (a string)
    :param dataset: an object represents the split
    :param check_predicted: whether to print statistics of predicted mentions too
    '''
    docs_count = 0
    sent_count = 0
    event_mentions_count = 0
    entity_mentions_count = 0
    event_chains_count = 0
    entity_chains_count = 0
    topics_count = len(dataset.topics.keys())
    predicted_events_count = 0
    predicted_entities_count = 0
    matched_predicted_event_count = 0
    matched_predicted_entity_count = 0


    for topic_id, topic in dataset.topics.items():
        event_gold_tag_to_cluster, entity_gold_tag_to_cluster, \
        event_mentions, entity_mentions = find_topic_gold_clusters(topic)

        docs_count += len(topic.docs.keys())
        sent_count += sum([len(doc.sentences.keys()) for doc_id, doc in topic.docs.items()])
        event_mentions_count += len(event_mentions)
        entity_mentions_count += len(entity_mentions)

        entity_chains = set()
        event_chains = set()

        for mention in entity_mentions:
            entity_chains.add(mention.gold_tag)

        for mention in event_mentions:
            event_chains.add(mention.gold_tag)

        # event_chains_count += len(set(event_gold_tag_to_cluster.keys()))
        # entity_chains_count += len(set(entity_gold_tag_to_cluster.keys()))

        event_chains_count += len(event_chains)
        entity_chains_count += len(entity_chains)

        if check_predicted:
            for doc_id, doc in topic.docs.items():
                for sent_id, sent in doc.sentences.items():
                    pred_events = sent.pred_event_mentions
                    pred_entities = sent.pred_entity_mentions

                    predicted_events_count += len(pred_events)
                    predicted_entities_count += len(pred_entities)

                    for pred_event in pred_events:
                        if pred_event.has_compatible_mention:
                            matched_predicted_event_count += 1

                    for pred_entity in pred_entities:
                        if pred_entity.has_compatible_mention:
                            matched_predicted_entity_count += 1

    with open(os.path.join(args.output_path, '{}_statistics.txt'.format(split_name)), 'w') as f:
        f.write('Number of topics - {}\n'.format(topics_count))
        f.write('Number of documents - {}\n'.format(docs_count))
        f.write('Number of sentences - {}\n'.format(sent_count))
        f.write('Number of event mentions - {}\n'.format(event_mentions_count))
        f.write('Number of entity mentions - {}\n'.format(entity_mentions_count))

        if check_predicted:
            f.write('Number of predicted event mentions  - {}\n'.format(predicted_events_count))
            f.write('Number of predicted entity mentions - {}\n'.format(predicted_entities_count))
            f.write('Number of predicted event mentions that match gold mentions- '
                    '{} ({}%)\n'.format(matched_predicted_event_count,
                                        (matched_predicted_event_count/float(event_mentions_count)) *100 ))
            f.write('Number of predicted entity mentions that match gold mentions- '
                    '{} ({}%)\n'.format(matched_predicted_entity_count,
                                        (matched_predicted_entity_count / float(entity_mentions_count)) * 100))


def obj_dict(obj):
    obj_d = obj.__dict__
    obj_d = stringify_keys(obj_d)
    return obj_d


def stringify_keys(d):
    """Convert a dict's keys to strings if they are not."""
    for key in d.keys():

        # check inner dict
        if isinstance(d[key], dict):
            value = stringify_keys(d[key])
        else:
            value = d[key]

        # convert nonstring to string if needed
        if not isinstance(key, str):
            try:
                d[str(key)] = value
            except Exception:
                try:
                    d[repr(key)] = value
                except Exception:
                    pass

            # delete old key
            del d[key]
    return d


def set_embed_to_mention(mention, sent_embeddings):
    '''
    Sets the Bert embeddings of a mention
    :param mention: event/entity mention object
    :param sent_embeddings: the embedding for each word in the sentence produced by Bert model
    :return:
    '''
    print()
    head_index = int(mention.get_head_index())+1     # Bert embedding has pads
    print(head_index)
    print(len(sent_embeddings))
    head_embeddings = sent_embeddings[head_index]
    mention.head_bert_embeddings = head_embeddings


def set_embeddings_to_mentions(embedder, sentence, set_pred_mentions):
    '''
     Sets the ELMo embeddings for all the mentions in the sentence
    :param embedder: a wrapper object
    :param sentence: a sentence object
    '''
    embedding = embedder.get_embedding(sentence)
    event_mentions = sentence.gold_event_mentions
    entity_mentions = sentence.gold_entity_mentions

    for event in event_mentions:
        set_embed_to_mention(event, embedding)
    for entity in entity_mentions:
        set_embed_to_mention(entity, embedding)

    # Set the contextualized vector also for predicted mentions
    if set_pred_mentions:
        for event in sentence.pred_event_mentions:
            set_embed_to_mention(event, embedding)  # set the head contextualized vector

        for entity in sentence.pred_entity_mentions:
            set_embed_to_mention(entity, embedding)  # set the head contextualized vector


def load_embeddings(dataset, embedder, set_pred_mentions):
    '''
    Sets the ELMo embeddings for all the mentions in the split
    :param dataset: an object represents a split (train/dev/test)
    :param embedder: a wrapper object
    :return:
    '''
    for topic_id, topic in dataset.topics.items():
        for doc_id, doc in topic.docs.items():
            for sent_id, sent in doc.get_sentences().items():
                set_embeddings_to_mentions(embedder, sent, set_pred_mentions)


In [9]:
training_data = load_gold_data(config_dict["train_text_file"],config_dict["train_event_mentions"],
                               config_dict["train_entity_mentions"])


In [10]:
train_set = order_docs_by_topics(training_data)

In [17]:
topic_id, topic = list(train_set.topics.items())[0]
_, doc = list(topic.docs.items())[0]
_, sent = list(doc.get_sentences().items())[0]
sent

<classes.Sentence at 0x7f84f36934a8>

In [117]:
text = sent.get_raw_sentence()
text

'Perennial party girl Tara Reid checked herself into Promises Treatment Center , her rep told People .'

In [140]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaModel.from_pretrained('roberta-base',
                                     output_hidden_states = True)
model.eval()

RobertaModel(
  (embeddings): RobertaEmbeddings(
    (word_embeddings): Embedding(50265, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (token_type_embeddings): Embedding(1, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): RobertaEncoder(
    (layer): ModuleList(
      (0): RobertaLayer(
        (attention): RobertaAttention(
          (self): RobertaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): RobertaSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Drop

In [202]:
text = " The  Mean Girls  star fooled an L . A Superior Court judge , a Santa Monica prosecutor and apparently her own lawyer , by pretending to check into a rehabilitation facility , but chickening out when she got there ."
# text = "Perennial party girl Tara Reid checked herself into Promises Treatment Center , her rep told People ."

In [203]:
# Split the sentence into tokens.
tokenized_text = tokenizer.tokenize(text)

# Map the token strings to their vocabulary indeces.
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

In [204]:
for tup in zip(tokenized_text, indexed_tokens):
    print('{:<12} {:>6,}'.format(tup[0], tup[1]))
len(tokenized_text)

ĠThe             20
Ġ             1,437
ĠMean        30,750
ĠGirls        7,707
Ġ             1,437
Ġstar           999
Ġfooled      31,952
Ġan              41
ĠL              226
Ġ.              479
ĠA               83
ĠSuperior    11,486
ĠCourt          837
Ġjudge        1,679
Ġ,            2,156
Ġa               10
ĠSanta        2,005
ĠMonica      12,811
Ġprosecutor   5,644
Ġand              8
Ġapparently   4,100
Ġher             69
Ġown            308
Ġlawyer       2,470
Ġ,            2,156
Ġby              30
Ġpretending  23,748
Ġto               7
Ġcheck        1,649
Ġinto            88
Ġa               10
Ġrehabilitation 11,226
Ġfacility     2,122
Ġ,            2,156
Ġbut             53
Ġchick       30,802
ening         4,226
Ġout             66
Ġwhen            77
Ġshe             79
Ġgot            300
Ġthere           89
Ġ.              479


43

In [205]:
segments_ids = [1] * len(tokenized_text)

tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

In [186]:
with torch.no_grad(): 
    outputs = model(tokens_tensor, segments_tensors) 
    hidden_states = outputs[2]


In [187]:
print ("Number of layers:", len(hidden_states), "  (initial embeddings + 12 BERT layers)")
layer_i = 0

print ("Number of batches:", len(hidden_states[layer_i]))
batch_i = 0

print ("Number of tokens:", len(hidden_states[layer_i][batch_i]))
token_i = 0

print ("Number of hidden units:", len(hidden_states[layer_i][batch_i][token_i]))

Number of layers: 13   (initial embeddings + 12 BERT layers)
Number of batches: 1
Number of tokens: 43
Number of hidden units: 768


In [188]:


# Concatenate the tensors for all layers. We use `stack` here to
# create a new dimension in the tensor.
token_embeddings = torch.stack(hidden_states, dim=0)

token_embeddings.size()


torch.Size([13, 1, 43, 768])

In [189]:
token_embeddings = torch.stack(hidden_states, dim=0)

In [190]:
token_embeddings = torch.squeeze(token_embeddings, dim=1) 
token_embeddings = token_embeddings.permute(1,0,2)

In [191]:
token_vecs_sum = []
for token in token_embeddings:
    sum_vec = torch.sum(token[-4:], dim=0)       # Sum the vectors from the last four layers.
    token_vecs_sum.append(sum_vec)

In [192]:
len(token_vecs_sum)

43

In [193]:
tokenized_text

['The',
 'Ġ',
 'ĠMean',
 'ĠGirls',
 'Ġ',
 'Ġstar',
 'Ġfooled',
 'Ġan',
 'ĠL',
 'Ġ.',
 'ĠA',
 'ĠSuperior',
 'ĠCourt',
 'Ġjudge',
 'Ġ,',
 'Ġa',
 'ĠSanta',
 'ĠMonica',
 'Ġprosecutor',
 'Ġand',
 'Ġapparently',
 'Ġher',
 'Ġown',
 'Ġlawyer',
 'Ġ,',
 'Ġby',
 'Ġpretending',
 'Ġto',
 'Ġcheck',
 'Ġinto',
 'Ġa',
 'Ġrehabilitation',
 'Ġfacility',
 'Ġ,',
 'Ġbut',
 'Ġchick',
 'ening',
 'Ġout',
 'Ġwhen',
 'Ġshe',
 'Ġgot',
 'Ġthere',
 'Ġ.']

In [194]:
def find_token_index(tokenized_text): 

    output = []
    current_word = []

    for i in range(0,len(tokenized_text)):
        if "Ġ" in tokenized_text[i]:
            output.append(current_word)     # store the previous word
            current_word = [i]              # start documenting the current one
        else:
            current_word.append(i)
    output.append(current_word)      # the last word was not handled in the loop
    
    return output

find_token_index(tokenized_text)

[[0],
 [1],
 [2],
 [3],
 [4],
 [5],
 [6],
 [7],
 [8],
 [9],
 [10],
 [11],
 [12],
 [13],
 [14],
 [15],
 [16],
 [17],
 [18],
 [19],
 [20],
 [21],
 [22],
 [23],
 [24],
 [25],
 [26],
 [27],
 [28],
 [29],
 [30],
 [31],
 [32],
 [33],
 [34],
 [35, 36],
 [37],
 [38],
 [39],
 [40],
 [41],
 [42]]

In [195]:
def get_mean_embedding(embeddings):
    '''
    :param embeddings: a list of embeddings
    '''
    arrays = [np.array(x) for x in embeddings]
    return [np.mean(k) for k in zip(*arrays)]

In [196]:
get_mean_embedding(token_vecs_sum[0:3])

[0.7046022,
 -1.4052979,
 0.80795664,
 0.38080928,
 4.24604,
 1.202809,
 -0.2752114,
 -1.423124,
 0.23219852,
 0.6524685,
 -0.34678546,
 -0.67077875,
 -0.51994246,
 1.3943027,
 -0.26988962,
 -0.0012127757,
 -1.1921684,
 0.3630447,
 -0.73556954,
 0.12920384,
 1.170262,
 0.17122309,
 -0.6413205,
 0.4025062,
 1.2127565,
 -1.2118548,
 -0.21689375,
 0.84364134,
 0.7030599,
 -0.78874034,
 -0.2520011,
 0.8663475,
 0.038869243,
 -0.7443525,
 -0.19262664,
 -0.40860426,
 0.8592184,
 -0.3714654,
 0.7567107,
 1.3348843,
 2.0979197,
 0.34157583,
 -0.02488198,
 -0.7637529,
 0.58489853,
 -0.37604877,
 0.60647196,
 1.5122467,
 -0.2804472,
 0.36757317,
 0.41305304,
 -0.16455531,
 0.60271543,
 -0.13022399,
 1.3673507,
 0.48400116,
 -1.0657852,
 -2.95632,
 -0.21301039,
 0.34348455,
 -0.33216098,
 8.539382,
 1.0291876,
 -0.013810138,
 0.17909224,
 -0.97374874,
 0.24259706,
 -1.533796,
 -0.07433193,
 -0.13367741,
 0.37843534,
 -0.04649973,
 0.43960524,
 -0.33108437,
 -0.44653288,
 -0.74886894,
 1.399036,
 

In [200]:
token_indices = find_token_index(tokenized_text)
embeddings = []
for token_index in token_indices:
    start_index = token_index[0]
    end_index = token_index[-1] + 1
    if len(token_index) != 1:
        embeddings.append(get_mean_embedding(token_vecs_sum[start_index:end_index]))
    else:
        print(token_index)
        embeddings.append(token_vecs_sum[token_index[0]])

[0]
[1]
[2]
[3]
[4]
[5]
[6]
[7]
[8]
[9]
[10]
[11]
[12]
[13]
[14]
[15]
[16]
[17]
[18]
[19]
[20]
[21]
[22]
[23]
[24]
[25]
[26]
[27]
[28]
[29]
[30]
[31]
[32]
[33]
[34]
[37]
[38]
[39]
[40]
[41]
[42]


In [198]:
len(embeddings)

42

In [201]:
print(token_indices)

[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15], [16], [17], [18], [19], [20], [21], [22], [23], [24], [25], [26], [27], [28], [29], [30], [31], [32], [33], [34], [35, 36], [37], [38], [39], [40], [41], [42]]
