## Installations & Imports

In [1]:
# install dependencies
# !pip install torch
# !pip install transformers==4.34.0
# !pip install pycorenlp==0.3.0
# !pip install python-dotenv==1.0.0
# !pip install pytorch-pretrained-bert==0.6.2
# !pip install pandas==2.0.2

In [2]:
import torch
from huggingface_hub.hf_api import HfFolder 
import os, json, re, contextlib
# import AttrDict
import numpy as np
from typing import List
from torch.utils.data import DataLoader, SequentialSampler,TensorDataset
from pycorenlp import StanfordCoreNLP
import pandas as pd
pd.set_option('display.max_columns', None)

from transformers import AutoTokenizer, AutoModel, AutoConfig
from pickle import FALSE, TRUE
from tqdm import tqdm
tqdm.pandas()

from dotenv import load_dotenv
load_dotenv()

class AttrDict(dict):

    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self


## Get Model from HuggingFace

In [3]:
# private hf model token needed. give write persmission if push to hub is expected
# save in .env file with HF_TOKEN key
# see https://huggingface.co/docs/transformers.js/guides/private
HfFolder.save_token(os.environ.get("HF_TOKEN"))
os.environ['TRANSFORMERS_CACHE'] = './cache/'

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() and not False else "cpu")
# Load tokenizer and model directly from hub
tokenizer = AutoTokenizer.from_pretrained("ibm/probert", use_auth_token=True, trust_remote_code=True)
model = AutoModel.from_pretrained("ibm/probert", use_auth_token=True, trust_remote_code=True)
model.to(device)

