This repository has been archived by the owner on Feb 8, 2018. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 13
/
prepare_data.py
125 lines (92 loc) · 3.71 KB
/
prepare_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gzip
import pickle as pkl
import argparse
from tqdm import tqdm
from random import randint
from collections import Counter
import sys
import collections
import gensim
import logging
import spacy
import itertools
import re
import json
import torch
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) #gensim logging
def count_lines(file):
count = 0
for _ in file:
count += 1
file.seek(0)
return count
def build_dataset(args):
def preprocess(datas):
for data in datas:
yield (data['reviewText'],max(1,int(round(float(data["overall"]))))-1) #zero is useless, classes between 0-4 for 1-5 reviews
def preprocess_rescale(datas):
for data in datas:
rating = max(1,int(round(float(data["overall"]))))-1
if rating > 3:
rating = 1
elif rating == 3:
yield None
continue
else:
rating = 0
yield (data['reviewText'],rating) #zero is useless
def data_generator(data):
with gzip.open(args.input,"r") as f:
for x in tqdm(f,desc="Reviews",total=count_lines(f)):
yield json.loads(x)
class TokIt(collections.Iterator):
def __init__(self, tokenized):
self.tok = tokenized
self.x = 0
self.stop = len(tokenized)
def __iter__(self):
return self
def next(self):
if self.x < self.stop:
self.x += 1
return list(w.orth_ for w in self.tok[self.x-1] if len(w.orth_.strip()) >= 1 ) #whitespace shouldn't be a word.
else:
self.x = 0
raise StopIteration
__next__ = next
print("Building dataset from : {}".format(args.input))
print("-> Building {} random splits".format(args.nb_splits))
nlp = spacy.load('en')
tokenized = [tok for tok in tqdm(nlp.tokenizer.pipe((x["reviewText"] for x in data_generator(args.input)),batch_size=10000, n_threads=8),desc="Tokenizing")]
if args.create_emb:
w2vmodel = gensim.models.Word2Vec(TokIt(tokenized), size=args.emb_size, window=5, min_count=5, iter=args.epochs, max_vocab_size=args.dic_size, workers=4)
print(len(w2vmodel.wv.vocab))
w2vmodel.wv.save_word2vec_format(args.emb_file,total_vec=len(w2vmodel.wv.vocab))
if args.rescale:
print("-> Rescaling data to 0-1 (3's are discarded)")
data = [dt for dt in tqdm(preprocess_rescale(data_generator(args.input)),desc="Processing") if dt is not None]
else:
data = [dt for dt in tqdm(preprocess(data_generator(args.input)),desc="Processing")]
splits = [randint(0,args.nb_splits-1) for _ in range(0,len(data))]
count = Counter(splits)
print("Split distribution is the following:")
print(count)
return {"data":data,"splits":splits,"rows":("review","rating")}
def main(args):
ds = build_dataset(args)
pkl.dump(ds,open(args.output,"wb"))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str)
parser.add_argument("output", type=str, default="sentences.pkl")
parser.add_argument("--rescale",action="store_true")
parser.add_argument("--nb_splits",type=int, default=5)
parser.add_argument("--create-emb",action="store_true")
parser.add_argument("--emb-file", type=str, default=None)
parser.add_argument("--emb-size",type=int, default=100)
parser.add_argument("--dic-size", type=int,default=10000000)
parser.add_argument("--epochs", type=int,default=1)
args = parser.parse_args()
if args.emb_file is None:
args.emb_file = args.output + "_emb.txt"
main(args)