-
Notifications
You must be signed in to change notification settings - Fork 9
/
preprocess.py
101 lines (84 loc) · 2.92 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gzip
import json
import re
import string
import unicodedata
from collections import defaultdict
from pathlib import Path
import click
from sentencepiece import SentencePieceProcessor
from tqdm import tqdm
PRINTABLE = set(string.printable)
MIN_REV_LEN = 4
MAX_REV_LEN = 128
def strip_text(s: str) -> str:
# https://stackoverflow.com/a/518232/2809427
# https://stackoverflow.com/a/8689826
return re.sub(" +", " ", "".join(c for c in unicodedata.normalize("NFD", s)
if unicodedata.category(c) != "Mn" and c in PRINTABLE).replace("\n", " "))
def yelp(file_path: str, spm: SentencePieceProcessor = None):
d = defaultdict(list)
for ins in tqdm(map(json.loads, open(file_path)), desc="Yelp"):
rating = int(ins["stars"])
text = strip_text(ins["text"])
x = {"business_id": ins["business_id"],
"review_id": ins["review_id"],
"rating": rating,
"text": text}
if spm is not None:
piece = spm.Encode(text)
if MIN_REV_LEN <= len(piece) <= MAX_REV_LEN:
x["piece"] = piece
d[ins["business_id"]].append(x)
else:
d[ins["business_id"]].append(x)
for reviews in d.values():
if len(reviews) > 10:
yield from reviews
def amzn(dir_path: str, spm: SentencePieceProcessor = None):
p = tqdm()
obs = set()
for fp in Path(dir_path).glob("*.gz"):
p.set_description(desc=fp.stem)
d = defaultdict(list)
for ins in filter(lambda x: x["asin"] not in obs, map(json.loads, gzip.open(fp, "rb"))):
text = strip_text(ins["reviewText"])
rating = int(float(ins["overall"]))
review_id = ins["reviewerID"]
x = {"business_id": ins["asin"],
"review_id": review_id,
"rating": rating,
"text": text}
if spm is not None:
piece = spm.Encode(text)
if MIN_REV_LEN <= len(piece) <= MAX_REV_LEN:
x["piece"] = piece
d[ins["asin"]].append(x)
else:
d[ins["asin"]].append(x)
p.update()
for reviews in d.values():
if len(reviews) > 10:
yield from reviews
obs.update(set(d))
p.close()
@click.command()
@click.argument("data_type", type=click.Choice(("yelp", "amzn")), )
@click.argument("raw_file", type=click.Path(exists=True))
def main(data_type, raw_file):
spm_file = Path(f"./data/sentencepiece/{data_type}.model")
if spm_file.exists():
spm = SentencePieceProcessor()
spm.Load(str(spm_file))
else:
spm = None
if data_type == "yelp":
parser = yelp
elif data_type == "amzn":
parser = amzn
else:
raise KeyError()
for x in parser(raw_file, spm):
print(json.dumps(x))
if __name__ == '__main__':
main()