In [None]:
import pandas as pd
import numpy as np
from collections import Counter, defaultdict
import json 
import os
import random
import sys
import string
from sklearn.feature_extraction.text import TfidfTransformer
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm_notebook as tqdm
sns.set_theme(style="whitegrid")
import sys

In [None]:
analysis_data_dir = '/Users/mario/code/exp-rep/data/BNC-2014/two-speakers/analysis/'
dialign_output = analysis_data_dir + 'dialign-output/'

# dialign_output += 'nopos/'

In [None]:
shared_lexica = {}
for f in os.listdir(dialign_output + 'nopos/'):
    if f.endswith('_tsv-lexicon.tsv') and not f.startswith('.'):
        filepath = os.path.join(dialign_output + 'nopos/', f)
        dial_id = f.split('_')[0]
        if dial_id not in shared_lexica:
            shared_lexica[dial_id] = pd.read_csv(filepath, sep='\t', header=0)
        else:
            shared_lexica[dial_id] = pd.concat([shared_lexica[dial_id], pd.read_csv(filepath, sep='\t', header=0)])


self_lexica = {}
for f in os.listdir(dialign_output + 'nopos/'):
    if (f.endswith('_tsv-lexicon-self-rep-A.tsv') or f.endswith('_tsv-lexicon-self-rep-B.tsv')) and not f.startswith('.'):
        filepath = os.path.join(dialign_output + 'nopos/', f)
        dial_id = f.split('_')[0]
        if dial_id not in self_lexica:
            self_lexica[dial_id] = pd.read_csv(filepath, sep='\t', header=0)
        else:
            self_lexica[dial_id] = pd.concat([self_lexica[dial_id], pd.read_csv(filepath, sep='\t', header=0)])


print(len(shared_lexica), len(self_lexica))

In [None]:
shared_lexica_pos = {}
for f in os.listdir(dialign_output + 'pos/'):
    if f.endswith('_tsv-lexicon.tsv') and not f.startswith('.'):
        filepath = os.path.join(dialign_output + 'pos/', f)
        dial_id = f.split('_')[0]
        if dial_id not in shared_lexica_pos:
            shared_lexica_pos[dial_id] = pd.read_csv(filepath, sep='\t', header=0)
        else:
            shared_lexica_pos[dial_id] = pd.concat([shared_lexica_pos[dial_id], pd.read_csv(filepath, sep='\t', header=0)])



self_lexica_pos = {}
for f in os.listdir(dialign_output + 'pos/'):
    if (f.endswith('_tsv-lexicon-self-rep-A.tsv') or f.endswith('_tsv-lexicon-self-rep-B.tsv')) and not f.startswith('.'):
        filepath = os.path.join(dialign_output + 'pos/', f)
        dial_id = f.split('_')[0]
        if dial_id not in self_lexica_pos:
            self_lexica_pos[dial_id] = pd.read_csv(filepath, sep='\t', header=0)
        else:
            self_lexica_pos[dial_id] = pd.concat([self_lexica_pos[dial_id], pd.read_csv(filepath, sep='\t', header=0)])

print(len(shared_lexica_pos), len(self_lexica_pos))


In [None]:
with open(analysis_data_dir + 'contexts.json', 'r') as f:
    contexts = json.load(f)


In [None]:
shared_lexica_pos['SJV7'].head(3)

In [None]:
def topical_or_referential(word_seq, pos_seq):
    assert len(word_seq) == len(pos_seq), (word_seq, pos_seq)
    
    GENERIC_NOUNS = 'bit bunch fact god middle ones part rest side sort sorts stuff thanks loads lot lots kind kinds time times way ways problem problems thing things idea ideas reason reasons day days week weeks year years'
    GENERIC_NOUNS = GENERIC_NOUNS.split(' ')
    
    if pos_seq.count('SUBST') >= 1:
        is_generic = True
        for w, tag in zip(word_seq, pos_seq):
            if tag == 'SUBST' and w not in GENERIC_NOUNS:
                is_generic = False
                break
        if not is_generic:
            return True
    
    return False


