Skip to content

Commit

Permalink
Use lowercase table names for SQL Alchemy tables
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Aug 8, 2017
1 parent f02b2a6 commit 1c81c51
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 49 deletions.
45 changes: 23 additions & 22 deletions chatterbot/ext/sqlalchemy_app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,23 @@ def __tablename__(cls):
'tag_association',
Base.metadata,
Column('tag_id', Integer, ForeignKey('tag.id')),
Column('statement_id', Integer, ForeignKey('StatementTable.id'))
Column('statement_id', Integer, ForeignKey('statement.id'))
)


class Tag(Base):
"""
A tag that describes a statement.
"""

name = Column(String)

class StatementTable(Base):

class Statement(Base):
"""
StatementTable, placeholder for a sentence or phrase.
A Statement represents a sentence or phrase.
"""

__tablename__ = 'StatementTable'

def get_statement(self):
from chatterbot.conversation import Statement as StatementObject

statement = StatementObject(self.text, extra_data=self.extra_data)
for response in self.in_response_to:
statement.add_response(response.get_response())
return statement

text = Column(String, unique=True)

tags = relationship(
Expand All @@ -66,17 +58,24 @@ def get_statement(self):
extra_data = Column(PickleType)

in_response_to = relationship(
'ResponseTable',
'Response',
back_populates='statement_table'
)

class ResponseTable(Base):
def get_statement(self):
from chatterbot.conversation import Statement as StatementObject

statement = StatementObject(self.text, extra_data=self.extra_data)
for response in self.in_response_to:
statement.add_response(response.get_response())
return statement


class Response(Base):
"""
ResponseTable, contains responses related to a givem statment.
Response, contains responses related to a givem statment.
"""

__tablename__ = 'ResponseTable'

text = Column(String)

created_at = Column(
Expand All @@ -86,10 +85,10 @@ class ResponseTable(Base):

occurrence = Column(Integer, default=1)

statement_text = Column(String, ForeignKey('StatementTable.text'))
statement_text = Column(String, ForeignKey('statement.text'))

statement_table = relationship(
'StatementTable',
'Statement',
back_populates='in_response_to',
cascade='all',
uselist=False
Expand All @@ -100,20 +99,22 @@ def get_response(self):
occ = {'occurrence': self.occurrence}
return ResponseObject(text=self.text, **occ)


conversation_association_table = Table(
'conversation_association',
Base.metadata,
Column('conversation_id', Integer, ForeignKey('conversation.id')),
Column('statement_id', Integer, ForeignKey('StatementTable.id'))
Column('statement_id', Integer, ForeignKey('statement.id'))
)


class Conversation(Base):
"""
A conversation.
"""

statements = relationship(
'StatementTable',
'Statement',
secondary=lambda: conversation_association_table,
backref='conversations'
)
54 changes: 27 additions & 27 deletions chatterbot/storage/sql_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

try:
from chatterbot.ext.sqlalchemy_app.models import (
Base, Conversation, StatementTable, ResponseTable, Tag
Base, Conversation, Statement, Response
)
except ImportError:
pass


def get_response_table(response):
return ResponseTable(text=response.text, occurrence=response.occurrence)
return Response(text=response.text, occurrence=response.occurrence)


class SQLStorageAdapter(StorageAdapter):
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(self, **kwargs):
"read_only", False
)

if not self.engine.dialect.has_table(self.engine, 'StatementTable'):
if not self.engine.dialect.has_table(self.engine, 'Statement'):
self.create()

self.Session = sessionmaker(bind=self.engine, expire_on_commit=True)
Expand All @@ -79,17 +79,17 @@ def count(self):
Return the number of entries in the database.
"""
session = self.Session()
statement_count = session.query(StatementTable).count()
statement_count = session.query(Statement).count()
session.close()
return statement_count

def __statement_filter(self, session, **kwargs):
"""
Apply filter operation on StatementTable
Apply filter operation on Statement
rtype: query
"""
_query = session.query(StatementTable)
_query = session.query(Statement)
return _query.filter_by(**kwargs)

def find(self, statement_text):
Expand Down Expand Up @@ -137,32 +137,32 @@ def filter(self, **kwargs):
# _response_query = None
_query = None
if len(filter_parameters) == 0:
_response_query = session.query(StatementTable)
_response_query = session.query(Statement)
statements.extend(_response_query.all())
else:
for i, fp in enumerate(filter_parameters):
_filter = filter_parameters[fp]
if fp in ['in_response_to', 'in_response_to__contains']:
_response_query = session.query(StatementTable)
_response_query = session.query(Statement)
if isinstance(_filter, list):
if len(_filter) == 0:
_query = _response_query.filter(
StatementTable.in_response_to == None) # NOQA Here must use == instead of is
Statement.in_response_to == None) # NOQA Here must use == instead of is
else:
for f in _filter:
_query = _response_query.filter(
StatementTable.in_response_to.contains(get_response_table(f)))
Statement.in_response_to.contains(get_response_table(f)))
else:
if fp == 'in_response_to__contains':
_query = _response_query.join(ResponseTable).filter(ResponseTable.text == _filter)
_query = _response_query.join(Response).filter(Response.text == _filter)
else:
_query = _response_query.filter(StatementTable.in_response_to == None) # NOQA
_query = _response_query.filter(Statement.in_response_to == None) # NOQA
else:
if _query:
_query = _query.filter(ResponseTable.statement_text.like('%' + _filter + '%'))
_query = _query.filter(Response.statement_text.like('%' + _filter + '%'))
else:
_response_query = session.query(ResponseTable)
_query = _response_query.filter(ResponseTable.statement_text.like('%' + _filter + '%'))
_response_query = session.query(Response)
_query = _response_query.filter(Response.statement_text.like('%' + _filter + '%'))

if _query is None:
return []
Expand All @@ -172,7 +172,7 @@ def filter(self, **kwargs):
results = []

for statement in statements:
if isinstance(statement, ResponseTable):
if isinstance(statement, Response):
if statement and statement.statement_table:
results.append(statement.statement_table.get_statement())
else:
Expand All @@ -195,14 +195,14 @@ def update(self, statement):

# Create a new statement entry if one does not already exist
if not record:
record = StatementTable(text=statement.text)
record = Statement(text=statement.text)

record.extra_data = dict(statement.extra_data)

if statement.in_response_to:
# Get or create the response records as needed
for response in statement.in_response_to:
_response = session.query(ResponseTable).filter_by(
_response = session.query(Response).filter_by(
text=response.text,
statement_text=statement.text
).first()
Expand All @@ -211,7 +211,7 @@ def update(self, statement):
_response.occurrence += 1
else:
# Create the record
_response = ResponseTable(
_response = Response(
text=response.text,
statement_text=statement.text,
occurrence=response.occurrence
Expand Down Expand Up @@ -250,23 +250,23 @@ def add_to_converation(self, conversation_id, statement, response):

conversation = session.query(Conversation).get(conversation_id)

statement_query = session.query(StatementTable).filter_by(
statement_query = session.query(Statement).filter_by(
text=statement.text
).first()
response_query = session.query(StatementTable).filter_by(
response_query = session.query(Statement).filter_by(
text=response.text
).first()

# Make sure the statements exist
if not statement_query:
self.update(statement)
statement_query = session.query(StatementTable).filter_by(
statement_query = session.query(Statement).filter_by(
text=statement.text
).first()

if not response_query:
self.update(response)
response_query = session.query(StatementTable).filter_by(
response_query = session.query(Statement).filter_by(
text=response.text
).first()

Expand All @@ -285,10 +285,10 @@ def get_latest_response(self, conversation_id):
statement = None

statement_query = session.query(
StatementTable
Statement
).filter(
StatementTable.conversations.any(id=conversation_id)
).order_by(StatementTable.id).limit(2).first()
Statement.conversations.any(id=conversation_id)
).order_by(Statement.id).limit(2).first()

if statement_query:
statement = statement_query.get_statement()
Expand All @@ -307,7 +307,7 @@ def get_random(self):
raise self.EmptyDatabaseException()

rand = random.randrange(0, count)
stmt = session.query(StatementTable)[rand]
stmt = session.query(Statement)[rand]

statement = stmt.get_statement()

Expand Down

0 comments on commit 1c81c51

Please sign in to comment.