In [1]:
import math
import matplotlib.pyplot as plt
from nltk.tree import Tree
import numpy as np
from copy import deepcopy
import pandas as pd
import re
import seaborn as sns
import statistics
from transformers import AutoModel, AutoTokenizer
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dir = "ch_copy/00/"
doc_id = '0037'

#returns dictionary containing plain sentence, constituency parse and coreference-id-labeled tokens 
def get_sentence_profiles(doc_id):
    profiles = {}
    filename = "ch_" + doc_id + ".onf"
    fp = dir + filename
    with open(fp, "r") as f:
        lines = f.readlines()
    #enumerated = enumerate(lines) -- don't do this
    plain_sent_idxs = [i for i, txt in enumerate(lines) if txt == ("Plain sentence:\n")]
    treebanked_idxs = [i for i, txt in enumerate(lines) if txt.find("Treebanked sentence:\n") > -1]
    tree_idxs = [i for i, txt in enumerate(lines) if txt == ("Tree:\n")]
    leaves_idxs = [i for i, txt in enumerate(lines) if txt == ("Leaves:\n")]

    n_sents = len(plain_sent_idxs)

    for i in range(n_sents):
        profile = doc_id + "_" + str(i)
        profiles[profile] = {}
        profiles[profile]["plain"] = lines[(plain_sent_idxs[i] + 2):treebanked_idxs[i]][0].strip()
        profiles[profile]["plain"] = profiles[profile]["plain"].replace("--", "")
        raw_tree = lines[(tree_idxs[i]+2):leaves_idxs[i]]
        profiles[profile]["tree"] = process_tree(raw_tree)

        if i < n_sents - 1:
            profiles[profile]["leafnotes"] = lines[(leaves_idxs[i]+2):(plain_sent_idxs[i+1]-3)]
        else:
            profiles[profile]["leafnotes"] = lines[(leaves_idxs[i]+2):-3]

        profiles[profile]["leafnotes"] = process_leaves(profiles[profile]["leafnotes"])
        profiles[profile]["leaves"] = [leaf for leaf in Tree.fromstring(profiles[profile]["tree"]).leaves() if '*' not in leaf]

        #i += 1
    return profiles

def process_leaves(leafnotes):
    leaves_dict = {}
    i = 0
    for line in leafnotes:
        line = line.strip()
        if len(line) > 0:
            line = line.split()
            if line[0].isdigit():
                if line[1] != "--":
                    leaves_dict[i] = {"token":"", "info": {}}
                    leaves_dict[i]["token"] = line[1]
                    i += 1
            elif i-1 in leaves_dict:
                leaves_dict[i-1]["info"][line[0]] = line[1:]
       
    return leaves_dict

def process_tree(raw_tree):
    tree = ""
    for line in raw_tree:
        tree += line.strip()
    return tree

def get_doc_profile(doc_id):
    profile = {}
    sentence_profiles = get_sentence_profiles(doc_id)
    profile["plain"] = ""
    for id in sentence_profiles:
        profile["plain"] += sentence_profiles[id]["plain"]
    profile["trees"] = [sentence_profiles[profile]["tree"] for profile in sentence_profiles]

    #creates a list of dictionaries containing leafnotes for each sentence
    profile["leaves_per_tree"] = []
    for id in sentence_profiles:
        profile["leaves_per_tree"].append(deepcopy(sentence_profiles[id]["leafnotes"])) #must deepcopy or dictionaries will merge

    #merged leaves
    profile["leafnotes"] = sentence_profiles[doc_id + "_0"]["leafnotes"]
    sentence_profiles_tail = dict(list(sentence_profiles.items())[1:])
    i = list(profile["leafnotes"].keys())[-1]
    for id in sentence_profiles_tail:
        for leaf in sentence_profiles[id]['leafnotes']:
            i += 1
            profile["leafnotes"][i] = sentence_profiles_tail[id]['leafnotes'][leaf]


    #words
    profile["leaves"] = []
    for tree in profile["trees"]:
        leaves = Tree.fromstring(tree).leaves()
        profile["leaves"] += ([leaf for leaf in leaves if '*' not in leaf])

    return profile


In [3]:
s_profiles = get_sentence_profiles(doc_id)
doc_profile = get_doc_profile(doc_id)
sample = s_profiles['0037_0']#['plain']

In [13]:
print(doc_profile['leafnotes'])
'''print(s_profiles[doc_id+"_0"]["leafnotes"])
print(s_profiles[doc_id+"_0"]["leaves"])'''



