In [3]:
import math
import matplotlib.pyplot as plt
from nltk.tree import Tree
import numpy as np
from copy import deepcopy
import pandas as pd
import re
import seaborn as sns
import statistics
from transformers import AutoModel, AutoTokenizer
import torch

In [4]:
#set model and tokenizer
model_name = "TsinghuaAI/CPM-Generate"#"ckiplab/gpt2-base-chinese" #"hfl/chinese-bert-wwm" #"bert-base-chinese"
#tokenizer = "hfl/chinese-bert-wwm"
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)

In [5]:
#returns dictionary containing plain sentence, constituency parse and coreference-id-labeled tokens 
def get_sentence_profiles(doc_id):
    profiles = {}
    filename = "ch_" + doc_id + ".onf"
    fp = dir + filename
    with open(fp, "r") as f:
        lines = f.readlines()
    #enumerated = enumerate(lines) -- don't do this
    plain_sent_idxs = [i for i, txt in enumerate(lines) if txt == ("Plain sentence:\n")]
    treebanked_idxs = [i for i, txt in enumerate(lines) if txt.find("Treebanked sentence:\n") > -1]
    tree_idxs = [i for i, txt in enumerate(lines) if txt == ("Tree:\n")]
    leaves_idxs = [i for i, txt in enumerate(lines) if txt == ("Leaves:\n")]

    n_sents = len(plain_sent_idxs)

    for i in range(n_sents):
        profile = doc_id + "_" + str(i)
        profiles[profile] = {}
        profiles[profile]["plain"] = lines[(plain_sent_idxs[i] + 2):treebanked_idxs[i]][0].strip()
        profiles[profile]["plain"] = profiles[profile]["plain"].replace("--", "")
        raw_tree = lines[(tree_idxs[i]+2):leaves_idxs[i]]
        profiles[profile]["tree"] = process_tree(raw_tree)

        if i < n_sents - 1:
            profiles[profile]["leafnotes"] = lines[(leaves_idxs[i]+2):(plain_sent_idxs[i+1]-3)]
        else:
            profiles[profile]["leafnotes"] = lines[(leaves_idxs[i]+2):-3]

        profiles[profile]["leafnotes"] = process_leaves(profiles[profile]["leafnotes"])
        profiles[profile]["leaves"] = [leaf for leaf in Tree.fromstring(profiles[profile]["tree"]).leaves() if '*' not in leaf]

        #i += 1
    return profiles

def process_leaves(leafnotes):
    leaves_dict = {}
    i = 0
    for line in leafnotes:
        line = line.strip()
        if len(line) > 0:
            line = line.split()
            if line[0].isdigit():
                if line[1] != "--":
                    leaves_dict[i] = {"token":"", "info": {}}
                    leaves_dict[i]["token"] = line[1]
                    i += 1
            elif i-1 in leaves_dict:
                leaves_dict[i-1]["info"][line[0]] = line[1:]
       
    return leaves_dict

def process_tree(raw_tree):
    tree = ""
    for line in raw_tree:
        tree += line.strip()
    return tree

def get_doc_profile(doc_id):
    profile = {}
    sentence_profiles = get_sentence_profiles(doc_id)
    profile["plain"] = ""
    for id in sentence_profiles:
        profile["plain"] += sentence_profiles[id]["plain"]
    profile["trees"] = [sentence_profiles[profile]["tree"] for profile in sentence_profiles]

    #creates a list of dictionaries containing leafnotes for each sentence
    profile["leaves_per_tree"] = []
    for id in sentence_profiles:
        profile["leaves_per_tree"].append(deepcopy(sentence_profiles[id]["leafnotes"])) #must deepcopy or dictionaries will merge

    #merged leaves
    profile["leafnotes"] = sentence_profiles[doc_id + "_0"]["leafnotes"]
    sentence_profiles_tail = dict(list(sentence_profiles.items())[1:])
    i = list(profile["leafnotes"].keys())[-1]
    for id in sentence_profiles_tail:
        for leaf in sentence_profiles[id]['leafnotes']:
            i += 1
            profile["leafnotes"][i] = sentence_profiles_tail[id]['leafnotes'][leaf]


    #words
    profile["leaves"] = []
    for tree in profile["trees"]:
        leaves = Tree.fromstring(tree).leaves()
        profile["leaves"] += ([leaf for leaf in leaves if '*' not in leaf])

    return profile


In [6]:
dir = "ch_copy/00/"
doc_id = '0037'

#s_profiles = get_sentence_profiles(doc_id)
doc_profile = get_doc_profile(doc_id)
#sample = s_profiles['0037_0']
doc_profile["leafnotes"]