def more_than_half_filled_pauses(construction):
    construction = construction.split(' ')
    FILLED_PAUSES = ['huh', 'uh', 'erm', 'hm', 'mm', 'er']
    n_filled_pauses = 0.
    for w in construction:
        if w in FILLED_PAUSES:
            n_filled_pauses += 1
    return n_filled_pauses >= len(construction) / 2
    

In [None]:
more_than_half_filled_pauses('mm mm mm')

In [None]:
pos_tagged_constructions = {}
pos_tagged_constructions_topical = {}

for d_id in shared_lexica_pos:
    lexicon_df = pd.concat((self_lexica_pos[d_id], shared_lexica_pos[d_id]))
    for _, row in lexicon_df.iterrows():
        constr = row['Surface Form']
        
        if not isinstance(constr, str):
            continue
        constr = constr.replace('? #STOP', '?#STOP')
        constr = constr.strip()
            
        tokens = constr.split(' ')
        w_seq = []
        pos_seq = []
        illegal_constr = False
        for token in tokens:
            try:
                w, tag = token.split('#')
            except ValueError:
                illegal_constr = True 
            w_seq.append(w)
            pos_seq.append(tag)
        
        if illegal_constr:
            print('Illegal construction:', constr)
            continue   # only exception is: "made . com#SUBST"
            
        concat_tokens = ''.join(w_seq)
        
        # Referential or topical constructions?
        if topical_or_referential(w_seq, pos_seq):
            pos_tagged_constructions_topical[concat_tokens] = pos_seq
        else:
            pos_tagged_constructions[concat_tokens] = pos_seq

In [None]:
contexts['SJV7']['205']

In [None]:
def find_subsequence(subsequence, sequence):
    l = len(subsequence)
    ranges = []
    for i in range(len(sequence)):
        if sequence[i:i+l] == subsequence:
            if i - 1 < 0:
                space_before = True
            else:
                space_before = sequence[i-1] in " ',.!:;?"
  
            if i + l >= len(sequence):
                space_after = True
            else:
                space_after = sequence[i+l] in " ',.!:;?"
                
            if space_before and space_after:
                ranges.append((i, i+l))
    return ranges


In [None]:
ss = "and i was just like"
s = "and i was just like oh my god"
find_subsequence(ss, s)

In [None]:
ss = "bad for you"
s = "yeah it is bad for you bad for your teeth"
find_subsequence(ss, s)

In [None]:
corpus_counts = defaultdict(lambda: {}) #[dialogue][expression]
_data = defaultdict(lambda: {})

cnt = []
for d_id in tqdm(shared_lexica):
    lexicon_df = pd.concat((self_lexica[d_id], shared_lexica[d_id]))
    dialogue = contexts[d_id]
    
    for _, row in lexicon_df.iterrows():
        constr = row['Surface Form']
        
        if not isinstance(constr, str):
            continue
        constr = constr.strip()
            
        turns = row['Turns'].split(', ')
        
        _freq = 0
        for turn in turns:
            _, _, text = dialogue[turn]
            ranges = find_subsequence(constr, text)
            _freq += len(ranges)
        
        assert _freq >= row['Freq.']
        
        # Condition 1: at least 3 tokens long
        if row['Size'] < 3:
            continue
        
        # Condition 2: frequency >= 3 in the dialogue
        if _freq < 3:
            continue
            
        # Condition 3: free form frequency >= 2 in the dialogue
        if row['Free Freq.'] < 2:
            continue
        
        concat_tokens = constr.replace(' ', '')
        if concat_tokens in pos_tagged_constructions_topical:
            _pos_seq = pos_tagged_constructions_topical[concat_tokens]
            topical = True
            cnt.append(1)
        elif concat_tokens in pos_tagged_constructions:
            _pos_seq = pos_tagged_constructions[concat_tokens]
            topical = False
            cnt.append(1)
        else:
            # Skip constructions for which we find no POS-tagged equivalent
            cnt.append(0)
            continue
            
        # Condition 4: no punctuation in the construction
        if "STOP" in _pos_seq:
            continue
            
        # Condition 5: at least half of the construction should not correspond to filled pauses
        if more_than_half_filled_pauses(constr):
            continue
        
        _data[d_id][constr] = {
            'Frequency': _freq,
            'Free frequency': row['Free Freq.'],
            'Length': row['Size'],
            'POS sequence': _pos_seq,
            'First speaker': row['First Speaker'],
            'Turns': turns, 
            'Spanning turns': row['Spanning'],
            'Establishment turn': row['Establishment turn'],
            'Topical': topical
        }
            
        corpus_counts[d_id][constr] = _freq


