-
Notifications
You must be signed in to change notification settings - Fork 4.4k
/
sql_storage.py
371 lines (274 loc) · 11.5 KB
/
sql_storage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
from chatterbot.storage import StorageAdapter
class SQLStorageAdapter(StorageAdapter):
"""
The SQLStorageAdapter allows ChatterBot to store conversation
data in any database supported by the SQL Alchemy ORM.
All parameters are optional, by default a sqlite database is used.
It will check if tables are present, if they are not, it will attempt
to create the required tables.
:keyword database_uri: eg: sqlite:///database_test.sqlite3',
The database_uri can be specified to choose database driver.
:type database_uri: str
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
self.database_uri = kwargs.get('database_uri', False)
# None results in a sqlite in-memory database as the default
if self.database_uri is None:
self.database_uri = 'sqlite://'
# Create a file database if the database is not a connection string
if not self.database_uri:
self.database_uri = 'sqlite:///db.sqlite3'
self.engine = create_engine(self.database_uri, convert_unicode=True)
if self.database_uri.startswith('sqlite://'):
from sqlalchemy.engine import Engine
from sqlalchemy import event
@event.listens_for(Engine, 'connect')
def set_sqlite_pragma(dbapi_connection, connection_record):
dbapi_connection.execute('PRAGMA journal_mode=WAL')
dbapi_connection.execute('PRAGMA synchronous=NORMAL')
if not self.engine.dialect.has_table(self.engine, 'Statement'):
self.create_database()
self.Session = sessionmaker(bind=self.engine, expire_on_commit=True)
def get_statement_model(self):
"""
Return the statement model.
"""
from chatterbot.ext.sqlalchemy_app.models import Statement
return Statement
def get_tag_model(self):
"""
Return the conversation model.
"""
from chatterbot.ext.sqlalchemy_app.models import Tag
return Tag
def model_to_object(self, statement):
from chatterbot.conversation import Statement as StatementObject
return StatementObject(**statement.serialize())
def count(self):
"""
Return the number of entries in the database.
"""
Statement = self.get_model('statement')
session = self.Session()
statement_count = session.query(Statement).count()
session.close()
return statement_count
def remove(self, statement_text):
"""
Removes the statement that matches the input text.
Removes any responses from statements where the response text matches
the input text.
"""
Statement = self.get_model('statement')
session = self.Session()
query = session.query(Statement).filter_by(text=statement_text)
record = query.first()
session.delete(record)
self._session_finish(session)
def filter(self, **kwargs):
"""
Returns a list of objects from the database.
The kwargs parameter can contain any number
of attributes. Only objects which contain all
listed attributes and in which all values match
for all listed attributes will be returned.
"""
from sqlalchemy import or_
Statement = self.get_model('statement')
Tag = self.get_model('tag')
session = self.Session()
page_size = kwargs.pop('page_size', 1000)
order_by = kwargs.pop('order_by', None)
tags = kwargs.pop('tags', [])
exclude_text = kwargs.pop('exclude_text', None)
exclude_text_words = kwargs.pop('exclude_text_words', [])
persona_not_startswith = kwargs.pop('persona_not_startswith', None)
search_text_contains = kwargs.pop('search_text_contains', None)
# Convert a single sting into a list if only one tag is provided
if type(tags) == str:
tags = [tags]
if len(kwargs) == 0:
statements = session.query(Statement).filter()
else:
statements = session.query(Statement).filter_by(**kwargs)
if tags:
statements = statements.join(Statement.tags).filter(
Tag.name.in_(tags)
)
if exclude_text:
statements = statements.filter(
~Statement.text.in_(exclude_text)
)
if exclude_text_words:
or_word_query = [
Statement.text.ilike('%' + word + '%') for word in exclude_text_words
]
statements = statements.filter(
~or_(*or_word_query)
)
if persona_not_startswith:
statements = statements.filter(
~Statement.persona.startswith('bot:')
)
if search_text_contains:
or_query = [
Statement.search_text.contains(word) for word in search_text_contains.split(' ')
]
statements = statements.filter(
or_(*or_query)
)
if order_by:
if 'created_at' in order_by:
index = order_by.index('created_at')
order_by[index] = Statement.created_at.asc()
statements = statements.order_by(*order_by)
total_statements = statements.count()
for start_index in range(0, total_statements, page_size):
for statement in statements.slice(start_index, start_index + page_size):
yield self.model_to_object(statement)
session.close()
def create(self, **kwargs):
"""
Creates a new statement matching the keyword arguments specified.
Returns the created statement.
"""
Statement = self.get_model('statement')
Tag = self.get_model('tag')
session = self.Session()
tags = set(kwargs.pop('tags', []))
if 'search_text' not in kwargs:
kwargs['search_text'] = self.tagger.get_text_index_string(kwargs['text'])
if 'search_in_response_to' not in kwargs:
in_response_to = kwargs.get('in_response_to')
if in_response_to:
kwargs['search_in_response_to'] = self.tagger.get_text_index_string(in_response_to)
statement = Statement(**kwargs)
for tag_name in tags:
tag = session.query(Tag).filter_by(name=tag_name).first()
if not tag:
# Create the tag
tag = Tag(name=tag_name)
statement.tags.append(tag)
session.add(statement)
session.flush()
session.refresh(statement)
statement_object = self.model_to_object(statement)
self._session_finish(session)
return statement_object
def create_many(self, statements):
"""
Creates multiple statement entries.
"""
Statement = self.get_model('statement')
Tag = self.get_model('tag')
session = self.Session()
create_statements = []
create_tags = {}
for statement in statements:
statement_data = statement.serialize()
tag_data = statement_data.pop('tags', [])
statement_model_object = Statement(**statement_data)
if not statement.search_text:
statement_model_object.search_text = self.tagger.get_text_index_string(statement.text)
if not statement.search_in_response_to and statement.in_response_to:
statement_model_object.search_in_response_to = self.tagger.get_text_index_string(statement.in_response_to)
new_tags = set(tag_data) - set(create_tags.keys())
if new_tags:
existing_tags = session.query(Tag).filter(
Tag.name.in_(new_tags)
)
for existing_tag in existing_tags:
create_tags[existing_tag.name] = existing_tag
for tag_name in tag_data:
if tag_name in create_tags:
tag = create_tags[tag_name]
else:
# Create the tag if it does not exist
tag = Tag(name=tag_name)
create_tags[tag_name] = tag
statement_model_object.tags.append(tag)
create_statements.append(statement_model_object)
session.add_all(create_statements)
session.commit()
def update(self, statement):
"""
Modifies an entry in the database.
Creates an entry if one does not exist.
"""
Statement = self.get_model('statement')
Tag = self.get_model('tag')
if statement is not None:
session = self.Session()
record = None
if hasattr(statement, 'id') and statement.id is not None:
record = session.query(Statement).get(statement.id)
else:
record = session.query(Statement).filter(
Statement.text == statement.text,
Statement.conversation == statement.conversation,
).first()
# Create a new statement entry if one does not already exist
if not record:
record = Statement(
text=statement.text,
conversation=statement.conversation,
persona=statement.persona
)
# Update the response value
record.in_response_to = statement.in_response_to
record.created_at = statement.created_at
record.search_text = self.tagger.get_text_index_string(statement.text)
if statement.in_response_to:
record.search_in_response_to = self.tagger.get_text_index_string(statement.in_response_to)
for tag_name in statement.get_tags():
tag = session.query(Tag).filter_by(name=tag_name).first()
if not tag:
# Create the record
tag = Tag(name=tag_name)
record.tags.append(tag)
session.add(record)
self._session_finish(session)
def get_random(self):
"""
Returns a random statement from the database.
"""
import random
Statement = self.get_model('statement')
session = self.Session()
count = self.count()
if count < 1:
raise self.EmptyDatabaseException()
random_index = random.randrange(0, count)
random_statement = session.query(Statement)[random_index]
statement = self.model_to_object(random_statement)
session.close()
return statement
def drop(self):
"""
Drop the database.
"""
Statement = self.get_model('statement')
Tag = self.get_model('tag')
session = self.Session()
session.query(Statement).delete()
session.query(Tag).delete()
session.commit()
session.close()
def create_database(self):
"""
Populate the database with the tables.
"""
from chatterbot.ext.sqlalchemy_app.models import Base
Base.metadata.create_all(self.engine)
def _session_finish(self, session, statement_text=None):
from sqlalchemy.exc import InvalidRequestError
try:
session.commit()
except InvalidRequestError:
# Log the statement text and the exception
self.logger.exception(statement_text)
finally:
session.close()