In [2]:
import numpy as np
import pandas as pd
import os
from collections import Counter
from itertools import chain
import math
import scipy
import json
%load_ext line_profiler

POS_OPTIONS = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ",
               "SYM","VERB","X"]

DEPS_OPTIONS = json.load(open("../data/deps_list.json", "r"))

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [3]:
import spacy
from spacy import en
import nltk

DATA_DIR = "../data"
nlp = spacy.load('en')

In [4]:
tokens = nltk.corpus.brown.words()
parsed = [nlp(unicode(y)) for y in tokens]
chained = [x for doc in parsed for x in doc]
lemmas = [x.lemma_ for x in [y for y in chained] if x.is_alpha] 
words = [x.orth_.lower() for x in [y for y in chained] if x.is_alpha]

In [5]:
lemma_counts = Counter(lemmas)
word_counts = Counter(words)

top_words = dict(word_counts.most_common(5000))
top_lemmas = dict(lemma_counts.most_common(5000))

freq_idx_lemmas = {lemma: idx for idx, (lemma, cnt) in enumerate(lemma_counts.iteritems())}

In [129]:
def human_readable(text):
    return text.replace("@@ ","")

def tokenize(text):
    return text.split(' ')

def split_by_phrases(tokens):
    res = []
    curr_phrase = []
    for token in tokens:
        if token.endswith("speaker>"):
            if curr_phrase:
                res.append(curr_phrase)
                curr_phrase = []
        else:
            curr_phrase.append(token)
    
    if curr_phrase:
        res.append(curr_phrase)
    
    return res

def parse_and_flatten(tokens):
    cleaned = [token for phrase in split_by_phrases(tokens) for token in phrase if not token.startswith("<")]
    parsed = [nlp(unicode(x)) for x in cleaned]
    return [x for doc in parsed for x in doc]

def join_sentence(tokens):
    phrases = split_by_phrases(tokens)
    filtered = map(lambda x: [y for y in x if not y.startswith("<")], phrases)
    glued = map(lambda x: " ".join(x), filtered)

    return unicode(".".join(glued))

In [146]:
train = pd.DataFrame.from_csv(os.path.join(DATA_DIR, "train.small.tsv"), sep='\t')
train["context"] = train["context"].map(human_readable).map(tokenize)
train["response"] = train["response"].map(human_readable).map(tokenize)

train["nlp_context"] = train["context"].map(lambda x: nlp(join_sentence(x)))
train["nlp_response"] = train["response"].map(lambda x: nlp(join_sentence(x)))

In [8]:
def basestat(f, label):
    def target(context, response):
        response_value = f(response)
        context_value = f(context)
        ratio = float(response_value)/context_value if context_value != 0 else 0
        return pd.Series([context_value,response_value,ratio],
                        index=["context_{}".format(label), "response_{}".format(label), "ratio_{}".format(label)])
    
    return target

In [151]:
def basestat_series(f):
    def target(context, response):
        if context and response:
            response_ser = f(response)
            context_ser = f(context)

            response_ser.index = map(lambda x: "response_{}".format(x), response_ser.index)
            context_ser.index = map(lambda x: "context_{}".format(x), context_ser.index)

            return response_ser.append(context_ser)
        else:
            return pd.Series()
    
    return target

## General features

In [9]:
def lenstat(context, response):
    return basestat(len, "length")(context, response)

## Morphological features

In [124]:
def stopwords_stat(context, response):
    def stopword_count(text):
        return len(filter(lambda x: x.lower() in en.STOP_WORDS, text))
    
    return basestat(stopword_count, "stopwords")(context, response)

In [125]:
def freq_stat(parsed_context, parsed_response):
    def topN_count(parsed_text):
        return len(filter(lambda x: x.orth_.lower() in top_words, parsed_text))
    
    def topN_count_lemma(parsed_text):
        return len(filter(lambda x: x.lemma_ in top_lemmas, parsed_text))
    
    def no_vocab_tokens(parsed_text):
        return len(filter(lambda x: x.lemma_ not in lemma_counts, parsed_text))
    
    def avg_index_lemma(parsed_text):
        freqs = [freq_idx_lemmas.get(token.lemma_,None) for token in parsed_text]
        return np.mean([math.log(x) for x in filter(lambda x: x, freqs)])
    
    res = pd.Series()
    
    if parsed_context and parsed_response:
        res = res.append(basestat(topN_count, "topN_tokens_count")(parsed_context, parsed_response))
        res = res.append(basestat(lambda x: float(topN_count(x))/len(x), "topN_tokens_count_relative")(parsed_context, parsed_response))
        res = res.append(basestat(topN_count_lemma, "topN_count_lemma")(parsed_context, parsed_response))
        res = res.append(basestat(lambda x: float(topN_count_lemma(x))/len(x), "topN_count_lemma_relative")(parsed_context, parsed_response))
        res = res.append(basestat(no_vocab_tokens, "no_vocab_tokens")(parsed_context, parsed_response))
        res = res.append(basestat(lambda x: float(no_vocab_tokens(x))/len(x), "no_vocab_tokens_relative")(parsed_context, parsed_response))
        res = res.append(basestat(avg_index_lemma, "avg_index_lemma")(parsed_context, parsed_response))
    
    return res