'print(s_profiles[doc_id+"_0"]["leafnotes"])\nprint(s_profiles[doc_id+"_0"]["leaves"])'

In [3]:
#set model and tokenizer
model_name = "TsinghuaAI/CPM-Generate"#"ckiplab/gpt2-base-chinese" #"hfl/chinese-bert-wwm" #"bert-base-chinese"
#tokenizer = "hfl/chinese-bert-wwm"
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)

In [15]:
#align annotated leaves with model inputs
'''def filter_inputs(inputs, tokens, leaf_labels): #manipulate output matrix instead?
    inputs['input_ids'] = inputs['input_ids'].numpy()
    inputs[]
    for i in range(len(tokens)):
        if i >= len(leaf_labels):
            break
        if tokens[i][:1] != leaf_labels[i]:

            del inputs[i]
    return inputs'''

"def filter_inputs(inputs, tokens, leaf_labels): #manipulate output matrix instead?\n    inputs['input_ids'] = inputs['input_ids'].numpy()\n    inputs[]\n    for i in range(len(tokens)):\n        if i >= len(leaf_labels):\n            break\n        if tokens[i][:1] != leaf_labels[i]:\n\n            del inputs[i]\n    return inputs"

In [4]:
def filter_attn(attn_matrix, leaves, tokens):
    #attn_matrix: 2-D numpy array
    #print("tokens:",tokens)
    #print("leaves:",leaves)
    filter = []
    for i in range(len(leaves)):
        if i >= len(tokens):
            break
        if tokens[i] == '▁': #▁ is not an underscore (_)
            #print(tokens[i])
            del tokens[i]
        #print(tokens[i], leaves[i])
        if tokens[i].replace('▁', '') != leaves[i]:
            filter.append(i)
    attn_matrix = np.delete(attn_matrix, filter, axis = 0)
    attn_matrix = np.delete(attn_matrix, filter, axis = 1)
    return attn_matrix

In [9]:
#returns attention weights for specified head, layer
def attention_map(model, profile):
    inputs = tokenizer(profile["plain"], return_tensors="pt")
    if inputs["input_ids"].shape[1] > 512:
        inputs = tokenizer(profile["plain"][:520], return_tensors="pt")
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    outputs = model(**inputs)
    print(torch.stack(outputs.attentions).shape)
    attention_map = []
    attentions = torch.stack(outputs.attentions).cpu()#.squeeze(0).detach().numpy()
    for layer in range(attentions.shape[0]):
        for head in range(attentions.shape[2]):
            attention_matrix = attentions[layer][0][head].squeeze(0).detach().numpy()
            attention_matrix = filter_attn(attention_matrix, profile["leaves"], tokens)
            attention_map.append(attention_matrix)
    # = attention[layers.unsqueeze(0), :, heads, :, :]
    #print(selection.shape)
    #aggregate = attention.sum(dim = 0).detach().numpy()
    #print(aggregate.shape)
    
    
    #print(len(tokens), tokens)
    #punct_idxs = [idx for idx, token in enumerate(tokens) if token in [",", "[CLS]", "[SEP]"]]
    #tokens = [token for i, token in enumerate(tokens) if i not in punct_idxs]
    #attention_matrix = np.delete(attention_matrix, punct_idxs, axis = 1)
    #attention_matrix = np.delete(attention_matrix, punct_idxs, axis = 0)
    #attention_matrix = filter_attn(attention_matrix, profile["leaves"], tokens)
    return (attention_map, attentions.squeeze(0).detach().numpy()) #return aggregate.detach().numpy()
    #tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

In [10]:
#returns the index of some verb in the predicate corresponding to the null subject
def get_verb_idx(tree, nsubj_idx):
    t_positions = tree.treepositions(order = 'preorder')
    leaf_positions = [tree.leaf_treeposition(idx) for idx in range(len(tree.leaves()))]
    nsubj_leaf_position = tree.leaf_treeposition(nsubj_idx)
    #print("get_verb_idx nsubj_idx:", nsubj_idx)
    IP_idx = t_positions.index(nsubj_leaf_position) - 1
    IP_pos = t_positions[IP_idx]
    IP_tree = tree[IP_pos]

    while (type(IP_tree) == type("s") or IP_tree.label() not in ['INC', 'IP']):
        IP_idx -= 1
        IP_pos = t_positions[IP_idx]
        IP_tree = tree[IP_pos]

    '''if type(IP_tree) == type("string"):
        print(tree.leaves()[nsubj_idx])
        print(IP_tree)
        print(type(IP_tree), tree)'''
        
    IPt_positions = IP_tree.treepositions(order='preorder')
    mv_idx = None
    for i in range(len(IPt_positions)):
        position = IPt_positions[i]
        if type(IP_tree[position]) == str and ('V' in IP_tree[IPt_positions[i-1]].label() or 'NT' in IP_tree[IPt_positions[i-1]].label()): #to-do: optimize
            #print(IP_tree[IPt_positions[i-1]])
            mv_idx = leaf_positions.index(t_positions[IP_idx + i])
            #print(mv_idx)
            return mv_idx
    
    return None

