In [None]:
import argparse
import json
import numpy as np
import os
import pandas as pd
import time
import torch
from torch.nn.functional import log_softmax
from tqdm import tqdm
from transformers import GPT2Tokenizer
from collections import Counter

In [None]:
with open('chains_all.json', 'r') as f:
    chains = json.load(f)
len(chains)

In [None]:
df = pd.read_csv(
    'surprisal_SBNC_gpt2_50_1e-3_agg.csv'
)

In [None]:
dialogue_ids = set(df['Dialogue ID'].tolist())
print('{} dialogues'.format(len(dialogue_ids)))

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')


In [None]:
def find_subsequence(subsequence, sequence):
    l = len(subsequence)
    ranges = []
    for i in range(len(sequence)):
        if sequence[i:i+l] == subsequence:
            ranges.append((i, i+l))
    return ranges

def find_subsequence_plain_text(subsequence, sequence):
    try:
        l = len(subsequence)
    except TypeError:
        print(subsequence)
    ranges = []
    for i in range(len(sequence)):
        if sequence[i:i+l] == subsequence:
            if i - 1 < 0:
                space_before = True
            else:
                space_before = sequence[i-1] in " ',.!:;?"
  
            if i + l >= len(sequence):
                space_after = True
            else:
                space_after = sequence[i+l] in " ',.!:;?"
                
            if space_before and space_after:
                ranges.append((i, i+l))
    return ranges


In [None]:
def facilitating_effect(turn_surprisal_values, construction_indices, window=10):
    start_constr, end_constr = construction_indices
    
    if window:
        start_ctx = start_constr - window
        if start_ctx < 0:
            start_ctx = 0
        end_ctx = end_constr + window
        if end_ctx > len(turn_surprisal_values):
            end_ctx = len(turn_surprisal_values)
    else:
        start_ctx = 0
        end_ctx = len(turn_surprisal_values)
        
    indices_locus = [i for i in range(start_ctx, end_ctx) if i not in range(start_constr, end_constr)]
    
    if not indices_locus:
        return 0
    
    surprisal_wo_constr = np.mean(
        [h for i, h in enumerate(turn_surprisal_values) if i in indices_locus]
    )
    surprisal_constr = np.mean(
        [h for i, h in enumerate(turn_surprisal_values) if i in range(start_constr, end_constr)]
    )
    
    return np.log2(surprisal_wo_constr / surprisal_constr)


def std_surprisal(turn_surprisal_values, construction_indices, window=None):
    start_constr, end_constr = construction_indices
    surprisal_constr = np.mean(
        [h for i, h in enumerate(turn_surprisal_values) if i in range(start_constr, end_constr)]
    )
    if window:
        start_ctx = start_constr - window
        if start_ctx < 0:
            start_ctx = 0
        end_ctx = end_constr + window
        if end_ctx > len(turn_surprisal_values):
            end_ctx = len(turn_surprisal_values)
    else:
        start_ctx = 0
        end_ctx = len(turn_surprisal_values)
        
    mu = np.mean(turn_surprisal_values[start_ctx: end_ctx])
    sigma = np.std(turn_surprisal_values[start_ctx: end_ctx])
    
    return (surprisal_constr - mu) / sigma


def surprisal(turn_surprisal_values, construction_indices):
    start_constr, end_constr = construction_indices
    surprisal_constr = np.mean(
        [h for i, h in enumerate(turn_surprisal_values) if i in range(start_constr, end_constr)]
    )
    return surprisal_constr


In [None]:
new_chains = {}

constrs = Counter()

for d_id in tqdm(chains):
        
    new_chains[d_id] = {}
    
    df_d = df[df['Dialogue ID'] == d_id]
    for constr in chains[d_id]:
        
        new_chains[d_id][constr] = []
    
        constr_tokens_w_space = tokenizer.convert_ids_to_tokens(tokenizer(' ' + constr)['input_ids'])
        constr_tokens_wo_space = tokenizer.convert_ids_to_tokens(tokenizer(constr)['input_ids'])
        
        prev_turn = None
        for occurrence in chains[d_id][constr]:

            df_row = df_d[df_d['Turn index'] == int(occurrence['CurrentTurn'])]
            
            turn_tokens = df_row['Tokens'].to_list()
            if not turn_tokens:
                print(d_id, constr)
                print('skip')
                continue
            
            if not occurrence['Topical']:
                constrs[constr] += 1
            
            turn_tokens = eval(turn_tokens[0])
                
