In [34]:
import pickle
import csv
import os
import pandas as pd
import numpy as np
from collections import defaultdict

In [41]:
def open_pkl(dir):
    with open(dir, 'rb') as file:
        return pickle.load(file)
    
def open_word_list_csv(csv_path):
    with open(csv_path, newline='\n') as csv_file:
        word_list = []
        csv_contents = csv.DictReader(csv_file, delimiter=',', quotechar='|')
        for row in csv_contents:
            word_list.append(row)
        return word_list

In [130]:
data_file = "../../../Data/model_datasets/spa/train_vocab_size_5000.pkl"
vocab_file = "../../../Data/model_datasets/spa/encoding_dictionary_vocab_size_5000.pkl"
word_list_file = "../../../Analyses/data/surprisal-and-frequency/lstm-surprisals/Spanish (Mexican)/all_child_directed_data_spanish_(mexican)_singleword_average_surprisal_perplexity_run0.csv"

In [131]:
data = open_pkl(data_file)
vocab = open_pkl(vocab_file)

In [132]:
word_list = open_word_list_csv(word_list_file)

In [123]:
def cooccurrence_matrix(sentences, window_size):
    d = defaultdict(int)
    vocab = set()
    for s in sentences:
        for i in range(len(s)):
            if not s[i] == 0:
                token = s[i]
                vocab.add(token)  # add to vocab
                next_token = s[i+1 : i+1+window_size]
                for t in next_token:
                    key = tuple( sorted([t, token]) )
                    d[key] += 1
    # formulate the dictionary into dataframe
    vocab = sorted(vocab) # sort vocab
    df = pd.DataFrame(data=np.zeros((len(vocab), len(vocab)), dtype=np.int16),
                      index=vocab,
                      columns=vocab)
    for key, value in d.items():
        df.at[key[0], key[1]] = value
        df.at[key[1], key[0]] = value
    return df

In [124]:
matrix = cooccurrence_matrix(data, 5)

In [125]:
def get_contextual_diversity(word_list, vocab, matrix):
    def get_word_score(word):
        index = vocab[word]
        try:
            row = matrix[index]
            cd = 0
            for i in row:
                if not i == 0:
                    cd+=1
        except:
            cd = 'NA'
        return cd
    cd_word_list = []
    for item in word_list:
        word = item['word_clean']
        language = item['language']
        cd_score = get_word_score(word)
        cd_word_list.append({'language': language, 'word_clean':word, 'cd_score':cd_score})
    return cd_word_list
    

In [133]:
cd_word_list = get_contextual_diversity(word_list, vocab, matrix)

In [134]:
cd_word_list

[{'language': 'Spanish (Mexican)', 'word_clean': 'un', 'cd_score': 1591},
 {'language': 'Spanish (Mexican)', 'word_clean': 'una', 'cd_score': 1186},
 {'language': 'Spanish (Mexican)', 'word_clean': 'unas', 'cd_score': 152},
 {'language': 'Spanish (Mexican)', 'word_clean': 'unos', 'cd_score': 215},
 {'language': 'Spanish (Mexican)', 'word_clean': 'mucho', 'cd_score': 404},
 {'language': 'Spanish (Mexican)', 'word_clean': 'otro', 'cd_score': 562},
 {'language': 'Spanish (Mexican)', 'word_clean': 'avión', 'cd_score': 15},
 {'language': 'Spanish (Mexican)', 'word_clean': 'todo', 'cd_score': 494},
 {'language': 'Spanish (Mexican)', 'word_clean': 'ya', 'cd_score': 1114},
 {'language': 'Spanish (Mexican)', 'word_clean': 'también', 'cd_score': 744},
 {'language': 'Spanish (Mexican)', 'word_clean': 'enojado', 'cd_score': 7},
 {'language': 'Spanish (Mexican)', 'word_clean': 'animal', 'cd_score': 53},
 {'language': 'Spanish (Mexican)', 'word_clean': 'manzana', 'cd_score': 21},
 {'language': 'Span

In [128]:
def save_cd_csv(cd_word_list, experiment_dir):
    language = cd_word_list[0]['language']
    file_name = language + '_contextual_diversity_scores5000.csv'
    with open(os.path.join(experiment_dir, file_name), mode='w') as csv_file:
        writer = csv.writer(csv_file, delimiter=',')
        writer.writerow(['language', 'word_clean', 'cd_score'])
        for word in cd_word_list:
            writer.writerow([word['language'], word['word_clean'], str(word['cd_score'])])

In [135]:
save_cd_csv(cd_word_list, "./contextual-diversity")