In [None]:
'''def reindex_attn_matrix(attn_matrix, leaves):
    for leaf in leaves'''

In [11]:
#Look for layers adhering to Cf rankings
#...examine effect of context_size on first-mention effect
#returns ordered indices of candidate antecedents
'''def mention_sequence(tree):
    lleaves = tree.pos()
    return [(leaf[0], i) for i, leaf in enumerate(lleaves) if "N" in leaf[1] and "*-" not in leaf[0]]'''
def mention_idxs(trees):
    joined_trees_tags = []
    for tree in trees:
        joined_trees_tags += [leaf for leaf in tree.pos()]
    return [idx for idx, leaf in enumerate(joined_trees_tags) if "N" in leaf[1] and "*-" not in leaf[0]]
'''def mention_set(string):
    return set(mention_sequence(string)) #ensure that order is preserved'''

#Hypothesis 1：linear order (looking for first-mention and recency effect) #filter out effect of confound---generate new set of sentences, 
##check if effect disappears w/ bidirectional model
def coref_map(mention_idxs, leafnotes): #convert mention indices to coreference ids
    coref_ids = []
    for idx in mention_idxs:
        #print(idx, leafnotes[idx])
        if idx in leafnotes and "coref:" in leafnotes[idx]["info"]:
            #print(leafnotes[idx]["info"]["coref:"])
            try:
                coref_ids.append(int(leafnotes[idx]["info"]["coref:"][1]))
            except:
                #print("coref:", leafnotes[idx]["info"]["coref:"])
                coref_ids.append(-1)
        else:
            if idx not in leafnotes:
                print("leafnotes:", leafnotes)
            coref_ids.append(-1)
    return coref_ids

'''def hyp_1_pred(profile):
     #todo
     return None'''
                
def compare_rankings(hyp_rankings, attn_rankings):
     #diffs = [abs(attn_rankings.index(attn_idx) - hyp_rankings.index(attn_idx)) for attn_idx in attn_rankings]
     #print(hyp_rankings)
     #print(attn_rankings)
     #print(hyp_rankings, attn_rankings)
     total_diff = sum([hyp_rankings[i] == attn_rankings[i] for i in range(len(attn_rankings))])
     return total_diff if len(attn_rankings) == 0 else total_diff/len(attn_rankings)

#returns list of candidates ranked by sums of attentions corresponding to each candidate for a verb
'''def aggregate_attentions(candidate_attns, coref_ids): 
    coref_set = set(coref_ids)
    aggregates_per_mention = {id:0 for id in coref_set}
    for i in range(len(candidate_attns)):
        aggregates_per_mention[coref_ids[i]] += candidate_attns[i][0]
    print(aggregates_per_mention)
    new_rankings = sorted([(id, aggregates_per_mention[id]) for id in coref_ids], key = lambda tup: tup[1], reverse=True)
    new_rankings = list(dict.fromkeys(new_rankings))
    return new_rankings'''
def aggregate_attentions(candidate_attns, coref_ids):
    coref_to_score = {}
    for (idx, score), coref_id in zip(candidate_attns, coref_ids):
        if coref_id not in coref_to_score:
            coref_to_score[coref_id] = 0.0
        coref_to_score[coref_id] += score
    ranked = sorted(coref_to_score.items(), key=lambda x: -x[1])  # Descending
    return [coref_id for coref_id, score in ranked]

#to-do: aggregate attentions for each coreference id, then rank and compare with order of coreference ids as predicted by hypothesis -- done?