In [None]:
print('Skipped {} out of {} constructions ({:.2f}%) as we find no POS-tagged equivalent.'.format(
    len([x for x in cnt if x == 0]),
    len(cnt),
    len([x for x in cnt if x == 0]) / len(cnt) * 100
))

In [None]:
class DialogueSpecificity:
    
    def __init__(self, corpus_counts):
        """
        Dialogue specificity statistics for expressions in a dialogue corpus.
        
        Args
            corpus_counts: A dictionary of dictionaries:
                           corpus_counts[dialogue][expression] = expression frequency in dialogue 
        """
        self.corpus_counts = corpus_counts
        self.dialogues = list(self.corpus_counts.keys())
        
        # Probability distribution over dialogues is uniform -- P(dial)
        self.p_dial = 1 / len(corpus_counts.keys()) 
        
        
        # The number of expressions (tokens) in a dialogue -- N_dial
        self.n_tokens_in_dial = {}  # N_dial
        for dial in corpus_counts:
            self.n_tokens_in_dial[dial] = sum(corpus_counts[dial].values())
        
        
        # The (token) frequency of an expression in the corpus -- N_exp
        self.exp_freq_in_corpus = defaultdict(int)
        
        # The total number of expressions (tokens) in the corpus -- N_corpus
        self.n_exp_tokens = 0  # N_corpus
        
        # The probability of an expression in the corpus -- P(exp)
        self.p_exp = defaultdict(int)
        
        # The probability of an expression in a dialogue -- P(exp|dial)
        self.p_exp_given_dial = defaultdict(int)
        
        
        for dial in corpus_counts:
            for exp in corpus_counts[dial]:
                self.exp_freq_in_corpus[exp] += corpus_counts[dial][exp]
                self.n_exp_tokens += corpus_counts[dial][exp]
                
                # P(exp|dial) = fr(exp, dial) / N_exp
                self.p_exp_given_dial[(exp, dial)] = self.corpus_counts[dial][exp] / self.n_tokens_in_dial[dial]
                
                # P(exp) = sum(dial' in corpus) [ P(exp|dial') * P(dial') ]
                self.p_exp[exp] += self.p_exp_given_dial[(exp, dial)] * self.p_dial
        
        
        # The total number of expression types in the corpus -- E_corpus
        self.n_exp_types = len(self.exp_freq_in_corpus)
                
                
        # P(dial|exp) for all expressions in all dialogues
        self.dialogue_posteriors = {}
        for dial in self.corpus_counts:
            self.dialogue_posteriors[dial] = self.get_dialogue_posteriors(dial)
            
        # Specificity [P(exp|dial) - P(exp)] for all expressions in all dialogues
        self.dialogue_specificity = {}
        for dial in self.corpus_counts:
            self.dialogue_specificity[dial] = self.get_dialogue_specificity(dial)  
            
        # PMI(exp, dial) for all expressions in all dialogues
        self.pmi = {}
        for dial in self.corpus_counts:
            self.pmi[dial] = self.get_dialogue_pmi(dial)
            
        # MD(exp, dial) for all expressions in all dialogues
        self.mutual_dependency = {}
        for dial in self.corpus_counts:
            self.mutual_dependency[dial] = self.get_mutual_dependency(dial)
            
        # LFMD(exp, dial) for all expressions in all dialogues
        self.lf_mutual_dependency = {}
        for dial in self.corpus_counts:
            self.lf_mutual_dependency[dial] = self.get_lfmd(dial)
            
        
    def posterior(self, expression, dialogue):
        """
        Compute P(dialogue|expression) for a given expression in a given dialogue.
        P(dialogue|expression) = P(expression|dialogue) * P(dialogue) / P(expression)
        """ 
        return self.p_exp_given_dial[(expression, dialogue)] * self.p_dial / self.p_exp[expression]
    
    
    def get_dialogue_posteriors(self, dialogue):
        """
        Compute P(dialogue|expression) for all expressions in a dialogue.
        """
        posteriors = Counter()
        for exp in self.corpus_counts[dialogue]:
            posteriors[exp] = self.posterior(exp, dialogue)
        return posteriors
    
    
    def get_dialogue_specificity(self, dialogue):
        """
        Compute P(expression|dialogue) - P(expression) for all expressions in a dialogue.
        """
        specificity = Counter()
        for exp in self.corpus_counts[dialogue]:
            specificity[exp] = self.p_exp_given_dial[(exp, dialogue)] - self.p_exp[exp]
        return specificity
    
    
    def get_dialogue_pmi(self, dialogue):
        """
        Compute pointwise mutual information for all expressions in a dialogue.
        PMI(expression, dialogue) = log[ P(expression|dialogue) / P(expression) ]
        """
        pmi = Counter()
        for exp in self.corpus_counts[dialogue]:
            pmi[exp] = np.log2(self.p_exp_given_dial[(exp, dialogue)] / self.p_exp[exp])
        return pmi
    
    
    def get_mutual_dependency(self, dialogue):
        """
        Compute mutual dependence MD(expression, dialogue) for all expressions in a dialogue (Thanopoulos et al, 2002).
        MD(expression, dialogue) = log[ P(expression|dialogue)**2 * P(dial) / P(expression) ]
        """
        md = Counter()
        for exp in self.corpus_counts[dialogue]:
            md[exp] = np.log2(((self.p_exp_given_dial[(exp, dialogue)] ** 2) * self.p_dial) / self.p_exp[exp])
        return md
    

    def get_lfmd(self, dialogue):
        """
        Compute log-frequency biased mutual dependence LFMD(expression, dialogue) for all expressions in a dialogue (Thanopoulos et al, 2002).
        LFMD(expression, dialogue) = MD(expression, dialogue + log P(expression, dialogue) 
        """
        lfmd = Counter()
        for exp in self.corpus_counts[dialogue]:
            lfmd[exp] = self.mutual_dependency[dialogue][exp] + np.log2(self.p_exp_given_dial[(exp, dialogue)] * self.p_dial)
        return lfmd

