-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9cd9ae8
commit b044995
Showing
5 changed files
with
578 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
def batcher(file_input, batch_size): | ||
import numpy as np | ||
import json | ||
|
||
data = json.load(open(file_input)) | ||
|
||
list_of_dict = [] | ||
|
||
for sents in data: | ||
ip_array = np.fromstring(sents['original'].strip(), dtype = int, sep= ' ') | ||
op_array = np.fromstring(sents['annotated'].strip(), dtype = int, sep= ' ') | ||
try: | ||
list_of_dict.append({ | ||
'original': ip_array, | ||
'annotated': op_array, | ||
'offsets': sents['offsets'], | ||
'doc_offset': sents['doc_offset'], | ||
'stems': sents['stems'], | ||
'pos': sents['pos'] | ||
}) | ||
except Exception as e: | ||
print e | ||
|
||
buckets = [[w for w in list_of_dict if w['original'].shape[0] == num] for num in set(i['original'].shape[0] for i in list_of_dict)] | ||
|
||
final_ip_list = [] | ||
final_op_list = [] | ||
final_offset_list = [] | ||
final_doc_offset_list = [] | ||
final_stem_list = [] | ||
final_pos_list = [] | ||
for ele in buckets: # ele is a list of json | ||
temp_ele = [ele[i:i + batch_size] for i in xrange(0, len(ele), batch_size)] | ||
for elem in temp_ele: | ||
ip_arr2d = np.array([elems['original'] for elems in elem]) | ||
op_arr2d = np.array([elems['annotated'] for elems in elem]) | ||
final_ip_list.append(ip_arr2d) | ||
final_op_list.append(op_arr2d) | ||
temp_offset_list = [] | ||
temp_doc_offset_list = [] | ||
temp_stem_list = [] | ||
temp_pos_list = [] | ||
for elems in elem: | ||
temp_offset_list.append(elems['offsets']) | ||
temp_doc_offset_list.append(elems['doc_offset']) | ||
temp_stem_list.append(elems['stems']) | ||
temp_pos_list.append(elems['pos']) | ||
|
||
final_offset_list.append(temp_offset_list) | ||
final_doc_offset_list.append(temp_doc_offset_list) | ||
final_stem_list.append(temp_stem_list) | ||
final_pos_list.append(temp_pos_list) | ||
|
||
return zip(final_ip_list, final_op_list, final_offset_list, final_doc_offset_list, final_stem_list, final_pos_list) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# coding: utf-8 | ||
from collections import Counter | ||
|
||
pos_dict = {'NOUN':'n', 'PROPN':'n', 'VERB':'v', 'AUX':'v', 'ADJ':'a', 'ADV':'r'} | ||
|
||
def build_mfs(batches, use_stem=True, return_lemma_freq=False): | ||
all_words = [] | ||
all_lemmas = [] | ||
word_dict = {} | ||
word_freq = {} | ||
for batch_idx, batch in enumerate(batches): | ||
#print batch | ||
x,y,a,tag, stem, POStag = batch | ||
bsz = x.shape[0] | ||
seqlen = x.shape[1] | ||
x,y = x.T, y.T | ||
#print (x.shape, y.shape) | ||
for i in range(bsz): | ||
xi = x[:,i] | ||
yi = y[:,i] | ||
ai = a[i] | ||
stemi = stem[i] | ||
POStagi = POStag[i] | ||
#print (xi, yi, ai, stemi, POStagi) | ||
for pos_idx, pos in enumerate(ai): | ||
xidx = xi[pos] | ||
yidx = yi[pos] | ||
stemidx = stemi[pos_idx] | ||
POStagidx = POStagi[pos_idx] | ||
#use xidx, or (stemidx, POStagidx) to generate MFS | ||
if use_stem: | ||
key = (stemidx, POStagidx) | ||
else: | ||
key = xidx | ||
all_words.append(key) | ||
all_lemmas.append(yidx) | ||
if key not in word_dict: | ||
word_dict[key] = [yidx] | ||
else: | ||
word_dict[key].append(yidx) | ||
all_lemmas = Counter(all_lemmas) | ||
lcnt = sum(all_lemmas[i] for i in all_lemmas) | ||
for e in all_lemmas: | ||
all_lemmas[e] = all_lemmas[e] * 1.0/lcnt | ||
for word in word_dict: | ||
wC = Counter(word_dict[word]) | ||
word_dict[word] = wC.most_common(1)[0][0] | ||
wcnt_sum = sum([wC[i] for i in wC]) | ||
wC = {i:wC[i]*1.0/wcnt_sum for i in wC} | ||
word_freq[word] = wC | ||
#print (word_freq[word]) | ||
if return_lemma_freq: | ||
return word_dict, word_freq, all_lemmas | ||
return word_dict, word_freq |
Oops, something went wrong.