def test_hyp1_doc(profile, attention_maps, context_len=512):

    matrix_size = len(attention_maps[0])

    comparisons = {} #dictionary mapping attention matrix to ranking comparisons
    trees = [Tree.fromstring(tree) for tree in profile["trees"]]
    mentions = mention_idxs(trees)

    nsubj_counter = 0 #keep track of doc-level nsubj positions

    for i in range(len(trees)):
        leafnotes = profile["leaves_per_tree"][i]

        nsubj_idxs = [i for i in leafnotes if leafnotes[i]["token"] == "*pro*" and "info" in leafnotes[i]]

        for nsubj_idx in nsubj_idxs:
            
            #print(mentions)
            n_subj_doc_level_pos = nsubj_idx + nsubj_counter
            candidate_idxs = [idx for idx in mentions if idx < n_subj_doc_level_pos and idx < matrix_size]
            hyp_rankings = list(dict.fromkeys(coref_map(candidate_idxs, profile['leafnotes'])))
            try:
                verb_idx = get_verb_idx(trees[i], nsubj_idx) + nsubj_counter
            except:
                continue

            if verb_idx >= matrix_size:
                break
            #print("verb idx:", verb_idx)
            for map_idx in range(len(attention_maps)):
                map = attention_maps[map_idx]
                #print(verb_idx)
                
                candidate_attentions = [(candidate_idx, map[verb_idx][candidate_idx]) for candidate_idx in candidate_idxs] #don't sort, will be sorted by aggregate
                #print("candidates sorted:", candidate_attentions)
                corefs = coref_map([tup[0] for tup in candidate_attentions], profile['leafnotes'])
                #print("coref ids sorted:", corefs)
                attn_ranking = aggregate_attentions(candidate_attentions, corefs)
                #print("coref ranking by aggregate attention:", attn_ranking)
                comparison = compare_rankings(hyp_rankings, attn_ranking)
                if map_idx in comparisons:
                    comparisons[map_idx].append(comparison)
                else:
                    comparisons[map_idx] = [comparison]

        nsubj_counter += len(leafnotes)
                
    return comparisons
         
def test_hyp1(profile, attention_maps, doc_level=False, context_len=512):
    
    if doc_level:
         print("whole doc")
         return test_hyp1_doc(profile, attention_maps, context_len)
    else: 
         tree = Tree.fromstring(profile["tree"])
         all_mentions = mention_idxs([tree])

    comparisons = {}
    leafnotes = profile["leafnotes"]
    nsubj_idxs = [i for i in leafnotes if leafnotes[i]["token"] == "*pro*" and "info" in leafnotes[i]]
    for nsubj_idx in nsubj_idxs:
        verb_idx = get_verb_idx(tree, nsubj_idx)
        #print("verb idx:", verb_idx)
        candidate_idxs = [idx for idx in all_mentions if idx < nsubj_idx]
        hyp_rankings = coref_map(candidate_idxs, leafnotes)
        if verb_idx >= context_len:
            break
        for map_idx in range(len(attention_maps)):
            candidate_attentions = [(candidate_idx, map[verb_idx][candidate_idx]) for candidate_idx in candidate_idxs]
            corefs = coref_map([tup[0] for tup in candidate_attentions], profile['leafnotes'])
            attn_ranking = aggregate_attentions(candidate_attentions, corefs)
            #print("coref ranking by aggregate attention:", attn_ranking)
            comparison = compare_rankings(hyp_rankings, attn_ranking)
            if map_idx in comparisons:
                comparisons[map_idx].append(comparison)
            else:
                comparisons[map_idx] = [comparison]

    return comparisons
            
#to-do: character-level implementation

#Hypothesis 2: grammatical role (ranking by grammatical role)
#to-do: finish implementation

In [None]:
#compares attention rankings to hypothesized rankings
'''def test_hyp_1(profile, attention_maps, is_bidirectional = True):
    tree = Tree.fromstring(profile["tree"])
    leafnotes = profile["leafnotes"]
    candidate_idxs = mention_idxs(tree)
    nsubj_idxs = [i for i in leafnotes if leafnotes[i]["token"] == "*pro*" and "info" in leafnotes[i]]
    for nsubj_idx in nsubj_idxs:
        verb_idx = get_verb_idx(tree, nsubj_idx)
        candidate_idxs = [idx for idx in candidate_idxs if idx < nsubj_idx]
        for map in attention_maps:
            candidate_attentions = [map[verb_idx][candidate] for candidate in candidate_idxs]
            
            if is_bidirectional:
                candidate_attentions_reverse = [map[candidate][verb_idx] for candidate in candidate_idxs]'''
    #to-do: finish implementation
    
#hypotheses = {1: test_hyp_1}      

In [None]:
#viewer
attm_map = attention_map(model, doc_profile)
'''for map in attention_maps:
    figure = sns.heatmap(map, xticklabels=2, yticklabels=2)
    figure.invert_yaxis()
    plt.show()'''

torch.Size([32, 1, 32, 504, 504])


