/
prepare_data.py
93 lines (75 loc) · 2.07 KB
/
prepare_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
This needs to import "data" from https://github.com/bazingagin/npc_gzip
Save all datasets as pickle files to $outdir/$name.pkl
data['train_data'] : [str]
data['test_data'] : [str]
data['train_labels'] : ndarray (n_train,) dtype=uint32
data['test_labels'] : ndarray (n_test,) dtype=uint32
"""
from data import (
load_kinnews,
load_kirnews,
load_filipino,
load_swahili,
load_20news,
)
import torchtext.datasets
import os
import pickle
import numpy as np
import argparse
#needed to add for SogouNews dataset:
import sys
import csv
print("SET csv.field_size_limit:", sys.maxsize)
csv.field_size_limit(sys.maxsize)
def load_torch(name):
Cls = getattr(torchtext.datasets,name)
return Cls(root="data")
def main():
"""
"""
parser = argparse.ArgumentParser()
parser.add_argument(
'--outdir',
help = "write $outdir/$name.pkl")
args = parser.parse_args()
outdir = args.outdir
# construct list of (name, lambda : data)
DSS = []
for name in (
"AG_NEWS",
"DBpedia",
"YahooAnswers",
):
DSS.append((name,lambda : load_torch(name)))
DSS.append(('20News', load_20news))
#ohsumed
#R8
#R52
DSS.extend([
('kinnews', load_kinnews),
('kirnews', load_kirnews),
('filipino',load_filipino),
('swahili', load_swahili),
])
name = "SogouNews"
DSS.append((name,lambda : load_torch(name)))
for name,fn in DSS:
tr,te = fn()
#unpack generators:
tr = list(tr)
te = list(te)
outfile = os.path.join(outdir,name+".pkl")
dtype = 'uint32'
print(name,"tr,te:", (len(tr), len(te)))
pickle.dump({
'train_data': [t for (l,t) in tr],
'test_data': [t for (l,t) in te],
'train_labels': np.array([l for (l,t) in tr],dtype),
'test_labels': np.array([l for (l,t) in te],dtype),
},
open(outfile,'wb'))
print("wrote:",outfile)
if __name__ == "__main__":
main()