In [9]:
from main import *
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
conf = SparkConf().setAppName('playground-main').setMaster('local[*]')
sc = SparkContext(conf=conf)

In [3]:
# load static resources
rules = load_rules()
(entid2name, entname2id, entname_set) = load_ent2name()
(global_word_log_cts, filter_words, stemmer) = load_word_resources()
(primes_map, ent_spell_scores, possible_spelling_ratios) = make_mispelling_resources(entname_set)

In [None]:
#################################################################################################
### Make Train Dataset
#################################################################################################

In [10]:
with open('train.txt','r') as f :
    lines = f.read().split('\n')
    lines = [line.split('\t') for line in lines]

# stores the word bag tf for each label
relationship_bags = defaultdict(dict)

all_word_cts = {}
X_words = []
labels = []
entities = []

# each line is a dataset entry
for line in lines :
    (text,entid,ent,rel_type) = process_dataset_line(line,entid2name)
    if ent is None : # some entities are oddly not found in the entity id to name dataset
        continue
        
    # get non-entity words from sentence
    text_minus_ent = replace_complete_word(text,ent,'')
    words = text_minus_ent.split()
    words = set([stemmer.stem(word) for word in words])
    
    # track dataset word frequencies
    add_list_to_count_dict(all_word_cts,words)
    
    # put words in weighted bag of words
    word_weights = dict( zip(words,[1.0/len(words) for i in range(len(words))]) )
    relationship_bags[rel_type] = outer_join_dicts(op.add, relationship_bags[rel_type], word_weights)
    
    # write dataset
    X_words.append(words)
    labels.append(rel_type)
    entities.append(entid)
    
# normalize each bag to total weight 1
for rel in relationship_bags :
    relationship_bags[rel] = normalize_dict(relationship_bags[rel])

In [11]:
# get rid of rare words to make learning more tractable since I only have a CPU
filtered_relationship_bags = {}
for rel in relationship_bags :
    filtered_relationship_bags[rel] = threshold_bag(relationship_bags[rel],0.03)

In [12]:
# make map of word to id for vectorizing inputs
all_words = set([word for group in filtered_relationship_bags.values() for word in group])
word2id = list_to_id_map(all_words,0)

In [13]:
# make nclass*nword matrix of class-wide tf values
labeli = 1
label2id = {}
tf_matrix = []
for reltype in filtered_relationship_bags :
    label2id[reltype] = labeli
    labeli += 1
    relationship_bag_dense = bag_to_dense_vector(filtered_relationship_bags[reltype],word2id,len(word2id))
    tf_matrix.append(relationship_bag_dense)
write_2d_matrix_as_csv('tf_matrix.csv',tf_matrix)

In [14]:
# make vectorized dataset for torch

X = [sparse_vector_to_dense_vector(words,word2id,len(word2id)) for words in X_words]
write_2d_matrix_as_csv('X_train.csv',X[:10000])
write_2d_matrix_as_csv('X_val.csv',X[10000:15000])

y = [label2id[label] for label in labels]
write_vector_as_csv('y_train.csv',y[:10000])
write_vector_as_csv('y_train.csv',y[10000:15000])

filters = [ sparse_vector_to_dense_vector(rules[entid],label2id,len(label2id)+1)[1:] for entid in entities ]
write_2d_matrix_as_csv('filters_train.csv',filters[:10000])
write_2d_matrix_as_csv('filters_train.csv',filters[10000:15000])

In [None]:
#################################################################################################
### Make Test Dataset (More Complicated Since Dont Know Which Entity To Extract)
#################################################################################################

In [18]:
with open('test.txt','r') as f :
    lines = f.read().split('\n')
    lines = [line.split('\t') for line in lines]

with open('intermediate_info.txt','w') as f :
    for line in lines[11000:12000] :
        (text,entid,ent1,rel_type) = process_dataset_line(line,entid2name)

        # we either dont have entid in the dataset or label was not seen in the train data
        if ent1 is None or rel_type not in label2id:
            #print text
            #print ':( $no'
            #print ''
            continue
        print text,'|',ent1,'|',rel_type

        # generate list of all possible entity spellings
        grams = generate_grams(text,filter_words)
        exact_match_ents = [(x,x) for x in list(set.intersection(grams,entname_set))]

        # get possible entity mispellings
        mispelled_ents = generate_mispelled_ents(grams, exact_match_ents, global_word_log_cts, primes_map, ent_spell_scores, possible_spelling_ratios)

        # assign 1.0 score to exact matches and 0.0 to mispellings
        P_spelling = [1.0] * len(exact_match_ents) + [0.0] * len(mispelled_ents)

        predictions = []
        possible_ents = exact_match_ents + mispelled_ents
        if len(possible_ents) == 0 :
            continue
        f.write(str(label2id[rel_type])+'|')
        
        # loop possible entities extracted from text
        for ent_i in range(len(possible_ents)) :
            (true_ent_words, present_ent_words) = possible_ents[ent_i]
            
            # get score for entity rarity
            ent_score = get_ent_score(true_ent_words, global_word_log_cts)

            # process words from sentence
            text_minus_ent = replace_complete_word(text,' '.join(present_ent_words),'')
            words = set(text_minus_ent.split())
            words = set([stemmer.stem(word) for word in words])
            word_weights = uniform_normalized_bag(words)
            
            #P_rel = get_rel_scores(word_weights, relationship_bags)
            
            # get list of all relationship types attached to this entity in the rules data
            present_rels = get_present_rels(true_ent_words, entname2id, rules)
            
            # make vectorized data for torch
            data_input = sparse_vector_to_dense_vector(words,word2id,len(word2id))
            data_filter = sparse_vector_to_dense_vector(present_rels,label2id,len(label2id)+1)[1:]
            
            #for rel in present_rels :
            #    if rel not in P_rel :
            #        continue
            #    predictions.append( (' '.join(true_ent_words), rel, P_spelling[ent_i]*P_rel[rel]*ent_score) )
            
            # write dataset to file that can read in torch
            f.write(str(' '.join(true_ent_words)==ent1)+'|')
            f.write(str(P_spelling[ent_i])+'|')
            f.write(str(ent_score)+'|')
            f.write(','.join([str(x) for x in data_input])+'|')
            f.write(','.join([str(x) for x in data_filter])+'|')

        #predictions.sort(lambda p1,p2: int(p2[2]*1000-p1[2]*1000))
        #for i in range(min(5,len(predictions))) :
        #    print predictions[i]
        #if predictions[0][0]==ent1 and predictions[0][1]==rel_type :
        #    print '$yes'
        #else :
        #    print '$no'
        
        f.write('\n')

who made tonight i 'm lovin' you  | tonight | /music/recording/artist
what is the name of an album by leo delibes | lo delibes | /music/artist/album
what type of institution is fremont-elizabeth city high school | fremont-elizabeth city high school | /education/educational_institution/school_type
what type of organization is heritage foundation | heritage foundation | /organization/organization/organization_type
which time zone does sylvia belong to | sylvia | /location/location/time_zones
which country is christian edward elder from  | christian edward elder | /people/person/nationality
what song is on the release taxi | taxi | /music/release/track
what 's a game on ti-99 | texas instruments ti-994a | /cvg/cvg_platform/games_on_this_platform
what netflix genre is 1918 | 1918 | /media_common/netflix_title/netflix_genres
what 's a town in louisiana whose name starts with an n | louisiana | /location/location/contains
what is lia roberts known as | lia roberts | /common/topic/notable_typ