In [2]:
import matplotlib.pyplot as plt
from nltk.tree import Tree
import numpy as np
import re
import seaborn as sns
from transformers import AutoModel, AutoTokenizer
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
dir = "ch_copy/00/"
doc_id = '0037'

#returns dictionary containing plain sentence, constituency parse and coreference-id-labeled tokens 
def get_sentence_profiles(doc_id):
    profiles = {}
    filename = "ch_" + doc_id + ".onf"
    fp = dir + filename
    with open(fp, "r") as f:
        lines = f.readlines()
    #enumerated = enumerate(lines) -- don't do this
    plain_sent_idxs = [i for i, txt in enumerate(lines) if txt == ("Plain sentence:\n")]
    treebanked_idxs = [i for i, txt in enumerate(lines) if txt.find("Treebanked sentence:\n") > -1]
    tree_idxs = [i for i, txt in enumerate(lines) if txt == ("Tree:\n")]
    leaves_idxs = [i for i, txt in enumerate(lines) if txt == ("Leaves:\n")]

    n_sents = len(plain_sent_idxs)

    for i in range(n_sents):
        profile = doc_id + "_" + str(i)
        profiles[profile] = {}
        profiles[profile]["plain"] = lines[(plain_sent_idxs[i] + 2):treebanked_idxs[i]][0].strip()
        profiles[profile]["plain"] = profiles[profile]["plain"].replace("--", "")
        raw_tree = lines[(tree_idxs[i]+2):leaves_idxs[i]]
        profiles[profile]["tree"] = process_tree(raw_tree)

        if i < n_sents - 1:
            profiles[profile]["leaves"] = lines[(leaves_idxs[i]+2):(plain_sent_idxs[i+1]-3)]
        else:
            profiles[profile]["leaves"] = lines[(leaves_idxs[i]+2):-3]

        profiles[profile]["leaves"] = process_leaves(profiles[profile]["leaves"])

        i += 1
    return profiles

def process_leaves(leafnotes):
    leaves_dict = {}
    i = 0
    for line in leafnotes:
        line = line.strip()
        if len(line) > 0:
            line = line.split()
            if line[0].isdigit():
                if line[1] != "--":
                    leaves_dict[i] = {"token":"", "info": {}}
                    leaves_dict[i]["token"] = line[1]
                    i += 1
            elif i-1 in leaves_dict:
                leaves_dict[i-1]["info"][line[0]] = line[1:]
       
    return leaves_dict

def process_tree(raw_tree):
    tree = ""
    for line in raw_tree:
        tree += line.strip()
    return tree

def get_doc_profile(doc_id):
    profile = {}
    sentence_profiles = get_sentence_profiles(doc_id)
    profile["plain"] = ""
    for id in sentence_profiles:
        profile["plain"] += sentence_profiles[id]["plain"]
    profile["tree"] = [sentence_profiles[profile]["tree"] for profile in sentence_profiles]
    profile["leaves"] = sentence_profiles[doc_id + "_0"]["leaves"]
    sentence_profiles_tail = dict(list(sentence_profiles.items())[1:])
    i = list(profile["leaves"].keys())[-1]
    for id in sentence_profiles_tail:
        for leaf in sentence_profiles[id]['leaves']:
            i += 1
            profile["leaves"][i] = sentence_profiles_tail[id]['leaves'][leaf]
    return profile


In [45]:
s_profiles = get_sentence_profiles(doc_id)
doc_profile = get_doc_profile(doc_id)
sample = s_profiles['0037_0']#['plain']

In [46]:
print(sample['leaves'])

