In [1]:
import sys; sys.path.insert(0, '..')

In [2]:
import pickle
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import math
import sklearn.metrics as metrics
import statistics

from fields import IncParseField
from scipy.stats import entropy

In [3]:
biaffine = './outputs/preprocessed_dep-biaffine-roberta-en.pkl'

with open(biaffine, 'rb') as f:
    biaffine_data = pickle.load(f)

biaffine_parse = [IncParseField(biaffine_data[idx]) for idx in range(len(biaffine_data))]

In [4]:
def compute_ent(parse):
    '''
    Compute Shannon/self entropy for dependency arcs.
    '''
    token_len = len(parse['arc_attn'])
    ent_tensor = torch.full((token_len, token_len), float('Inf'))

    for step in range(token_len):
        self_ent = entropy(parse['arc_attn'][step].numpy(), axis=1)
        ent_tensor[step][:step+1] = torch.from_numpy(self_ent)
    return ent_tensor

def compute_var(ent_tensor):
    '''
    Compute variation (effect of the (t+X)th token against t).
    '''
    diff_tensor = torch.full_like(ent_tensor, 0)

    for step in range(1, diff_tensor.size(0)):
        diff_tensor[step][:step] = torch.abs(ent_tensor[step][:step] - ent_tensor[step-1][:step])

    return diff_tensor

## Variation
We compute the variation of entropy of the arc distribution with respect to the previous state, using the biaffine parser with RoBERTa. Through this, we measure the effect how much on average token $t+i, i=1, ..., T$ impacts $t$. We only use the unambiguous stimuli here.

## NNC

In [5]:
T = 4 # Only 5 tokens for NNC baseline
NNC_var_dict = {t: [] for t in range(1,T+1)}

for field in biaffine_parse:
    if field.source == 'nnc':
        
        baseline = field.parses['baseline']
        baseline_ent_self = compute_var(compute_ent(baseline))     

        for t in range(1, T+1):
            NNC_var_dict[t].append(baseline_ent_self.diagonal(-t))
        
for key, value in NNC_var_dict.items():
    mean = torch.mean(torch.cat(value)).item()
    print('t+{}: {:.2f}'.format(key, mean))

t+1: 0.14
t+2: 0.21
t+3: 0.29
t+4: 0.08


## NP/S

In [6]:
T = 10
NPS_var_dict = {t: [] for t in range(1,T+1)}

for field in biaffine_parse:
    if field.source == 'classic-nps':
        
        baseline = field.parses['baseline']
        baseline_ent_self = compute_var(compute_ent(baseline))           

        for t in range(1, T+1):
            NPS_var_dict[t].append(baseline_ent_self.diagonal(-t))

for key, value in NPS_var_dict.items():
    mean = torch.mean(torch.cat(value)).item()
    print('t+{}: {:.2f}'.format(key, mean))

t+1: 0.19
t+2: 0.06
t+3: 0.05
t+4: 0.01
t+5: 0.00
t+6: 0.00
t+7: 0.00
t+8: 0.00
t+9: 0.00
t+10: 0.00


## MVRR

In [7]:
T = 10
MVRR_var_dict = {t: [] for t in range(1,T+1)}

for field in biaffine_parse:
    if field.source == 'classic-mvrr':
        
        baseline = field.parses['baseline']
        baseline_ent_prev = compute_var(compute_ent(baseline))  

        for t in range(1, T+1):
            MVRR_var_dict[t].append(baseline_ent_prev.diagonal(-t))


for key, value in MVRR_var_dict.items():
    mean = torch.mean(torch.cat(value)).item()
    print('t+{}: {:.2f}'.format(key, mean))

t+1: 0.11
t+2: 0.02
t+3: 0.01
t+4: 0.02
t+5: 0.01
t+6: 0.00
t+7: 0.00
t+8: 0.00
t+9: 0.00
t+10: 0.00