In [None]:
ds = DialogueSpecificity(corpus_counts)

In [None]:
data = defaultdict(lambda: defaultdict(list))

# For each dialogue
for dial_id in tqdm(_data):
    
    for constr in _data[dial_id]:
        
        # if ds.pmi[dial_id][constr] <= 1:
        #     continue
        
        turns = _data[dial_id][constr]['Turns']
        
        current_freq = 0
        prev_turn_any = None
        prev_turn_by_speaker = {}
        prev_start_idx_any = None
        prev_start_idx_by_speaker = {}
        first_occ_start_idx = None
        first_occ_turn = None
        speakers = set()
        prev_speaker = None
        
        for turn in turns:
        
            _, turn_speaker, text = contexts[dial_id][turn]
            other_speaker = [s for s in speakers if s != turn_speaker]
            
            
            if len(other_speaker) == 0:
                other_speaker = None
            elif len(other_speaker) == 1:
                other_speaker = other_speaker[0]
            else:
                raise ValueError('There should be maximum two speakers: {}'.format(speakers))
                        
            ranges = find_subsequence(constr, text)
            
            for j, (start_idx, end_idx) in enumerate(ranges):
                current_freq += 1
                shared_currently = len(speakers) == 2
                
                if not prev_turn_any:
                    recency_any = -1
                    recency_same = -1
                    recency_other = -1
                    current_spanning = -1
                    first_occ_start_idx = start_idx
                    first_occ_turn = turn
                else:
                    
                    if first_occ_turn == turn:
                        current_spanning = len(text[first_occ_start_idx:start_idx].split())
                    else:
                        current_spanning = len(contexts[dial_id][first_occ_turn][2][first_occ_start_idx:].split())
                        for t in range(int(first_occ_turn) + 1, int(turn)):
                            current_spanning += len(contexts[dial_id][str(t)][2].split())
                        current_spanning += len(text[:start_idx].split())
                        
                    if prev_turn_any == turn:
                        recency_any = len(text[prev_start_idx_any:start_idx].split())
                    else:
                        recency_any = len(contexts[dial_id][prev_turn_any][2][prev_start_idx_any:].split())
                        for t in range(int(prev_turn_any) + 1, int(turn)):
                            recency_any += len(contexts[dial_id][str(t)][2].split())
                        recency_any += len(text[:start_idx].split())               
                    
                    if turn_speaker not in prev_turn_by_speaker:
                        recency_same = -1
                    else:
                        if prev_turn_by_speaker[turn_speaker] == turn:
                            recency_same = len(text[prev_start_idx_by_speaker[turn_speaker]:start_idx].split())
                        else:
                            recency_same = len(contexts[dial_id][prev_turn_by_speaker[turn_speaker]][2][prev_start_idx_by_speaker[turn_speaker]:].split())
                            for t in range(int(prev_turn_by_speaker[turn_speaker]) + 1, int(turn)):
                                recency_same += len(contexts[dial_id][str(t)][2].split())
                            recency_same += len(text[:start_idx].split())
                    
                    if not other_speaker in prev_turn_by_speaker:
                        recency_other = -1
                    else:
                        if prev_turn_by_speaker[other_speaker] == turn:
                            recency_other = len(text[prev_start_idx_by_speaker[other_speaker]:start_idx].split())
                        else:
                            recency_other = len(contexts[dial_id][prev_turn_by_speaker[other_speaker]][2][prev_start_idx_by_speaker[other_speaker]:].split())
                            for t in range(int(prev_turn_by_speaker[other_speaker]) + 1, int(turn)):
                                recency_other += len(contexts[dial_id][str(t)][2].split())
                            recency_other += len(text[:start_idx].split())
                    

                prev_turn_any = turn
                prev_start_idx_any = start_idx
                prev_turn_by_speaker[turn_speaker] = turn
                prev_start_idx_by_speaker[turn_speaker] = start_idx
                
                data[dial_id][constr].append({
                    'CurrentTurn': turn,
                    'PositionInTurn': (start_idx, end_idx),
                    'IndexInTurn': j,
                    'PreviousInSameTurn': j > 0,
                    'FrequencyInTurn': len(ranges),
                    'Turns': turns, 
                    'RepetitionIndex': current_freq,
                    'Frequency': _data[dial_id][constr]['Frequency'],
                    'FreeFrequency': _data[dial_id][constr]['Free frequency'],
                    'Length': _data[dial_id][constr]['Length'],
                    'POSsequence': _data[dial_id][constr]['POS sequence'],
                    'Speaker': turn_speaker,
                    'FirstSpeaker': _data[dial_id][constr]['First speaker'],
                    'TotalSpanningTurns': _data[dial_id][constr]['Spanning turns'],
                    'CurrentSpanningTokens': current_spanning,
                    'RecencyBoth': recency_any,
                    'RecencySame': recency_same,
                    'RecencyOther': recency_other,
                    'SharedCurrently': shared_currently,
                    'Shared': _data[dial_id][constr]['Establishment turn'] >= 0,
                    'EstablishmentTurn': _data[dial_id][constr]['Establishment turn'],
                    'SameAsFirstSpeaker': turn_speaker == _data[dial_id][constr]['First speaker'],
                    'SameAsPreviousSpeaker': turn_speaker == prev_speaker,
                    'Topical': _data[dial_id][constr]['Topical'],
                    'PMI': ds.pmi[dial_id][constr],
                })
                prev_speaker = turn_speaker
                
            speakers.add(turn_speaker)
            assert(len(speakers) in [1,2])