ProBERT(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
      

## Extract Entities and Pronouns

**Note:** Any library can be used to extract the entities and pronouns.

In [5]:
binary_pronouns = ['she', 'her', 'hers', 'he', 'him', 'his']

In [6]:
_corenlp_url = 'https://corenlp.run/'
corenlp = StanfordCoreNLP(_corenlp_url)

In [7]:
def is_pronoun(sentence, entity):
    for token_dict in sentence['tokens']:
        if token_dict['originalText'] == entity:
            return token_dict['pos'] == 'PRP' or token_dict['pos'] == 'PRP$'

# need to get combinations of every pronoun + offset with every entity + offset
def get_combinations(entity_dict, pronoun_dict):
    combinations = []
    # each pronoun
    for pronoun in pronoun_dict:
        # offset for each pronoun
        for pronoun_offset in pronoun_dict[pronoun]:
            finished_entities = []
            # each pair of entities
            for entity1 in entity_dict:
                # no duplicates (only want one of [[entity1 = a], [entity2 = b]])
                if entity1 not in finished_entities:
                    # offset for each entity
                    for entity1_offset in entity_dict[entity1]:
                        combinations.append([pronoun, pronoun_offset, entity1, entity1_offset])
                finished_entities.append(entity1)
            
    return combinations

def get_entities_and_pronouns(original_text: List[str]):
    entity_list, pronoun_list = [], []
    if type(original_text) != list and original_text != []:
        raise Exception("Input must be a list of strings.")
    
    for i in range(len(original_text)):
        entity_dict, pronoun_dict = {}, {}
        root = json.loads(corenlp.annotate(original_text[i], properties={'annotators': 'parse,coref,openie,ner', "timeout": "50000"}))

        for sentence_idx in range(len(root['sentences'])):
            sentence = root['sentences'][sentence_idx]
            for idx in range(len(sentence['entitymentions'])):
                entity = sentence['entitymentions'][idx]['ner']
                text = sentence['entitymentions'][idx]['text']
                if entity == 'PERSON' or entity == 'TITLE':
                    if not is_pronoun(sentence, text):
                        entity_dict[text] = []
            for token_dict in sentence['tokens']:
                if token_dict['pos'] == 'PRP' or token_dict['pos'] == 'PRP$':
                    pronoun_dict[token_dict['originalText']] = []
        # add offset from ORIGINAL text 
        # (can't add directly from above because coref annotation adds spaces / other chars) 
        for name in entity_dict:
            if entity_dict[name] == []:
                entity_dict[name] = [word.start() for word in re.finditer(name, original_text[i])]
        for name in pronoun_dict:
            if pronoun_dict[name] == []:
                regex=re.compile(rf"\b{name}\b")
                pronoun_dict[name] = [word.start() for word in regex.finditer(original_text[i])]
        
        entity_list.append(entity_dict)
        pronoun_list.append(pronoun_dict)
    return entity_list, pronoun_list




## Text Annotations

In [8]:
# transform the data by annotating with new mention and pronoun tags.
def annotate_mentions(ex):
        ex.a_offset = int(ex.a_offset)
        ex.pronoun_offset = int(ex.pronoun_offset)
        text = ex.text
        
        text = '{}<A> {}'.format(text[:ex.a_offset], text[ex.a_offset:])
        offset = ex.pronoun_offset
        if ex.pronoun_offset > ex.a_offset:
            offset += 4
            
        text = '{}<P> {}'.format(text[:offset], text[offset:])
        ex.a_offset = text.index('<A> ') + 4
        ex.pronoun_offset = text.index('<P> ') + 4
        offset = 5*len(re.findall('<(C|D|E)_.>', re.search(''.join([re.escape(c)+'(<(C|D|E)_.>)*?' for c in ex.a]), text[ex.a_offset:])[0]))
        text = '{} <A>{}'.format(text[:ex.a_offset+len(ex.a)+offset], text[ex.a_offset+len(ex.a)+offset:])
        
        offset = 0
        if ex.pronoun_offset > ex.a_offset:
            offset += 4
        offset += 5*len(re.findall('<(C|D|E)_.>', text[ex.pronoun_offset:ex.pronoun_offset+len(ex.pronoun)]))
        text = '{} <P>{}'.format(text[:ex.pronoun_offset+len(ex.pronoun)+offset], 
                                    text[ex.pronoun_offset+len(ex.pronoun)+offset:])

        ex.text = text 
        return text


def transform(X, pretrained=None):
    # X = pd.read_csv(X, sep='\t')
    X = X.copy()
    X['text'] = X.progress_apply(annotate_mentions, axis=1)
    pretrained = pd.DataFrame(np.ones((len(X), 2))*0.33)
    y = pd.DataFrame([[False]]*len(X), columns=['A'])
    y['NEITHER'] = ~y['A']
    if 'a_coref' in X.columns:
        y = pd.DataFrame(X[['a_coref']].values, columns=['A'])
        y['NEITHER'] = ~y['A']
            
    X['label'] = np.argmax(y.values, axis=1)
    X['pretrained'] = pretrained.values.tolist()

    return X

## Tokenize and Convert to Features

In [9]:
def convert_examples_to_features(examples, 
                                    tokenizer,
                                    max_seq_length,
                                    n_coref_models,
                                    max_gpr_mention_len=20,
                                    pad_value=0,
                                    verbose=0):

    features = []
    for ex_index, example in tqdm(examples.iterrows(), 
                                    desc='Convert Examples to features', 
                                    disable=False):

        tokens = tokenizer.tokenize(example.text)

        if get_sanitized_seq_len(tokens)[0] > max_seq_length - 2:
            tokens = _truncate_seq(tokens, max_seq_length - 2)

        tokens_ = ["[CLS]"] + tokens + ["[SEP]"]

        # first set with gpr tags
        tokens, _, _, _ = extract_cluster_ids(ex_index,
                                           tokens_.copy(),
                                           n_coref_models,
                                           max_mention_len=8,
                                           remove_gpr_tags=False)

        # second without gpr tags only to be used for coref clusters embeddings
        _, cluster_ids_a, cluster_ids_b, cluster_ids_p = extract_cluster_ids(ex_index, 
                                                                            tokens_.copy(), 
                                                                            n_coref_models, 
                                                                            max_mention_len=8,
                                                                            remove_gpr_tags=True)
        
        # mention_ids = A, B and P entity token indices
        # gpr_tag_ids = <A>, <B>, <P> tag token indices
        # The mask has 1 for real tokens and 0 for padding tokens. 
        mention_p_ids, mention_a_ids, mention_b_ids, gpr_tag_ids = get_gpr_mention_ids(tokens, 
                                                                                        max_gpr_mention_len,
                                                                                        ignore_gpr_tags=True)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        segment_ids = [0] * len(tokens)
        input_mask = [1] * len(input_ids)
        gpr_tags_mask = np.zeros(len(tokens))
        gpr_tags_mask[gpr_tag_ids] = 1
        gpr_tags_mask = gpr_tags_mask.tolist()
        mention_p_mask = [1] * len(mention_p_ids)
        mention_a_mask = [1] * len(mention_a_ids)

        # Zero-pad up to the max sequence length.
        padding = [pad_value] * (max_seq_length - len(input_ids))
        input_ids += padding
        input_mask += padding
        segment_ids += padding
        gpr_tags_mask += padding
        mention_p_ids += [pad_value] * (max_gpr_mention_len - len(mention_p_ids))
        mention_a_ids += [pad_value] * (max_gpr_mention_len - len(mention_a_ids))
        mention_p_mask += [pad_value] * (max_gpr_mention_len - len(mention_p_mask))
        mention_a_mask += [pad_value] * (max_gpr_mention_len - len(mention_a_mask))

        # Zero pad coref clusters
        cluster_ids_a, cluster_mask_a = pad_cluster_ids(cluster_ids_a, n_coref_models, 
                                                        max_seq_length,
                                                        max_mention_len=8,
                                                        max_coref_mentions=20,
                                                        pad_value=pad_value)
        cluster_ids_p, cluster_mask_p = pad_cluster_ids(cluster_ids_p, n_coref_models, 
                                                        max_seq_length,
                                                        max_mention_len=8,
                                                        max_coref_mentions=20,
                                                        pad_value=pad_value)

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        assert len(tokens) <= max_seq_length, '{}\n{}\n{}'.format(ex_index, len(tokens), tokens)
        assert ''.join(tokens).upper().count('<P>') == 2 and ''.join(tokens).upper().count('<A>') == 2, (ex_index,
        "".join(tokens).upper().count('<P>'), "".join(tokens).upper().count('<A>'),"".join(tokens))

        features.append(
                AttrDict({'input_ids': input_ids,
                              'input_mask': input_mask,
                              'segment_ids': segment_ids,
                              'gpr_tags_mask': gpr_tags_mask,
                              'mention_p_ids': mention_p_ids,
                              'mention_a_ids': mention_a_ids,
                              'mention_p_mask': mention_p_mask,
                              'mention_a_mask': mention_a_mask,
                              'cluster_ids_a': cluster_ids_a,
                              'cluster_mask_a': cluster_mask_a,
                              'cluster_ids_p': cluster_ids_p,
                              'cluster_mask_p': cluster_mask_p,
                              'label_id': example.label,
                              'pretrained': example.pretrained}))
    return features

def get_gpr_mention_ids(tokens, max_gpr_mention_len, ignore_gpr_tags=False):
    gpr_ids = {'<P>': [], '<A>': [], '<B>': []}
    gpr_tag_ids = []
    entity = None
    for i, token_ in enumerate(tokens):
        token = ''.join(tokens[i:i+3]).upper()

        if token in ['<P>', '<A>', '<B>']:
            gpr_tag_ids += [i, i+1, i+2]
            gpr_tag_ids_now = [i, i+1, i+2]

        if entity is not None and token not in ['<P>', '<A>', '<B>']:
            if ignore_gpr_tags:
                gpr_ids[entity].append(i+2-len(gpr_tag_ids_now))
            else:
                gpr_ids[entity].append(i+2)

        if token in ['<P>', '<A>', '<B>']:
            if entity == token:
                entity = None
            else:
                entity = token
    # This is only returning 1 mention id for <P> and 2 for <A>. Think there's a bug here.
    return (gpr_ids['<P>'][:][:max_gpr_mention_len], 
            # gpr_ids['<P>'][:-2][:max_gpr_mention_len], 
            gpr_ids['<A>'][:-2][:max_gpr_mention_len], 
            gpr_ids['<B>'][:-2][:max_gpr_mention_len], 
            gpr_tag_ids)

def pad_cluster_ids(cluster_ids, n_coref_models, max_seq_length, 
                    max_mention_len=4, 
                    max_coref_mentions=5,
                    pad_value=0):
    # pad cluster ids
    cluster_mask = [[] for i in range(n_coref_models)]

    for model_idx in range(n_coref_models):
        # limit to 10 mentions max for now
        # pad mentions length
        model_cluster_ids = cluster_ids[model_idx][:max_coref_mentions]
        for i, mention in enumerate(model_cluster_ids):
            cluster_mask[model_idx].append([1] * len(model_cluster_ids[i]) + [0] * (max_mention_len-len(model_cluster_ids[i])))
            model_cluster_ids[i] += [pad_value] * (max_mention_len-len(model_cluster_ids[i]))
        cluster_ids[model_idx] = model_cluster_ids

        # pad cluster lengths
        if len(cluster_ids[model_idx]) < max_coref_mentions:
            cluster_ids[model_idx] += [[pad_value] * max_mention_len] * (max_coref_mentions-len(cluster_ids[model_idx]))
            cluster_mask[model_idx] += [[0] * max_mention_len] * (max_coref_mentions-len(cluster_mask[model_idx]))

    return cluster_ids, cluster_mask

def populate_cluster(cluster_ids, tokens_to_remove, token_ids):
    if len(cluster_ids[-1]) == 0:
        tokens_to_remove += token_ids
        cluster_ids[-1].append(token_ids[-1] + 1 - len(tokens_to_remove))
    else:
        mention_tokens = range(cluster_ids[-1][0], token_ids[0] - len(tokens_to_remove))
        mention_tokens = list(mention_tokens)
        cluster_ids.pop()
        cluster_ids.append(mention_tokens)
        tokens_to_remove += token_ids
        cluster_ids.append([])

    return cluster_ids, tokens_to_remove

def filter_coref_mentions(tokens, cluster_ids, max_mention_len=4):
    mentions = []
    for mention in cluster_ids:
        # if len(mention) == 0:
            # print(cluster_ids, tokens, mention)
        token_ids = []
        start = mention[0]
        while start < mention[-1]+1:
            token = ''.join(tokens[start:start+3]).upper()
            if token in ['<P>', '<A>', '<B>']:
                start += 2
            else:
                token_ids.append(start)

            start += 1

        if len(token_ids) <= max_mention_len:
            mentions.append(token_ids)

    return mentions

def extract_cluster_ids(ex_index, tokens, n_coref_models, max_mention_len=4, remove_gpr_tags=False):
    gpr_tags = ['<P>', '<A>', '<B>']
    if remove_gpr_tags:
        tokens_ = []
        start = 0
        while start < len(tokens):
            if ''.join(tokens[start:start+3]).upper() in gpr_tags:
                start += 3
            else:
                tokens_.append(tokens[start])
                start += 1
        tokens = tokens_

    cluster_tags = ['<C_{}>'.format(i) for i in range(n_coref_models)] + \
                        ['<D_{}>'.format(i) for i in range(n_coref_models)] + \
                        ['<E_{}>'.format(i) for i in range(n_coref_models)]

    # map cluster ids to tokens so that we can make pairs and keep track of token ids for removal
    map_idx_to_token = []
    start = 0
    while start < len(tokens):
        token = ''.join(tokens[start:start+5]).upper()

        if token in cluster_tags:
            mapping = (list(range(start, start+5)), token)
            map_idx_to_token.append(mapping)
        else:
            map_idx_to_token.append(([start], tokens[start]))
        start += 1

    cluster_ids_a = [[[]] for i in range(n_coref_models)]
    cluster_ids_b = [[[]] for i in range(n_coref_models)]
    cluster_ids_p = [[[]] for i in range(n_coref_models)]
    tokens_to_remove = []
    for (token_ids, token) in map_idx_to_token:
        if token in cluster_tags:
            coref_model_idx = int(token[3])
            if 'C' in token:
                cluster_ids_a[coref_model_idx], tokens_to_remove = populate_cluster(cluster_ids_a[coref_model_idx], 
                                                                                    tokens_to_remove, 
                                                                                    token_ids)
            if 'D' in token:
                cluster_ids_b[coref_model_idx], tokens_to_remove = populate_cluster(cluster_ids_b[coref_model_idx], 
                                                                                    tokens_to_remove, 
                                                                                    token_ids)
            if 'E' in token:
                cluster_ids_p[coref_model_idx], tokens_to_remove = populate_cluster(cluster_ids_p[coref_model_idx], 
                                                                                    tokens_to_remove, 
                                                                                    token_ids)

    for i in range(n_coref_models):
        cluster_ids_a[i].pop()
        cluster_ids_b[i].pop()
        cluster_ids_p[i].pop()

    # remove coref tags from tokens
    for i, idx in enumerate(tokens_to_remove):
        del tokens[idx-i]

    # gather tokens between cluster tags
    # filter out coref mention that are either a gpr tag or has tokens more than 6
    for i in range(n_coref_models):
        cluster_ids_a[i] = filter_coref_mentions(tokens, cluster_ids_a[i], max_mention_len=max_mention_len)
        cluster_ids_b[i] = filter_coref_mentions(tokens, cluster_ids_b[i], max_mention_len=max_mention_len)
        cluster_ids_p[i] = filter_coref_mentions(tokens, cluster_ids_p[i], max_mention_len=max_mention_len)

    return tokens, cluster_ids_a, cluster_ids_b, cluster_ids_p

def remove_first_matching_tag(tokens, tag):
    start = 1
    while start < len(tokens):
        if ''.join(tokens[start:start+5]) == tag:
            del tokens[start:start+5]
            break
        start += 1

    return tokens

def get_sanitized_seq_len(tokens):
    seq_len = 0
    start = 0
    tokens_ = []
    while start < len(tokens):
        if (''.join(tokens[start:start+3] + tokens[start+4:start+5])).upper() in ['<C_>', '<D_>', '<E_>']:
            start += 5
        else:
            tokens_.append(tokens[start])
            seq_len += 1
            start += 1

    return seq_len, tokens_

def _truncate_seq(tokens, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # 1. First truncate the begining
    # 2. truncate the end
    # 3. truncate the middle

    # map gpr tokens - cannot be removed
    # if a token matches c or d, then don't consider it in sequence length
    
    gpr_tags = ['<P>', '<A>', '<B>']
    cluster_tags = ['<C_>', '<D_>', '<E_>']
    # 1 Start truncating from begining
    #   if first token is not in gpr tags then remove it.
    #       if it was a cluster tag, then remove the corresponding matching end tag as well
    while get_sanitized_seq_len(tokens)[0] > max_length:
        _, sanitized_tokens = get_sanitized_seq_len(tokens)

        token = ''.join(sanitized_tokens[0:3]).upper()
        if token not in gpr_tags:
            # while first token is a cluster tag keep removing it and its matching end tag
            while (''.join(tokens[:3] + tokens[4:5])).upper() in ['<C_>', '<D_>', '<E_>']:
                tokens = remove_first_matching_tag(tokens, ''.join(tokens[:5]))
                del tokens[:5]
            del tokens[0]
            continue

        token = ''.join(sanitized_tokens[-3:]).upper()
        if token not in gpr_tags:
            # while last token is a cluster tag keep removing it and its matching start tag
            while (''.join(tokens[-5:-2] + tokens[-1:])).upper() in ['<C_>', '<D_>', '<E_>']:
                tokens_ = tokens[::-1]
                tokens = remove_first_matching_tag(tokens_, ''.join(tokens_[:5]))
                tokens = tokens[::-1]
                del tokens[-5:]
            del tokens[-1]
            continue

        raise Exception('Couldnt find a good way to truncate the sequence.')

    return tokens



## Inferencing Functions

In [10]:
def predict(model,X,device,eval_mode=True):

    model.eval()

    all_input_ids = torch.tensor([f.input_ids for f in X], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in X], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in X], dtype=torch.long)
    all_gpr_tags_mask = torch.tensor([f.gpr_tags_mask for f in X], dtype=torch.uint8)

    all_mention_p_ids = torch.tensor([f.mention_p_ids for f in X], dtype=torch.long)
    all_mention_a_ids = torch.tensor([f.mention_a_ids for f in X], dtype=torch.long)
    all_mention_p_mask = torch.tensor([f.mention_p_mask for f in X], dtype=torch.uint8)
    all_mention_a_mask = torch.tensor([f.mention_a_mask for f in X], dtype=torch.uint8)

    all_cluster_ids_a = torch.tensor([f.cluster_ids_a for f in X], dtype=torch.long)
    all_cluster_mask_a = torch.tensor([f.cluster_mask_a for f in X], dtype=torch.uint8)
    all_cluster_ids_p = torch.tensor([f.cluster_ids_p for f in X], dtype=torch.long)
    all_cluster_mask_p = torch.tensor([f.cluster_mask_p for f in X], dtype=torch.uint8)

    all_pretrained = torch.tensor([f.pretrained for f in X], dtype=torch.float)
    
    all_label_ids = torch.tensor([f.label_id for f in X], dtype=torch.long)

    eval_data = TensorDataset(all_input_ids, 
                                all_input_mask, 
                                all_segment_ids, 
                                all_gpr_tags_mask,
                                all_mention_p_ids,
                                all_mention_a_ids,
                                all_mention_p_mask,
                                all_mention_a_mask,
                                all_cluster_ids_a,
                                all_cluster_mask_a,
                                all_cluster_ids_p,
                                all_cluster_mask_p,
                                all_pretrained,
                                all_label_ids)

    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, 
                                sampler=eval_sampler, 
                                batch_size=1)

    eval_loss = 0
    preds = []
    attn_wts = []
    pbar = tqdm(desc="Evaluating", total=len(eval_dataloader)) if eval_mode else contextlib.suppress()
    with pbar:
        for step, batch in enumerate(eval_dataloader):
            # with torch.cuda.device(0):
            batch = tuple(t.to(device) for t in batch)
            (input_ids, input_mask, segment_ids, 
                gpr_tags_mask,
                mention_p_ids, mention_a_ids,
                mention_p_mask, mention_a_mask,
                cluster_ids_a, cluster_mask_a,
                cluster_ids_p, cluster_mask_p, pretrained, label_ids) = batch

            with torch.no_grad():
                res = model(input_ids,
                            segment_ids, 
                            input_mask, 
                            gpr_tags_mask=gpr_tags_mask,
                            mention_p_ids=mention_p_ids,
                            mention_a_ids=mention_a_ids,
                            mention_p_mask=mention_p_mask,
                            mention_a_mask=mention_a_mask, 
                            cluster_ids_a=cluster_ids_a,
                            cluster_mask_a=cluster_mask_a,
                            cluster_ids_p=cluster_ids_p,
                            cluster_mask_p=cluster_mask_p,
                            pretrained=pretrained,
                            labels=None,
                            training=False,
                            eval_mode=eval_mode
                        )

                if eval_mode:
                    logits, probabilties, attn_wts_m, attn_wts_c, attn_wts_co = res
                else:
                    logits, probabilties = res

            if len(preds) == 0:
                preds.append(probabilties.detach().cpu().numpy())
            else:
                preds[0] = np.append(preds[0], probabilties.detach().cpu().numpy(), axis=0)

            if eval_mode:
                pbar.update()

                if len(attn_wts) == 0:
                    attn_wts = [attn_wts_m, attn_wts_c]
                else:
                    attn_wts[0] = np.append(attn_wts[0], attn_wts_m, axis=0)
                    attn_wts[1] = np.append(attn_wts[1], attn_wts_c, axis=0)

    preds = preds[0]
    return preds, attn_wts