In [12]:
def special_terms_stat(context, response):
    def spec_count(term, text):
        return len([x for x in text if x == "<{}>".format(term)])
    
    res = pd.Series()
    res = res.append(basestat(lambda x: spec_count("at", x), "mentions")(context, response))
    res = res.append(basestat(lambda x: spec_count("number", x), "numbers")(context, response))
    res = res.append(basestat(lambda x: spec_count("url", x), "links")(context, response))
    return res

In [127]:
def pos_stats(parsed_context, parsed_response):
    def tag(parsed):
        
        if parsed:
            counter = Counter()
            for token in parsed:
                counter[token.pos_] += 1

            abs_counts = pd.Series(data = [counter[pos] for pos in POS_OPTIONS],
                                  index = ["pos_{}_abs".format(pos) for pos in POS_OPTIONS])

            relative_counts = pd.Series(data = [float(counter[pos])/len(parsed) for pos in POS_OPTIONS],
                                  index = ["pos_{}_rel".format(pos) for pos in POS_OPTIONS])

            return abs_counts.append(relative_counts)
        else:
            return pd.Series()
    
    return basestat_series(tag)(parsed_context, parsed_response)


## Syntax features

In [128]:
def find_root(sentence):
    for token in sentence:
        if token.dep_ == "ROOT":
            return token
    
    raise ValueError()
    
def depth(node):
    return 1 + np.max([depth(child) for child in node.children])

def depth_up(node):
    if node.dep_ == "ROOT":
        return 0
    
    return 1 + depth_up(node.head)

In [117]:
def desctiptive_stats(arr):
    stats = scipy.stats.describe(arr).__dict__
    min_val, max_val = stats["minmax"]
    stats["min"] = min_val
    stats["max"] = max_val
    stats["median"] = np.median(arr)
    del stats["minmax"]
    
    ser = pd.Series.from_array(stats)
    return ser

In [182]:
def syntax_stats(nlp_context, nlp_response): 
    
    def depth_stats(nlp_sentence):
        stats = desctiptive_stats([depth_up(x) for x in nlp_sentence])
        stats.index = map(lambda x: "syntax_depth_{}".format(x), stats.index)
        return stats

    def child_stats(nlp_sentence):
        stats = desctiptive_stats([len(list(x.children)) for x in nlp_sentence])
        stats.index = map(lambda x: "syntax_children_{}".format(x), stats.index)
        return stats

    def dependency_stats(nlp_sentence):
        counter = Counter()
        for token in nlp_sentence:
            for dep in token.dep_.split("||"):
                counter[dep] += 1

        abs_counts = pd.Series(data = [counter[pos] for pos in DEPS_OPTIONS],
                              index = ["syntax_{}_abs".format(pos) for pos in DEPS_OPTIONS])
        
        relative_counts = pd.Series(data = [float(counter[pos])/len(nlp_sentence) for pos in DEPS_OPTIONS],
                              index = ["syntax_{}_rel".format(pos) for pos in DEPS_OPTIONS])

        return abs_counts.append(relative_counts)

    res = pd.Series()
    res = res.append(basestat_series(depth_stats)(nlp_context, nlp_response))
    res = res.append(basestat_series(child_stats)(nlp_context, nlp_response))
    res = res.append(basestat_series(dependency_stats)(nlp_context, nlp_response))
    
    return res

In [58]:
#train_no_labels[:100].apply(lambda x: syntax_stats(*x), axis=1)

## Word2Vec