----

## Example constructions

In [None]:
# for d_id in ds.pmi.keys():
for d_id in list(np.random.choice(list(ds.pmi.keys()), 10)):
# for d_id in ['SPXV', 'S38F', 'SAXQ', 'SJDK', 'SJV7']:
    print(d_id)
    print('----')
    ii=0
    for rank, (exp, score) in enumerate(ds.pmi[d_id].most_common()):
        if ii >= 10:
            break
        ds.corpus_counts[d_id][exp]
        if score > 1:
            if data[d_id][exp][0]['Topical']:
                print('\\textit{', exp, '}', '  &  ', sep='')
            else:
                print(exp, ' &  ')
            ii += 1
    print()

## Construction length

In [None]:
constr_lens = {}
for dial in data:
    for constr in data[dial]:
        if constr not in constr_lens:
            len_ = data[dial][constr][0]['Length']
            constr_lens[constr] = len_

lens = list(constr_lens.values())
np.min(lens), np.max(lens), np.mean(lens), np.std(lens), np.median(lens)

In [None]:
p = sns.histplot(lens)
p.axes.set_yscale('log')

In [None]:
print('Total unique :', len(lens))
for l in range(np.min(lens), np.max(lens) + 1):
    print('{:2d}   {:4d}   {:5.2f}%'.format(l, 
                                           len([x for x in lens if x == l]), 
                                           len([x for x in lens if x == l]) / len(lens) * 100))