{0: {'token': '*pro*', 'info': {'coref:': ['IDENT', '4', '0-0', '*pro*']}},
 1: {'token': '就', 'info': {}},
 2: {'token': '觉',
  'info': {'prop:': ['觉.01'],
   'v': ['*', '->', '2:0,', '觉'],
   'ARG0': ['*', '->', '0:0,', '*pro*'],
   'ARGM-DIS': ['*', '->', '1:1,', '就'],
   'ARG1': ['*', '->', '3:3,', '*pro*', '挺', '不', '舒服', '的']}},
 3: {'token': '*pro*', 'info': {'coref:': ['IDENT', '5', '3-3', '*pro*']}},
 4: {'token': '挺', 'info': {}},
 5: {'token': '不', 'info': {}},
 6: {'token': '舒服',
  'info': {'prop:': ['舒服.02'],
   'v': ['*', '->', '6:0,', '舒服'],
   'ARG0': ['*', '->', '3:0,', '*pro*'],
   'ARGM-ADV': ['*', '->', '5:1,', '不']}},
 7: {'token': '的', 'info': {}},
 8: {'token': ',', 'info': {}},
 9: {'token': '本来', 'info': {}},
 10: {'token': '就是', 'info': {}},
 11: {'token': '信心', 'info': {}},
 12: {'token': '挺', 'info': {}},
 13: {'token': '足',
  'info': {'prop:': ['足.01'],
   'v': ['*', '->', '13:0,', '足'],
   'ARGM-ADV': ['*', '->', '12:1,', '挺'],
   'ARGM-DIS': ['*', '->', '

In [2]:
def filter_attn(attn_matrix, leaves, tokens):
    #attn_matrix: 2-D numpy array
    #print("tokens:",tokens)
    #print("leaves:",leaves)
    filter = []
    for i in range(len(leaves)):
        if i >= len(tokens):
            break
        if tokens[i] == '▁': #▁ is not an underscore (_)
            #print(tokens[i])
            del tokens[i]
        #print(tokens[i], leaves[i])
        if tokens[i].replace('▁', '') != leaves[i]:
            filter.append(i)
    attn_matrix = np.delete(attn_matrix, filter, axis = 0)
    attn_matrix = np.delete(attn_matrix, filter, axis = 1)
    return attn_matrix

In [7]:
#returns attention weights for specified head, layer
def attention_map(model, profile):
    inputs = tokenizer(profile["plain"], return_tensors="pt")
    if inputs["input_ids"].shape[1] > 512:
        inputs = tokenizer(profile["plain"][:520], return_tensors="pt")
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    outputs = model(**inputs)
    #print(torch.stack(outputs.attentions).shape)
    attention_map = []
    attentions = torch.stack(outputs.attentions).cpu()#.squeeze(0).detach().numpy()
    for layer in range(attentions.shape[0]):
        for head in range(attentions.shape[2]):
            attention_matrix = attentions[layer][0][head].squeeze(0).detach().numpy()
            attention_matrix = filter_attn(attention_matrix, profile["leaves"], tokens)
            attention_map.append(attention_matrix)
    return (attention_map, attentions.squeeze(0).detach().numpy())

In [9]:
def get_verb_idx(tree, nsubj_idx):
    t_positions = tree.treepositions(order = 'preorder')
    leaf_positions = [tree.leaf_treeposition(idx) for idx in range(len(tree.leaves()))]
    nsubj_leaf_position = tree.leaf_treeposition(nsubj_idx)
    #print("get_verb_idx nsubj_idx:", nsubj_idx)
    IP_idx = t_positions.index(nsubj_leaf_position) - 1
    IP_pos = t_positions[IP_idx]
    IP_tree = tree[IP_pos]

    while (type(IP_tree) == type("s") or IP_tree.label() not in ['INC', 'IP']):
        IP_idx -= 1
        IP_pos = t_positions[IP_idx]
        IP_tree = tree[IP_pos]

    '''if type(IP_tree) == type("string"):
        print(tree.leaves()[nsubj_idx])
        print(IP_tree)
        print(type(IP_tree), tree)'''
        
    IPt_positions = IP_tree.treepositions(order='preorder')
    mv_idx = None
    for i in range(len(IPt_positions)):
        position = IPt_positions[i]
        if type(IP_tree[position]) == str and ('V' in IP_tree[IPt_positions[i-1]].label() or 'NT' in IP_tree[IPt_positions[i-1]].label()): #to-do: optimize
            #print(IP_tree[IPt_positions[i-1]])
            mv_idx = leaf_positions.index(t_positions[IP_idx + i])
            #print(mv_idx)
            return mv_idx
    
    return None

In [10]:

def compare_rankings(hyp_rankings, attn_rankings):
     #diffs = [abs(attn_rankings.index(attn_idx) - hyp_rankings.index(attn_idx)) for attn_idx in attn_rankings]
     #print(hyp_rankings)
     #print(attn_rankings)
     #print(hyp_rankings, attn_rankings)
     total_diff = sum([hyp_rankings[i] == attn_rankings[i] for i in range(len(attn_rankings))])
     return total_diff if len(attn_rankings) == 0 else total_diff/len(attn_rankings)

#returns list of candidates ranked by sums of attentions corresponding to each candidate for a verb
'''def aggregate_attentions(candidate_attns, coref_ids): 
    coref_set = set(coref_ids)
    aggregates_per_mention = {id:0 for id in coref_set}
    for i in range(len(candidate_attns)):
        aggregates_per_mention[coref_ids[i]] += candidate_attns[i][0]
    print(aggregates_per_mention)
    new_rankings = sorted([(id, aggregates_per_mention[id]) for id in coref_ids], key = lambda tup: tup[1], reverse=True)
    new_rankings = list(dict.fromkeys(new_rankings))
    return new_rankings'''
def aggregate_attentions(candidate_attns, coref_ids):
    coref_to_score = {}
    for (idx, score), coref_id in zip(candidate_attns, coref_ids):
        if coref_id not in coref_to_score:
            coref_to_score[coref_id] = 0.0
        coref_to_score[coref_id] += score
    ranked = sorted(coref_to_score.items(), key=lambda x: -x[1])  # Descending
    return [coref_id for coref_id, score in ranked]

In [31]:
def test_hyp_2(profile):
    comparisons = {}
    in_order = []
    leafnotes_split = profile["leaves_per_tree"]
    intermediate, pooled = attention_map(model, profile)
    matrix_size = len(intermediate[0])
    counter = 0
    for i in range(len(leafnotes_split)):
        for word_idx in leafnotes_split[i]:
            word_idx_global = word_idx + counter
            if word_idx_global >= matrix_size:
                "breaking"
                return comparisons, in_order
            info = leafnotes_split[i][word_idx]['info']
            if 'ARG0' in info and 'ARG1' in info:
                coref_idx0 = int(info['ARG0'][2].split(':')[0])
                coref_idx1 = int(info['ARG1'][2].split(':')[0])
                coref_idx0_global = coref_idx0 + counter
                coref_idx1_global = coref_idx1 + counter
                if coref_idx0_global >= matrix_size or coref_idx1_global >= matrix_size:
                    print("out of bounds")
                    print(coref_idx0_global, coref_idx1_global)
                    return comparisons, in_order
                #verb = leafnotes[idx]['token']
                #print(leafnotes_split[i])
                if 'coref:' in leafnotes_split[i][coref_idx0]['info'] and 'coref:' in leafnotes_split[i][coref_idx1]['info']:
                    coref_id0 = leafnotes_split[i][coref_idx0]['info']['coref:'][1]
                    coref_id1 = leafnotes_split[i][coref_idx1]['info']['coref:'][1]
                    hypothesis = [coref_id0, coref_id1]
                    comparisons[word_idx_global] = []
                    for n in range(len(intermediate)):
                        rank_by_attn = []
                        #print("coref_idx0:", coref_idx0+counter)
                        #print("coref_idx1:", coref_idx1+counter)
                        idxs_ranked = sorted([(candidate_idx, intermediate[n][int(word_idx)+counter][candidate_idx]) for candidate_idx in [coref_idx0_global, coref_idx1_global]], key=lambda tup: tup[1], reverse=True)
                        idxs_ranked = [tup[0] for tup in idxs_ranked]
                        for idx in idxs_ranked:
                            rank_by_attn.append(leafnotes_split[i][idx - counter]['info']["coref:"][1])
                        #print(hypothesis, rank_by_attn)
                        comparisons[word_idx_global].append(compare_rankings(hypothesis, rank_by_attn))
                        in_order.append(True) if rank_by_attn[0] <= rank_by_attn[1] else in_order.append(False)
        counter += len(leafnotes_split[i])

    return comparisons, in_order

    

In [28]:
test_hyp_2(doc_profile)

['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5'] ['4', '5']
['4', '5']

({2: [1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.

In [None]:
hyp2_dfs = []
for i in range(1, 25):
    print(i)
    doc_id = ("000" if i < 10 else "00") + str(i)
    doc_profile = get_doc_profile(doc_id)
    comparisons, in_order = test_hyp_2(doc_profile)
    comparisons_df = pd.DataFrame.from_dict(comparisons)
    comparisons_df.columns = [(f"{str(i)}_" + str(colname)) for colname in comparisons_df.columns]
    hyp2_dfs.append(comparisons_df)

hyp2_comparisons = pd.concat(hyp2_dfs, axis=1)
pd.to_csv("hyp2_comparisons.csv", index = False)


['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7'] ['7', '7']
['7', '7']

KeyboardInterrupt: 