-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
196 lines (176 loc) · 7.87 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# coding: utf-8
import sys
import os
import time
import joblib
import argparse
import pandas as pd
import numpy as np
from gensim import utils
from gensim.parsing.preprocessing import preprocess_string, remove_stopwords
from sklearn.datasets import fetch_20newsgroups
# NLP TOOLS #
# A custom stopwords remover based on gensim's but allowing non-english languages + special stopwords
def my_remove_stopwords(s, language):
if language == 'english':
return remove_stopwords(s)
path = "./datasets/stopwords-{}.txt".format(language)
if not os.path.exists(path):
print("{} is not a built-in language yet. Please provide '{}' file containing appropriate stopwords \
(one word by line, lower case).".format(language.capitalize(), path))
sys.exit(1)
with open(path, encoding='utf-8') as f:
stopwords = f.read().splitlines()
specials = "./datasets/special-stopwords.txt"
if os.path.exists(specials):
with open(specials, encoding='utf-8') as f:
specials = f.read().splitlines()
stopwords = set(stopwords + specials)
s = utils.to_unicode(s)
return " ".join(w for w in s.split() if w.lower() not in stopwords)
# From a sparse transformed corpus of gensim, i.e. [(0, 12), (1, 15)], return matrix format: [12, 15].
def transcorp2matrix(transcorp, bow_corpus, vector_size):
x = np.zeros((len(bow_corpus), vector_size))
for i, doc in enumerate(transcorp):
for topic in doc:
x[i][topic[0]] = topic[1]
return x
# DATA IMPORT TOOLS #
# Load user's .csv file data set or 20News data set
def load_corpus(datafile, embedding, preprocess=True, language='english'):
corpus, slices, data = None, None, None
if datafile == '20News':
source = fetch_20newsgroups(subset='all', remove=('headers', 'footers')) # , 'quotes'
res = pd.Series(source.data, name='res')
if preprocess:
print("Pre-processing text...")
corpus = [preprocess_string(remove_stopwords(x)) for x in res]
else:
corpus = [x.split() for x in res]
elif not os.path.exists(datafile):
print("No such file: '{}'".format(datafile))
sys.exit(1)
elif not datafile[-4:] == '.csv':
print("Currently supported inputs: '20News' or .csv file containing a column called 'text'.")
sys.exit(1)
else:
data = pd.read_csv(datafile, encoding='utf-8').dropna()
if 'text' not in data.columns:
print("Column containing text must be called 'text'. Please check your datafile format.")
sys.exit(1)
if preprocess:
print("Pre-processing text...")
corpus = [preprocess_string(my_remove_stopwords(x, language)) for x in data['text']]
else:
corpus = [x.split() for x in data['text'].tolist()]
if 'toremove' in data.columns:
rm = [preprocess_string(my_remove_stopwords(' '.join(x), language)) for x in data['toremove'].apply(eval)]
corpus = [[y for y in x if y not in rm[i]] for i, x in enumerate(corpus)]
if embedding == 'DTM':
if datafile == '20News':
print("DTM cannot be used with '20News' datafile as time information is not provided.")
sys.exit(1)
elif datafile[-4:] == '.csv':
if 'year' in data.columns:
slices = data['year'].value_counts().sort_index().tolist()
else:
print("DTM cannot be used with this datafile as time information is required.\
Try .csv file with 'year' column.")
sys.exit(1)
else:
print("DTM cannot be used with this datafile as time information is required.\
Try .csv file with 'year' column.")
sys.exit(1)
return corpus, slices
# Loading labels from user's file or 20News
def load_labels(datafile, embedding, vector_size):
if datafile == '20News':
source = fetch_20newsgroups(subset='all', remove=('headers', 'footers'))
y = pd.Series(source.target, name='label')
elif embedding == 'STM' or embedding == 'CTM':
y = np.loadtxt('./external/raw_embeddings/tmp_{}_LABELS_{}.csv'.format(embedding, vector_size))
elif not os.path.exists(datafile):
print("No such file: '{}'".format(datafile))
sys.exit(1)
elif not datafile[-4:] == '.csv':
print("Currently supported inputs: '20News' or .csv file containing a column called 'label'.")
sys.exit(1)
else:
data = pd.read_csv(datafile, encoding='utf-8').dropna()
if 'label' not in data.columns:
print("Column called 'label' is required to perform classification task.")
sys.exit(1)
else:
y = data['label']
return y
# Loading embeddings computed in Step 1
def load_embeddings(project, embedding, k):
filename = './results/{}/embeddings/{}_embedding_{}.csv'.format(project, embedding, k)
try:
res = np.genfromtxt(filename, delimiter=',')
return res
except OSError:
print("The embedding you're trying to work with is not computed. \
Try to use the same command with '-mode encode' first (instead of '-mode classify') to compute it.")
sys.exit(1)
# Loading models used in Step 1
def load_model(project, embedding, algo, k):
filename = './results/{}/classifiers/{}_{}_{}.joblib'.format(project, embedding, algo, k)
try:
model = joblib.load(filename)
return model
except OSError:
print("The classifier you're trying to work with is not computed. \
Try to use the same command with '-mode classify' first (instead of '-mode interpret') to compute it.")
sys.exit(1)
# A parser to get user's inputs
def read_options():
parser = argparse.ArgumentParser()
parser.add_argument('-mode',
choices=['all', 'encode', 'classify', 'interpret'],
required=True,
help="Step you want to perform (can be 'all').")
parser.add_argument('-input',
type=str,
required=True,
help="Path to your .csv input file, or '20News'.")
parser.add_argument('-embed',
choices=['BOW', 'DOC2VEC', 'POOL', 'BOREP', 'LSA', 'LDA', 'HDP', 'DTM', 'STM', 'CTM', 'PTM', 'BERT'],
required=True,
help='Embedding to use.')
parser.add_argument('-project',
type=str,
default=time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()),
help='Name of the project where to find the results later.')
parser.add_argument('-k',
type=int,
default=200,
help='Size of your embedding vectors.')
parser.add_argument('-prep',
type=str2bool,
nargs='?',
default=True,
help='Specify if you want to pre-process text (i.e. lowercase, lemmatize...).')
parser.add_argument('-langu',
type=str,
default='english',
help='Language to use for text pre-processing.')
parser.add_argument('-algo',
choices=['LOGIT', 'NBAYES', 'ADAB', 'DTREE', 'KNN', 'ANN', 'SVM'],
default='LOGIT',
help='Classifier to use.')
parser.add_argument('-samp',
choices=['OVER', 'UNDER', 'NONE'],
default='NONE',
help='Sampling to use to prevent imbalanced data sets.')
return parser.parse_args()
# Converts string to corresponding boolean
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')