This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
/
safety.py
281 lines (245 loc) · 8.61 KB
/
safety.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Utility functions and classes for detecting offensive language.
"""
from parlai.agents.transformer.transformer import TransformerClassifierAgent
from parlai.core.agents import create_agent, create_agent_from_shared
from parlai.tasks.dialogue_safety.agents import OK_CLASS, NOT_OK_CLASS
from parlai.utils.typing import TShared
from parlai.utils.io import PathManager
import parlai.utils.logging as logging
import os
class OffensiveLanguageClassifier:
"""
Load model trained to detect offensive language in the context of single- turn
dialogue utterances.
This model was trained to be robust to adversarial examples created by humans. See
<http://parl.ai/projects/dialogue_safety/> for more information.
"""
def __init__(
self,
shared: TShared = None,
custom_model_file='zoo:dialogue_safety/single_turn/model',
):
if not shared:
self.model = self._create_safety_model(custom_model_file)
else:
self.model = create_agent_from_shared(shared['model'])
self.classes = {OK_CLASS: False, NOT_OK_CLASS: True}
def share(self):
shared = {'model': self.model.share()}
return shared
def _create_safety_model(self, custom_model_file):
from parlai.core.params import ParlaiParser
parser = ParlaiParser(False, False)
TransformerClassifierAgent.add_cmdline_args(parser, partial_opt=None)
parser.set_params(
model='transformer/classifier',
model_file=custom_model_file,
print_scores=True,
data_parallel=False,
)
safety_opt = parser.parse_args([])
return create_agent(safety_opt, requireModelExists=True)
def contains_offensive_language(self, text):
"""
Returns the probability that a message is safe according to the classifier.
"""
if not text:
return False, 1.0
act = {'text': text, 'episode_done': True}
self.model.observe(act)
response = self.model.act()['text']
pred_class, prob = [x.split(': ')[-1] for x in response.split('\n')]
pred_not_ok = self.classes[pred_class] # check whether classified as NOT OK
prob = float(prob) # cast string to float
return pred_not_ok, prob
def __contains__(self, key):
"""
A simple way of checking whether the model classifies an utterance as offensive.
Returns True if the input phrase is offensive.
"""
pred_not_ok, prob = self.contains_offensive_language(key)
return pred_not_ok
class OffensiveStringMatcher:
"""
Detects offensive language using a list of offensive language and phrases from
https://github.com/LDNOOBW.
"""
def __init__(self, datapath: str = None):
"""
Get data from external sources and build data representation.
If datapath ends in '.txt' it is assumed a custom model file is already given.
"""
import parlai.core.build_data as build_data
from parlai.core.dict import DictionaryAgent
self.tokenize = DictionaryAgent.split_tokenize
def _path():
# Build the data if it doesn't exist.
build()
return os.path.join(
self.datapath, 'OffensiveLanguage', 'OffensiveLanguage.txt'
)
def build():
version = 'v1.0'
dpath = os.path.join(self.datapath, 'OffensiveLanguage')
if not build_data.built(dpath, version):
logging.info(f'building data: {dpath}')
if build_data.built(dpath):
# An older version exists, so remove these outdated files.
build_data.remove_dir(dpath)
build_data.make_dir(dpath)
# Download the data.
fname = 'OffensiveLanguage.txt'
url = 'http://parl.ai/downloads/offensive_language/' + fname
build_data.download(url, dpath, fname)
# Mark the data as built.
build_data.mark_done(dpath, version)
if datapath is not None and datapath.endswith('.txt'):
# Load custom file.
self.datafile = datapath
else:
# Build data from zoo, and place in given datapath.
if datapath is None:
# Build data from zoo.
from parlai.core.params import ParlaiParser
parser = ParlaiParser(False, False)
self.datapath = parser.parse_args([])['datapath']
else:
self.datapath = datapath
self.datafile = _path()
# store a token trie: e.g.
# {'2': {'girls': {'1': {'cup': {'__END__': True}}}}
self.END = '__END__'
self.max_len = 1
self.offensive_trie = {}
self.word_prefixes = [
'de',
'de-',
'dis',
'dis-',
'ex',
'ex-',
'mis',
'mis-',
'pre',
'pre-',
'non',
'non-',
'semi',
'semi-',
'sub',
'sub-',
'un',
'un-',
]
self.word_suffixes = [
'a',
'able',
'as',
'dom',
'ed',
'er',
'ers',
'ery',
'es',
'est',
'ful',
'fy',
'ies',
'ify',
'in',
'ing',
'ish',
'less',
'ly',
's',
'y',
]
self.allow_list = [
'butter',
'buttery',
'spicy',
'spiced',
'spices',
'spicier',
'spicing',
'twinkies',
]
with PathManager.open(self.datafile, 'r') as f:
for p in f.read().splitlines():
mod_ps = [p]
mod_ps += [pref + p for pref in self.word_prefixes]
mod_ps += [p + suff for suff in self.word_suffixes]
for mod_p in mod_ps:
if mod_p not in self.allow_list:
self.add_phrase(mod_p)
def add_phrase(self, phrase):
"""
Add a single phrase to the filter.
"""
toks = self.tokenize(phrase)
curr = self.offensive_trie
for t in toks:
if t not in curr:
curr[t] = {}
curr = curr[t]
curr[self.END] = True
self.max_len = max(self.max_len, len(toks))
def add_words(self, phrase_list):
"""
Add list of custom phrases to the filter.
"""
for phrase in phrase_list:
self.add_phrase(phrase)
def _check_sequence(self, toks, idx, node):
"""
Check if words from the sequence are in the trie.
This checks phrases made from toks[i], toks[i:i+2] ... toks[i:i + self.max_len]
"""
right = min(idx + self.max_len, len(toks))
for i in range(idx, right):
if toks[i] in node:
node = node[toks[i]]
if self.END in node:
return ' '.join(toks[j] for j in range(idx, i + 1))
else:
break
return False
def contains_offensive_language(self, text):
"""
Determine if text contains any offensive words in the filter.
"""
if not text:
return None
if type(text) is str:
toks = self.tokenize(text.lower())
elif type(text) is list or type(text) is tuple:
toks = text
for i in range(len(toks)):
res = self._check_sequence(toks, i, self.offensive_trie)
if res:
return res
return None
def find_all_offensive_language(self, text):
"""
Find all offensive words from text in the filter.
"""
if type(text) is str:
toks = self.tokenize(text.lower())
elif type(text) is list or type(text) is tuple:
toks = text
all_offenses = []
for i in range(len(toks)):
res = self._check_sequence(toks, i, self.offensive_trie)
if res:
all_offenses.append(res)
return all_offenses
def __contains__(self, key):
"""
Determine if text contains any offensive words in the filter.
"""
return self.contains_offensive_language(key)