Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 13 additions & 40 deletions datajoint/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,28 +83,6 @@ def __eq__(self, other):
"""
return self.conn_info == other.conn_info

def is_same(self, host, user):
"""
true if the connection host and user name are the same
"""
if host is None:
host = self.conn_info['host']
port = self.conn_info['port']
else:
try:
host, port = host.split(':')
port = int(port)
except ValueError:
port = default_port

if user is None:
user = self.conn_info['user']

return self.conn_info['host'] == host and \
self.conn_info['port'] == port and \
self.conn_info['user'] == user


@property
def is_connected(self):
return self._conn.ping()
Expand Down Expand Up @@ -180,14 +158,14 @@ def _load_headings(self, dbname, force=False):
Setting force=True will result in reloading of the heading even if one
already exists.
"""
if not dbname in self.headings or force:
if dbname not in self.headings or force:
logger.info('Loading table definitions from `{dbname}`...'.format(dbname=dbname))
self.table_names[dbname] = {}
self.headings[dbname] = {}
self.tableInfo[dbname] = {}

cur = self.query('SHOW TABLE STATUS FROM `{dbname}` WHERE name REGEXP "{sqlPtrn}"'.format(
dbname=dbname, sqlPtrn=table_name_regexp_sql.pattern), asDict=True)
dbname=dbname, sqlPtrn=table_name_regexp_sql.pattern), as_dict=True)

for info in cur:
info = {k.lower(): v for k, v in info.items()} # lowercase it
Expand All @@ -202,27 +180,27 @@ def _load_headings(self, dbname, force=False):

def load_dependencies(self, dbname): # TODO: Perhaps consider making this "private" by preceding with underscore?
"""
load dependencies (foreign keys) between tables by examnining their
load dependencies (foreign keys) between tables by examining their
respective CREATE TABLE statements.
"""

ptrn = r"""
FOREIGN\ KEY\s+\((?P<attr1>[`\w ,]+)\)\s+ # list of keys in this table
foreign_key_regexp = re.compile(r"""
FOREIGN KEY\s+\((?P<attr1>[`\w ,]+)\)\s+ # list of keys in this table
REFERENCES\s+(?P<ref>[^\s]+)\s+ # table referenced
\((?P<attr2>[`\w ,]+)\) # list of keys in the referenced table
"""
""", re.X)

logger.info('Loading dependencies for `{dbname}`'.format(dbname=dbname))

for tabName in self.tableInfo[dbname]:
cur = self.query('SHOW CREATE TABLE `{dbname}`.`{tabName}`'.format(dbname=dbname, tabName=tabName),
asDict=True)
as_dict=True)
table_def = cur.fetchone()
full_table_name = '`%s`.`%s`' % (dbname, tabName)
self.parents[full_table_name] = []
self.referenced[full_table_name] = []

for m in re.finditer(ptrn, table_def["Create Table"], re.X): # iterate through foreign key statements
for m in foreign_key_regexp.finditer(table_def["Create Table"]): # iterate through foreign key statements
assert m.group('attr1') == m.group('attr2'), \
'Foreign keys must link identically named attributes'
attrs = m.group('attr1')
Expand All @@ -234,11 +212,7 @@ def load_dependencies(self, dbname): # TODO: Perhaps consider making this "priv
if not re.search(r'`\.`', ref): # if referencing other table in same schema
ref = '`%s`.%s' % (dbname, ref) # convert to full-table name

if is_primary:
self.parents[full_table_name].append(ref)
else:
self.referenced[full_table_name].append(ref)

(self.parents if is_primary else self.referenced)[full_table_name].append(ref)
self.parents.setdefault(ref, [])
self.referenced.setdefault(ref, [])

Expand Down Expand Up @@ -298,7 +272,6 @@ def __del__(self):
logger.info('Disconnecting {user}@{host}:{port}'.format(**self.conn_info))
self._conn.close()


def erd(self, databases=None, tables=None, fill=True, reload=True):
"""
Creates Entity Relation Diagram for the database or specified subset of
Expand All @@ -321,14 +294,14 @@ def erd(self, databases=None, tables=None, fill=True, reload=True):

graph.plot()

def query(self, query, args=(), asDict=False):
def query(self, query, args=(), as_dict=False):
"""
Execute the specified query and return the tuple generator.

If asDict is set to True, the returned cursor objects returns
If as_dict is set to True, the returned cursor objects returns
query results as dictionary.
"""
cursor = pymysql.cursors.DictCursor if asDict else pymysql.cursors.Cursor
cursor = pymysql.cursors.DictCursor if as_dict else pymysql.cursors.Cursor
cur = self._conn.cursor(cursor=cursor)

# Log the query
Expand All @@ -343,4 +316,4 @@ def cancel_transaction(self):
self.query('ROLLBACK')

def commit_transaction(self):
self.query('COMMIT')
self.query('COMMIT')
Loading