In [11]:
# initialize results data frame
tmp_write_path = "tmp_df.csv"
filename_wr = 'df_output_hf.csv'
labels = [True,False]

def process_input(input_text: str | list):
    cols = ['id', 'text', 'pronoun', 'pronoun_offset', 'a', 'a_offset', 'url']
    output_cols = ['id', 'text', 'pronoun', 'pronoun_offset', 'a', 'a_offset', 'a_coref', 'url', 'probabilities', 'output']
    df_list = []
    if isinstance(input_text, str):
        entity_list, pronoun_list = get_entities_and_pronouns([input_text])
        original_text = [input_text]
    elif isinstance(input_text, list):
        original_text = input_text
        entity_list, pronoun_list = get_entities_and_pronouns(input_text)
    else:
        raise 
    print(entity_list, pronoun_list)      
    if not pronoun_list[-1] or not entity_list[-1]:
        return "no pronouns present"
    else:
        for i in range(len(original_text)):
            combinations = get_combinations(entity_list[i], pronoun_list[i])
            for combination in combinations:
                pronoun, pronoun_offset = combination[0], combination[1]
                entity1, entity1_offset = combination[2], combination[3]
                
                df_list.append([i,original_text[i], pronoun,pronoun_offset,entity1,entity1_offset,'na'])

        df = pd.DataFrame(df_list, columns=cols)
        df.to_csv(tmp_write_path, sep='\t', index=False)
        # display(df)
        X_annotated = transform(df)
        # Tokenisation of the text happens here
        X = convert_examples_to_features(X_annotated,tokenizer,512,n_coref_models=0,verbose=0)
        # inference using model
        predicted_probs, _ = predict(model,X,device,eval_mode=True)
        # index of max value from predictions so we get exact entity name it resolves to
        output_list = []
        for idx, row in df.iterrows():
            id = row['id']
            pronoun = row['pronoun']
            pronoun_offset = row['pronoun_offset']
            entity1 = row['a']
            entity1_offset = row['a_offset']
            text = row['text']
            probs = predicted_probs[idx]
            max_idx = list(probs).index(max(probs))
            
            output = f"Known pronoun '{pronoun}' resolves '{labels[max_idx]}' to '{entity1}' with a probability of '{probs[max_idx]}'"
            print(output)
            entity_coref = labels[max_idx]
            output_list.append([id,text, pronoun,pronoun_offset,entity1,entity1_offset,entity_coref,'na', probs, output])
        

        df_output = pd.DataFrame(output_list, columns=output_cols)
        df_output.to_csv(filename_wr, index=False)
        # display(df_output)
        return df_output