'for map in attention_maps:\n    figure = sns.heatmap(map, xticklabels=2, yticklabels=2)\n    figure.invert_yaxis()\n    plt.show()'

In [None]:
test_hyp1(doc_profile, attm_map, doc_level=True, context_len = 512)
#print(test_hyp1(sample, attention_map, doc_level=False, context_len = 512))


In [26]:
dfs = []
final_outputs = []
for i in range(1, 25):
    doc_id = ("000" if i < 10 else "00") + str(i)
    doc_profile = get_doc_profile(doc_id)
    attn_map, final = attention_map(model, doc_profile)
    final_outputs.append(final)
    rank_comps = test_hyp1(doc_profile, attn_map, doc_level = True) #dictionary mapping attention matrix to list of comparison scores
    df = pd.DataFrame.from_dict(rank_comps, orient = "index")
    df.columns = [(f"{str(i)}_" + str(colname)) for colname in df.columns]
    dfs.append(df)

dfs_all = pd.concat(dfs, axis=1)



torch.Size([32, 1, 32, 504, 504])
whole doc
torch.Size([32, 1, 32, 521, 521])
whole doc
torch.Size([32, 1, 32, 533, 533])
whole doc
torch.Size([32, 1, 32, 482, 482])
whole doc
torch.Size([32, 1, 32, 526, 526])
whole doc
torch.Size([32, 1, 32, 532, 532])
whole doc
torch.Size([32, 1, 32, 541, 541])
whole doc
torch.Size([32, 1, 32, 563, 563])
whole doc
torch.Size([32, 1, 32, 485, 485])
whole doc
torch.Size([32, 1, 32, 581, 581])
whole doc
torch.Size([32, 1, 32, 478, 478])
whole doc
torch.Size([32, 1, 32, 513, 513])
whole doc
torch.Size([32, 1, 32, 438, 438])
whole doc
torch.Size([32, 1, 32, 469, 469])
whole doc
torch.Size([32, 1, 32, 528, 528])
whole doc
torch.Size([32, 1, 32, 508, 508])
whole doc
torch.Size([32, 1, 32, 521, 521])
whole doc
torch.Size([32, 1, 32, 514, 514])
whole doc
torch.Size([32, 1, 32, 562, 562])
whole doc
torch.Size([32, 1, 32, 547, 547])
whole doc
torch.Size([32, 1, 32, 511, 511])
whole doc
torch.Size([32, 1, 32, 524, 524])
whole doc
torch.Size([32, 1, 32, 503, 503]

In [None]:
comparisons_by_doc = pd.DataFrame()
for i in range(24, 42):
    doc_id = ("000" if i < 10 else "00") + str(i)
    doc_profile = get_doc_profile(doc_id)
    attn_map, final = attention_map(model, doc_profile)
    #final_outputs.append(final)
    rank_comps = test_hyp1(doc_profile, attn_map, doc_level = True)
    comparisons_by_doc[doc_id] = [statistics.mean(rank_comps[i]) for i in rank_comps]

torch.Size([32, 1, 32, 516, 516])
whole doc
torch.Size([32, 1, 32, 571, 571])
whole doc
torch.Size([32, 1, 32, 510, 510])
whole doc
torch.Size([32, 1, 32, 477, 477])
whole doc
torch.Size([32, 1, 32, 501, 501])
whole doc
torch.Size([32, 1, 32, 491, 491])
whole doc
torch.Size([32, 1, 32, 515, 515])
whole doc
torch.Size([32, 1, 32, 494, 494])
whole doc
torch.Size([32, 1, 32, 520, 520])
whole doc
torch.Size([32, 1, 32, 487, 487])
whole doc
torch.Size([32, 1, 32, 495, 495])
whole doc
torch.Size([32, 1, 32, 562, 562])
whole doc
torch.Size([32, 1, 32, 551, 551])
whole doc
torch.Size([32, 1, 32, 516, 516])
whole doc
torch.Size([32, 1, 32, 493, 493])
whole doc
torch.Size([32, 1, 32, 499, 499])
whole doc
torch.Size([32, 1, 32, 540, 540])
whole doc
torch.Size([32, 1, 32, 524, 524])
whole doc


In [28]:
dfs_all.to_csv('comparisons_by_doc_hyp1_1.csv', index=False)

In [None]:
pooled = {}
for matrix in final_outputs:
    rank_comps = test_hyp1(doc_profile, [matrix], doc_level = True)
    pooled[doc_id] = [statistics.mean(rank_comps[i]) for i in rank_comps]

pooled.csv('pooled_hyp1.csv', index = False)