#             turn_string = tokenizer.convert_tokens_to_string(turn_tokens)
            
            ranges1 = find_subsequence(constr_tokens_w_space, turn_tokens)
            ranges2 = find_subsequence(constr_tokens_wo_space, turn_tokens)
            ranges = ranges1 + ranges2
            ranges.sort(key=lambda x:x[0])
            
            if len(ranges) > occurrence['FrequencyInTurn']:
                # Remove extra ranges due to LM tokenizer's splitting differently:
                # "it was er" found in "it was erm" because tokenizer splits "erm" into ("er", "m")
                new_ranges = []
                for r in ranges:
                    extended_span = turn_tokens[r[0]-1: r[1]+1]
                    extended_span_plain = tokenizer.convert_tokens_to_string(extended_span)
                    if find_subsequence_plain_text(constr, extended_span_plain):
                        new_ranges.append(r)
                ranges = new_ranges
                
            elif len(ranges) < occurrence['FrequencyInTurn']:
                # Occcurs only once: construction is at too high position in the sentence.
                # Sentences are truncated after 1024 tokens during surprisal estimation.
                new_chains[d_id][constr].append(
                    {**occurrence, **{
                        'FE': float('NaN'),
                        'FE1': float('NaN'),
                        'FE2': float('NaN'),
                        'FE3': float('NaN'),
                        'FE4': float('NaN'),
                        'FE5': float('NaN'),
                        'FE10': float('NaN'),
                        'FE15': float('NaN'),
                        'FE20': float('NaN'),
                        'FE25': float('NaN'),
                        'FE30': float('NaN'),
                        'SS': float('NaN'),
                        'S' : float('NaN')
                    }}
                )
                continue
                
            start_idx, end_idx = ranges[occurrence['IndexInTurn']]
            
            tok_surprisal = eval(df_row['Surprisal'].to_list()[0])
            assert len(tok_surprisal) == len(turn_tokens)

            fe = facilitating_effect(tok_surprisal, (start_idx, end_idx), window=None)
            fe1 = facilitating_effect(tok_surprisal, (start_idx, end_idx), window=1)
            fe2 = facilitating_effect(tok_surprisal, (start_idx, end_idx), window=2)
            fe3 = facilitating_effect(tok_surprisal, (start_idx, end_idx), window=3)
            fe4 = facilitating_effect(tok_surprisal, (start_idx, end_idx), window=4)
            fe5 = facilitating_effect(tok_surprisal, (start_idx, end_idx), window=5)
            fe10 = facilitating_effect(tok_surprisal, (start_idx, end_idx), window=10)
            fe15 = facilitating_effect(tok_surprisal, (start_idx, end_idx), window=15)
            fe20 = facilitating_effect(tok_surprisal, (start_idx, end_idx), window=20)
            fe25 = facilitating_effect(tok_surprisal, (start_idx, end_idx), window=25)
            fe30 = facilitating_effect(tok_surprisal, (start_idx, end_idx), window=30)
    
#             if fe10 > 10:
#                 print(constr)
#                 print(turn_tokens[st: en])
#                 print(tok_surprisal[st: en])
#                 print(tok_surprisal[start_idx: end_idx])
#                 print()
    
            ss = std_surprisal(tok_surprisal, (start_idx, end_idx))
            s = surprisal(tok_surprisal, (start_idx, end_idx))
            
            new_chains[d_id][constr].append(
                {**occurrence, **{
                    'FE': fe,
                    'FE1': fe1,
                    'FE2': fe2,
                    'FE3': fe3,
                    'FE4': fe4,
                    'FE5': fe5,
                    'FE10': fe10,
                    'FE15': fe15,
                    'FE20': fe20,
                    'FE25': fe25,
                    'FE30': fe30,
                    'SS': ss,
                    'S' : s,
                }}
            )


In [None]:
sum(constrs.values()), len(constrs)

---

In [None]:
df_data = []
for d_id in new_chains:
    for constr in new_chains[d_id]:
        for occurrence in new_chains[d_id][constr]:
            df_data.append((d_id, constr,) + tuple(occurrence.values()))


In [None]:
columns = list(new_chains['SD8N']["it's like a"][0].keys())
columns = ['Dialogue ID', 'Form'] + columns

df = pd.DataFrame(df_data, columns=columns)

In [None]:
df.head()

In [None]:
df.to_csv('chains_all_SBNC_gpt2_50_1e-3.csv')