#### Remove outliers with length 9 and 12

In [None]:
outlier_lengths = [9, 12]

new_data = {}
for dial in data:
    new_data[dial] = {}
    for constr in data[dial]:
        len_ = data[dial][constr][0]['Length']
        if len_ in outlier_lengths:
            print('Removed "{}", with freq. {} in dialogue {}'.format(constr, len(data[dial][constr]), dial))
        else:
            new_data[dial][constr] = data[dial][constr]


In [None]:
constr_lens = {}
for dial in new_data:
    for constr in new_data[dial]:
        if constr not in constr_lens:
            len_ = new_data[dial][constr][0]['Length']
            constr_lens[constr] = len_

lens = list(constr_lens.values())
print(np.min(lens), np.max(lens), np.mean(lens), np.std(lens), np.median(lens))


print('Total:', len(lens))
for l in range(np.min(lens), np.max(lens) + 1):
    print('{:2d}   {:4d}   {:5.2f}%'.format(l, 
                                           len([x for x in lens if x == l]), 
                                           len([x for x in lens if x == l]) / len(lens) * 100))

In [None]:
with open('chains_all.json', 'w') as f:
    json.dump(new_data, f, indent=4)

### Length across all of occurrences (not types)

In [None]:
token_lens = []
for dial in new_data:
    for constr in new_data[dial]:
        for occ in new_data[dial][constr]:
            token_lens.append(occ['Length'])

# print(np.min(token_lens), np.max(token_lens), np.mean(token_lens), np.std(token_lens), np.median(token_lens))


print('Total:', len(token_lens))

for l in range(np.min(token_lens), np.max(token_lens) + 1):
    print('{:2d}   {:5d}   {:5.2f}%'.format(l, 
                                           len([x for x in token_lens if x == l]), 
                                           len([x for x in token_lens if x == l]) / len(token_lens) * 100))


## Length of turns containing constructions

In [None]:
n_words_in_target_turn = []

for dial_id in data:
    for exp in data[dial_id]:
        prev_turn = None
        for instance in data[dial_id][exp]:
            turn = contexts[dial_id][instance['CurrentTurn']]
            text = turn[2]
            
            if instance['CurrentTurn'] == prev_turn:
                continue
            else:
                n_words_in_target_turn.append(len(text.split(' ')))
                prev_turn = instance['CurrentTurn']


In [None]:
print(np.mean(n_words_in_target_turn), np.std(n_words_in_target_turn), np.median(n_words_in_target_turn), 
      np.min(n_words_in_target_turn), np.max(n_words_in_target_turn))

fig = sns.histplot(n_words_in_target_turn, log_scale=True, bins=6)
fig.set(xlabel='Log number of words in a turn containing a construction')
plt.show(fig)


## Number of constructions in a dialogue

In [None]:
n_in_dial = list(ds.n_tokens_in_dial.values())
np.min(n_in_dial), np.max(n_in_dial), np.mean(n_in_dial), np.std(n_in_dial), np.median(n_in_dial)

In [None]:
constr_freqs = []
for dial in data:
    for constr in data[dial]:
        freq_ = data[dial][constr][0]['Frequency']
        assert freq_ == len(data[dial][constr])
        constr_freqs.append(freq_)

np.min(constr_freqs), np.max(constr_freqs), np.mean(constr_freqs), np.std(constr_freqs), np.median(constr_freqs)

## Total number of constructions (tokens and types)

In [None]:
types = set()
n_tokens = 0
for dial in data:
    for constr in data[dial]:
        types.add(constr)
        n_tokens += len(data[dial][constr])


In [None]:
n_tokens, len(types)