In [132]:
def word2vec_stats(nlp_context, nlp_response):
    
    def avg_vector(nlp_sentence):
        w2v = np.mean(np.array([x.vector for x in nlp_sentence]), axis=0)
        return pd.Series(data=w2v,
                        index=["w2v_{}".format(i) for i in range(0,w2v.shape[0])])

    def corr_stats(nlp_sentence):
        corrs = []
        for i in range(0, len(nlp_sentence)-1):
            for j in range(i+1,len(nlp_sentence)):
                corrs.append(nlp_sentence[i].similarity(nlp_sentence[j]))

        if corrs:
            stats = desctiptive_stats(corrs)
            stats.index = map(lambda x: "w2v_correlations_{}".format(x), stats.index)
            return stats
        else:
            return pd.Series()

    res = pd.Series()
    res = res.append(basestat_series(avg_vector)(nlp_context, nlp_response))
#     res = res.append(basestat_series(corr_stats)(nlp_context, nlp_response))
    
    return res
    
    

## Combine

In [186]:
funcs = [lenstat, stopwords_stat, special_terms_stat]
nlp_funcs = [freq_stat, pos_stats, syntax_stats, word2vec_stats]

In [187]:
def run_all(ser, funcs, nlp_funcs):
    res = pd.Series()
    for f in funcs:
        res = res.append(f(ser["context"],ser["response"]))
    
    for f in nlp_funcs:
        res = res.append(f(ser["nlp_context"],ser["nlp_response"]))
    return res

In [188]:
#%lprun -f run_all train[:100].apply(lambda x: run_all(x,funcs,nlp_funcs), axis=1)

In [195]:
df = train[100:200].apply(lambda x: run_all(x,funcs, nlp_funcs), axis=1)

In [192]:
df.to_csv("../data/features.small.csv")

In [196]:
df

Unnamed: 0_level_0,context_avg_index_lemma,context_length,context_links,context_mentions,context_no_vocab_tokens,context_no_vocab_tokens_relative,context_numbers,context_pos_ADJ_abs,context_pos_ADJ_rel,context_pos_ADP_abs,...,response_w2v_90,response_w2v_91,response_w2v_92,response_w2v_93,response_w2v_94,response_w2v_95,response_w2v_96,response_w2v_97,response_w2v_98,response_w2v_99
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
100,9.445460,23.0,0.0,0.0,4.0,0.181818,0.0,1.0,0.045455,2.0,...,-0.025873,0.094332,-0.061377,0.156775,0.095398,0.158056,-0.026449,-0.041327,-0.232958,-0.049234
101,9.521465,16.0,0.0,1.0,3.0,0.250000,0.0,0.0,0.000000,0.0,...,-0.212550,-0.096326,-0.095313,0.067674,0.092909,0.167461,0.045127,-0.100422,-0.110122,0.003493
102,9.373579,38.0,1.0,2.0,3.0,0.100000,0.0,2.0,0.066667,2.0,...,-0.072232,0.080743,0.073073,0.072042,0.011515,-0.005366,-0.151322,-0.095845,0.053204,0.149230
103,9.120075,41.0,0.0,4.0,8.0,0.266667,0.0,3.0,0.100000,2.0,...,-0.018270,0.049539,-0.173140,-0.063550,0.053127,-0.173004,-0.023782,0.047856,-0.213482,0.056170
104,9.081136,38.0,0.0,1.0,5.0,0.142857,0.0,1.0,0.028571,4.0,...,-0.100851,0.000924,-0.052806,0.054733,0.037270,-0.045845,-0.052774,-0.137185,-0.040650,0.051717
105,8.447171,57.0,0.0,6.0,35.0,0.813953,0.0,2.0,0.046512,2.0,...,-0.171655,-0.012006,-0.009915,-0.007866,0.029608,0.117592,-0.035336,-0.000342,-0.069520,-0.040705
106,9.482224,11.0,0.0,0.0,4.0,0.400000,0.0,0.0,0.000000,1.0,...,-0.036011,-0.033962,-0.028033,0.050998,0.199511,0.005663,0.000123,-0.015097,-0.083666,0.109085
107,9.323169,105.0,0.0,6.0,34.0,0.369565,0.0,5.0,0.054348,4.0,...,0.063002,0.276403,0.175363,0.119677,-0.184731,-0.008864,-0.365806,-0.241272,-0.062670,-0.044146
108,9.409440,48.0,0.0,3.0,6.0,0.146341,0.0,2.0,0.048780,4.0,...,-0.045132,0.077311,0.004194,0.073569,-0.031825,-0.063781,-0.129833,-0.074753,-0.094611,0.086981
109,8.871454,44.0,0.0,3.0,13.0,0.361111,0.0,4.0,0.111111,0.0,...,-0.126070,-0.105178,-0.096401,0.014453,0.269048,-0.005560,0.105984,0.003253,-0.152837,-0.079070