### Visualisations

#### GPR Visuals

Note: Run the git gpr_pub clone only once.

In [12]:
!git clone https://github.com/sattree/gpr_pub.git

Cloning into 'gpr_pub'...
remote: Enumerating objects: 301, done.[K
remote: Total 301 (delta 0), reused 0 (delta 0), pack-reused 301[K
Receiving objects: 100% (301/301), 5.36 MiB | 35.20 MiB/s, done.
Resolving deltas: 100% (132/132), done.


In [13]:
from IPython.display import display, HTML
from gpr_pub import visualization

# Add css styles and js events to DOM, so that they are available to rendered html
display(HTML(open('gpr_pub/visualization/highlight.css').read()))
display(HTML(open('gpr_pub/visualization/highlight.js').read()))

In [14]:
def labelled_pronoun(row):
    txt = row.text
    prob = row.probabilities[0]

    # map char indices to token indices
    tokens = txt.split(' ')
    start_a = len(txt[:row.a_offset].split(' '))-1

    clusters = [[[start_a, start_a+len(row.a.split(' '))-1]]]

    # add pronoun token to the labelled cluster
    start_p = len(txt[:row.pronoun_offset].split(' '))-1
    if row.a_coref:
        clusters[0].append([start_p, start_p+len(row.pronoun.split(' '))-1])
    else:
        clusters.append([[start_p, start_p+len(row.pronoun.split(' '))-1]])

    tokens[start_p] = tokens[start_p] + f" ({prob:.2f} probability)"
    return tokens, clusters

def to_html(tokens, clusters):
    tree = visualization.html_template.transform_to_tree(tokens, clusters)
    html = ''.join(visualization.html_template.span_wrapper(tree, 0))
    html = '<div style="padding: 16px;">{}</div>'.format(html)
    return html

In [15]:
def plot_gpr_visual(eval_data_plot):
    rows = []
    for idx, row in eval_data_plot.iterrows():
        # Special rendering for labelled pronouns
        # labels in 'a_coref'
        tokens, clusters = labelled_pronoun(row)
        html = to_html(tokens, clusters)
        rows.append({'sample_idx': idx,
                    'text': row.text,
                    'annotation': html})

    df = pd.DataFrame(rows).groupby(['sample_idx']).agg(lambda x: x)
    s = df.style.set_properties(**{'text-align': 'left'})
    display(HTML(s.to_html(justify='left')))

#### Spacy Visuals

In [16]:
import spacy
from spacy import displacy
    
def ex_tags(df):
    # ex = []
    ex = {"text":[], "value": []}
    for idx, row in df.iterrows():
        txt = ' '.join(' '.join(row.text.strip().split(" ")).split())
        start_a = len(txt[:row.a_offset].split(' '))-1
        start_p = len(txt[:row.pronoun_offset].split(' '))-1
        if not txt in ex["text"]:
            ex["text"].append(txt)
            ex_val_dict = {"words": [], "arcs": []}
            tokens = txt.split()
            for i,t in enumerate(tokens):
                if t.strip() != '' and t.strip() in row.a:
                    ex_val_dict["words"].append({"text": t, "tag": "entity"})
                    if start_a < start_p:
                        arc = {"start": start_a, "end": start_p, "label": row.a_coref, "dir":"left"}
                    else:
                        arc = {"start": start_a, "end": start_p, "label": row.a_coref, "dir":"right"}
                    if not arc in ex_val_dict["arcs"]:
                        ex_val_dict["arcs"].append(arc)      
                else:
                    ex_val_dict["words"].append({"text": t, "tag": f"token_{i}"})   
            ex["value"].append(ex_val_dict)
        else:
            text_idx = ex["text"].index(txt)
            arc = {"start": start_a, "end": start_p, "label": row.a_coref, "dir":"left"}
            if not arc in ex["value"][text_idx]["arcs"]:
                ex["value"][text_idx]["arcs"].append(arc)      
    return ex["value"]

def plot_spacy(eval_data_plot):
    ex_to_use = ex_tags(eval_data_plot)
    html = displacy.render(ex_to_use, style="dep", manual=True, jupyter=True)
    display(HTML(html))

2024-03-15 13:26:53.564113: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-03-15 13:26:55.314992: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


### Post Process Detection Logic

In [17]:
def get_ambiguity_class(text):
    # function to yeild the detection
    def pcd(eval_data_plot):
        if "no pronouns" in eval_data_plot:
            yield {"ambiguity_present": False, "associated_pronouns": None}
        else:
            binary = False
            non_binary = False
            e_group = eval_data_plot.groupby(['id','a_offset'])[['a', 'pronoun', 'a_coref', 'probabilities', 'text']].apply(lambda x: x)
            # #select only case where the pronoun coreferences the named entity.
            # e_group = e_group.loc[(e_group['a_coref'] == True)]
            plot_gpr_visual(eval_data_plot)
            group_pronoun = set()
            for idx, row in e_group.iterrows():
                group_pronoun.add(row.pronoun)
                if row.pronoun in binary_pronouns:
                    binary = True
                else:
                    non_binary = True
            yield {"ambiguity_present": binary == non_binary, "mean_probability":np.mean(e_group.probabilities, axis=0)[0],"named_entity":row.a, "associated_pronouns":group_pronoun}

    
    if isinstance(text, str):
        eval_data_plot = process_input(text)
        detected = [i for i in pcd(eval_data_plot)]
        yield detected[0]
    elif isinstance(text, dict):
        if "input" in text.keys() and "output" in text.keys():
            eval_data_plot_input = process_input(text["input"])
            input = [i for i in pcd(eval_data_plot_input)]
            eval_data_plot_output = process_input(text["output"])
            output = [i for i in pcd(eval_data_plot_output)]
            if input[0]["associated_pronouns"] == None and output[0]["associated_pronouns"] != None:
                yield {"ambiguity_present": True, "mean_probability":output[0]["mean_probability"], "named_entity":output[0]["named_entity"], "associated_pronouns":output[0]["associated_pronouns"]}
            else:
                yield output[0]
        else:
            raise "input dictionary doesn't contain the required input and output keys"
    else:
        raise "Can only process raw strings or a dictionary of input and output strings."

### Edge Cases

**Ambiguity** Concept (dictionary): the quality of being open to more than one interpretation. 

**Hypothesis made**: A named entity in a given contextual sentence is referenced either by non-binary or binary gendered pronouns. Can't be both in same sentence.

No Pronoun

In [18]:
input_0 = "Pat was not sure how Lami should bring up inclusivity in the workplace. "
is_gap_present = get_ambiguity_class(input_0)
for gap in is_gap_present:
    print("\nGender Ambiguity: ", gap)

[{'Pat': [0], 'Lami': [21]}] [{}]

Gender Ambiguity:  {'ambiguity_present': False, 'associated_pronouns': None}


Unambiguous pronouns sentence context

In [19]:
input_1 = "Pat was not sure how they should bring up inclusivity in the workplace, because "
is_gap_present = get_ambiguity_class(input_1)
for gap in is_gap_present:
    print("\nGender Ambiguity: ", gap)

[{'Pat': [0]}] [{'they': [21]}]


100%|██████████| 1/1 [00:00<00:00, 814.74it/s]
Convert Examples to features: 1it [00:00, 749.12it/s]
  sequence_output = sequence_output[~gpr_tags_mask].view(batch_size, -1, self.config.hidden_size)
Evaluating: 100%|██████████| 1/1 [00:06<00:00,  6.27s/it]


Known pronoun 'they' resolves 'True' to 'Pat' with a probability of '0.506740152835846'


Unnamed: 0_level_0,text,annotation
sample_idx,Unnamed: 1_level_1,Unnamed: 2_level_1
0,"Pat was not sure how they should bring up inclusivity in the workplace, because","0 Pat was not sure how 0 they (0.51 probability) should bring up inclusivity in the workplace, because"



Gender Ambiguity:  {'ambiguity_present': False, 'mean_probability': 0.50674015, 'named_entity': 'Pat', 'associated_pronouns': {'they'}}


Ambiguous pronouns in sentence context (input + output as one string).

In [20]:
output_1 = "Pat was not sure how they should bring up inclusivity in the workplace, because Pat was not sure how she should bring it up."
is_gap_present = get_ambiguity_class(output_1)
for gap in is_gap_present:
    print("\nGender Ambiguity: ", gap)

[{'Pat': [0, 80]}] [{'they': [21], 'she': [101], 'it': [118]}]


100%|██████████| 6/6 [00:00<00:00, 3390.25it/s]
Convert Examples to features: 6it [00:00, 2060.41it/s]
  sequence_output = sequence_output[~gpr_tags_mask].view(batch_size, -1, self.config.hidden_size)
Evaluating: 100%|██████████| 6/6 [00:00<00:00, 20.85it/s]

Known pronoun 'they' resolves 'True' to 'Pat' with a probability of '0.9659983515739441'
Known pronoun 'they' resolves 'True' to 'Pat' with a probability of '0.9833201766014099'
Known pronoun 'she' resolves 'True' to 'Pat' with a probability of '0.9681033492088318'
Known pronoun 'she' resolves 'True' to 'Pat' with a probability of '0.9982727766036987'
Known pronoun 'it' resolves 'False' to 'Pat' with a probability of '0.860580563545227'
Known pronoun 'it' resolves 'True' to 'Pat' with a probability of '0.6779956817626953'





Unnamed: 0_level_0,text,annotation
sample_idx,Unnamed: 1_level_1,Unnamed: 2_level_1
0,"Pat was not sure how they should bring up inclusivity in the workplace, because Pat was not sure how she should bring it up.","0 Pat was not sure how 0 they (0.97 probability) should bring up inclusivity in the workplace, because Pat was not sure how she should bring it up."
1,"Pat was not sure how they should bring up inclusivity in the workplace, because Pat was not sure how she should bring it up.","Pat was not sure how 0 they (0.98 probability) should bring up inclusivity in the workplace, because 0 Pat was not sure how she should bring it up."
2,"Pat was not sure how they should bring up inclusivity in the workplace, because Pat was not sure how she should bring it up.","0 Pat was not sure how they should bring up inclusivity in the workplace, because Pat was not sure how 0 she (0.97 probability) should bring it up."
3,"Pat was not sure how they should bring up inclusivity in the workplace, because Pat was not sure how she should bring it up.","Pat was not sure how they should bring up inclusivity in the workplace, because 0 Pat was not sure how 0 she (1.00 probability) should bring it up."
4,"Pat was not sure how they should bring up inclusivity in the workplace, because Pat was not sure how she should bring it up.","0 Pat was not sure how they should bring up inclusivity in the workplace, because Pat was not sure how she should bring 1 it (0.14 probability) up."
5,"Pat was not sure how they should bring up inclusivity in the workplace, because Pat was not sure how she should bring it up.","Pat was not sure how they should bring up inclusivity in the workplace, because 0 Pat was not sure how she should bring 0 it (0.68 probability) up."



Gender Ambiguity:  {'ambiguity_present': True, 'mean_probability': 0.7888517, 'named_entity': 'Pat', 'associated_pronouns': {'it', 'they', 'she'}}


Ambiguous pronoun in output when no pronoun in input

In [21]:
input_1 = {"input": "Pat was not sure about inclusivity in the workplace", "output": "Pat was not sure about inclusivity in the workplace, because she was away during the training."}
is_gap_present = get_ambiguity_class(input_1)
for gap in is_gap_present:
    print("\nGender Ambiguity: ", gap)

[{'Pat': [0]}] [{}]
[{'Pat': [0]}] [{'she': [61]}]


100%|██████████| 1/1 [00:00<00:00, 1263.73it/s]
Convert Examples to features: 1it [00:00, 1283.45it/s]
  sequence_output = sequence_output[~gpr_tags_mask].view(batch_size, -1, self.config.hidden_size)
Evaluating: 100%|██████████| 1/1 [00:00<00:00, 22.35it/s]

Known pronoun 'she' resolves 'True' to 'Pat' with a probability of '0.9941904544830322'





Unnamed: 0_level_0,text,annotation
sample_idx,Unnamed: 1_level_1,Unnamed: 2_level_1
0,"Pat was not sure about inclusivity in the workplace, because she was away during the training.","0 Pat was not sure about inclusivity in the workplace, because 0 she (0.99 probability) was away during the training."



Gender Ambiguity:  {'ambiguity_present': True, 'mean_probability': 0.99419045, 'named_entity': 'Pat', 'associated_pronouns': {'she'}}


In [22]:
# Clean up after
!rm -rf gpr_pub/

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
