-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
128 lines (103 loc) · 3.52 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import re
import csv
import string
import nltk
from nltk.stem import *
try:
from nltk.corpus import stopwords
except:
nltk.download('stopwords')
from nltk.corpus import stopwords
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset
TRAIN_FILE = 'dataset/Train.csv'
VALID_FILE = 'dataset/Valid.csv'
TEST_FILE = 'dataset/Test.csv'
stemmer = PorterStemmer()
STOPS = set(stopwords.words('english'))
def is_stopword(word):
# STOP_PAT = re.compile('^[a-zA-Z]{2,}$')
return word in STOPS
def data_file(data):
assert data in ['train', 'valid', 'test'], f'invalid dataset: {data}'
if data == 'train':
return TRAIN_FILE
if data == 'valid':
return VALID_FILE
if data == 'test':
return TEST_FILE
class IMDBDataset(Dataset):
def __init__(self, data, data_limit=None, balanced_limit=False, load_class='both', stemmer=False):
neg, pos = [], []
with open(data_file(data), 'r', encoding="utf8") as file:
reader = csv.reader(file)
next(reader)
for idx, line in enumerate(reader):
line[0] = line[0].translate(str.maketrans('', '', string.punctuation))
words = line[0].lower().split()
if stemmer:
words = [stemmer.stem(word) for word in line[0].split() if not is_stopword(word)]
else:
words = [word for word in words if not is_stopword(word)]
if line[1] == '1':
if not balanced_limit:
pos.append(words)
else:
if len(pos) < (data_limit // 2):
pos.append(words)
else:
if not balanced_limit:
neg.append(words)
else:
if len(neg) < (data_limit // 2):
neg.append(words)
if data_limit is not None:
if (idx + 1) >= data_limit:
break
self.pos = pos
self.neg = neg
if load_class == 'both':
self.x = self.pos + self.neg
self.y = [1 for _ in range(len(pos))] + [0 for _ in range(len(neg))]
if load_class == 'pos':
self.x = self.pos
self.y = [1 for _ in range(len(pos))]
if load_class == 'neg':
self.x = self.neg
self.y = [0 for _ in range(len(neg))]
def __len__(self):
return len(self.x)
def __getitem__(self, i):
return self.x[i], self.y[i]
def shuffle(self):
order = np.arange(len(self))
np.random.shuffle(order)
order = order.astype(np.int)
self.x = self.x[order]
self.y = self.y[order]
def load_LDA_data(model, dataset, include_llk=False):
lda_x = []
llk_x = []
for x, _ in dataset:
x = model.make_doc(x)
x, llk = model.infer(x)
lda_x.append(x)
llk_x.append(llk)
lda_x = np.array(lda_x)
llk_x = np.array(llk_x)
lda_y = np.array(dataset.y)
if include_llk is False:
return lda_x, lda_y
else:
return lda_x, llk_x, lda_y
def load_LDA_data_batch(model, dataset, include_llk=False):
docs = [model.make_doc(x) for (x, _) in dataset]
lda_x, llk = model.infer(docs)
lda_x = np.array(lda_x)
llk = np.array(llk)
lda_y = np.array(dataset.y)
if include_llk is False:
return lda_x, lda_y
else:
return lda_x, llk, lda_y