{0: {'token': '*pro*', 'info': {'coref:': ['IDENT', '4', '0-0', '*pro*']}}, 1: {'token': '就', 'info': {}}, 2: {'token': '觉', 'info': {'prop:': ['觉.01'], 'v': ['*', '->', '2:0,', '觉'], 'ARG0': ['*', '->', '0:0,', '*pro*'], 'ARGM-DIS': ['*', '->', '1:1,', '就'], 'ARG1': ['*', '->', '3:3,', '*pro*', '挺', '不', '舒服', '的']}}, 3: {'token': '*pro*', 'info': {'coref:': ['IDENT', '5', '3-3', '*pro*']}}, 4: {'token': '挺', 'info': {}}, 5: {'token': '不', 'info': {}}, 6: {'token': '舒服', 'info': {'prop:': ['舒服.02'], 'v': ['*', '->', '6:0,', '舒服'], 'ARG0': ['*', '->', '3:0,', '*pro*'], 'ARGM-ADV': ['*', '->', '5:1,', '不']}}, 7: {'token': '的', 'info': {}}, 8: {'token': ',', 'info': {}}, 9: {'token': '本来', 'info': {}}, 10: {'token': '就是', 'info': {}}, 11: {'token': '信心', 'info': {}}, 12: {'token': '挺', 'info': {}}, 13: {'token': '足', 'info': {'prop:': ['足.01'], 'v': ['*', '->', '13:0,', '足'], 'ARGM-ADV': ['*', '->', '12:1,', '挺'], 'ARGM-DIS': ['*', '->', '10:1,', '就是'], 'ARG0': ['*', '->', '11:1,', '信心']

In [5]:
#tokenization

##word-level
def get_words(sentence_profile):
    leaves = sentence_profile["leaves"]
    #words = [(idx, leaves[i]["token"]) for idx, i in enumerate(leaves)]
    words = [leaves[i]["token"] for i in leaves]
    return words

##character-level
def get_characters(sentence_profile):
    #characters = [(idx, ch) for idx, ch in enumerate(sentence_profile["plain"])]
    characters = [ch for ch in sentence_profile["plain"]]
    return characters

In [None]:
#set model and tokenizer
model_name = "ckiplab/gpt2-base-chinese" #"hfl/chinese-bert-wwm" #"bert-base-chinese"
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)

In [None]:
heads = 0
layers = 0

#returns attention weights for specified head, layer
def attention_map(model, text, heads, layers):
    inputs = tokenizer(text, return_tensors="pt")
    outputs = model(**inputs)
    print(torch.stack(outputs.attentions).shape)
    attention_matrix = torch.stack(outputs.attentions)[layers, :, heads, :, :].squeeze(0).detach().numpy()
    print(attention_matrix.shape)
    # = attention[layers.unsqueeze(0), :, heads, :, :]
    #print(selection.shape)
    #aggregate = attention.sum(dim = 0).detach().numpy()
    #print(aggregate.shape)
    
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    punct_idxs = [idx for idx, token in enumerate(tokens) if token in [",", "[CLS]", "[SEP]"]]
    tokens = [token for i, token in enumerate(tokens) if i not in punct_idxs]
    attention_matrix = np.delete(attention_matrix, punct_idxs, axis = 1)
    attention_matrix = np.delete(attention_matrix, punct_idxs, axis = 0)
    
    return attention_matrix #return aggregate.detach().numpy()
    #tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

In [None]:
#returns the index of some verb in the predicate corresponding to the null subject
def get_verb_idx(tree, nsubj_idx):
    t_positions = tree.treepositions(order = 'preorder')
    leaf_positions = [tree.leaf_treeposition(idx) for idx in range(len(tree.leaves()))]
    nsubj_leaf_position = tree.leaf_treeposition(nsubj_idx)
    IP_idx = t_positions.index(nsubj_leaf_position) - 3
    IP_pos = t_positions[IP_idx]
    IP_tree = tree[IP_pos]
    IPt_positions = IP_tree.treepositions(order='preorder')
    mv_idx = None
    for i in range(len(IPt_positions)):
        position = IPt_positions[i]
        if type(IP_tree[position]) == str and IP_tree[IPt_positions[i-1]].label() == 'VV':
            print(IP_tree[IPt_positions[i-1]])
            mv_idx = leaf_positions.index(t_positions[IP_idx + i])
            print(mv_idx)
            return mv_idx
    return None

In [None]:
#Look for layers adhering to Cf rankings

#returns ordered indices of candidate antecedents
'''def mention_sequence(tree):
    lleaves = tree.pos()
    return [(leaf[0], i) for i, leaf in enumerate(lleaves) if "N" in leaf[1] and "*-" not in leaf[0]]'''
def mention_idxs(tree):
    return [i for i, leaf in enumerate(tree.pos()) if "N" in leaf[1] and "*-" not in leaf[0]]
'''def mention_set(string):
    return set(mention_sequence(string)) #ensure that order is preserved'''

#Hypothesis 1：linear order (looking for first-mention effect)
def coref_map(tree, leafnotes): #map candidate antecedents to coreference ids
    candidate_idxs = mention_idxs(tree)
    map = {}
    candidate_leafnotes = [leafnotes[idx] for idx in candidate_idxs]
    for idx in leafnotes:
        map[idx] = {idx: None}
        if leafnotes[idx]["info"] and leafnotes[idx]["info"]["coref"]:
                map[idx][idx] = leafnotes[idx]["info"]["coref"][1]
                #...to-do: finish implementation
                
def test_hyp_1(profile, attention_maps, is_bidirectional = True):
    tree = Tree.fromstring(profile["tree"])
    leafnotes = profile["leaves"]
    candidate_idxs = mention_idxs(tree)
    nsubj_idxs = [i for i in leafnotes if leafnotes[i]["token"] == "*pro*" and leafnotes["info"]]
    for nsubj_idx in nsubj_idxs:
        verb_idx = get_verb_idx(tree, nsubj_idx)
        candidate_idxs = [idx for idx in candidate_idxs if idx < nsubj_idx]
        for map in attention_maps:
            candidate_attentions = [map[verb_idx][candidate] for candidate in candidate_idxs]
            
            '''if is_bidirectional:
                candidate_attentions_reverse = [map[candidate][verb_idx] for candidate in candidate_idxs]'''
            #...to-do: finish implementation
            
#to-do: character-level implementation

#Hypothesis 2: grammatical role (ranking by grammatical role)
#to-do: finish implementation

In [None]:
#compares attention rankings to hypothesized rankings
def test_hyp_1(profile, attention_maps, is_bidirectional = True):
    tree = Tree.fromstring(profile["tree"])
    leafnotes = profile["leaves"]
    candidate_idxs = mention_idxs(tree)
    nsubj_idxs = [i for i in leafnotes if leafnotes[i]["token"] == "*pro*" and leafnotes["info"]]
    for nsubj_idx in nsubj_idxs:
        verb_idx = get_verb_idx(tree, nsubj_idx)
        candidate_idxs = [idx for idx in candidate_idxs if idx < nsubj_idx]
        for map in attention_maps:
            candidate_attentions = [map[verb_idx][candidate] for candidate in candidate_idxs]
            
            '''if is_bidirectional:
                candidate_attentions_reverse = [map[candidate][verb_idx] for candidate in candidate_idxs]'''
    #to-do: finish implementation
    
hypotheses = {1: test_hyp_1}      

In [None]:
#viewer
attention_matrix = attention_map(model, sample['plain'], heads, layers)
figure = sns.heatmap(attention_matrix, xticklabels=2, yticklabels=2)
figure.invert_yaxis()
plt.show()