Skip to content

Commit

Permalink
Merge pull request #153 from pagreene/belief-sort
Browse files Browse the repository at this point in the history
Enable sorting by belief in queries
  • Loading branch information
pagreene committed Dec 1, 2020
2 parents 56df7db + 3c0913e commit e35b010
Show file tree
Hide file tree
Showing 8 changed files with 493 additions and 242 deletions.
385 changes: 240 additions & 145 deletions indra_db/client/readonly/query.py

Large diffs are not rendered by default.

120 changes: 95 additions & 25 deletions indra_db/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def wrapper(obj, tbl_name, data, cols=None, commit=True, *args, **kwargs):
if not CAN_COPY:
raise RuntimeError("Cannot use copy methods. `pg_copy` is not "
"available.")
if obj.is_protected():
raise RuntimeError("Attempt to copy while in protected mode!")
if len(data) is 0:
return get_null_return() # Nothing to do....

Expand Down Expand Up @@ -238,10 +240,11 @@ class DatabaseManager(object):
_instance_name_fmt = NotImplemented
_db_name = NotImplemented

def __init__(self, url, label=None):
def __init__(self, url, label=None, protected=False):
self.url = make_url(url)
self.session = None
self.label = label
self.__protected = protected
self._conn = None

# To stringify table classes, we must merge the two meta classes.
Expand All @@ -262,7 +265,7 @@ class BaseMeta(DeclarativeMeta, IndraDBTableMetaClass):
return

# Create the engine (connection manager).
self.engine = create_engine(self.url)
self.__engine = create_engine(self.url)
return

def _init_foreign_key_map(self, foreign_key_map):
Expand All @@ -275,6 +278,21 @@ def _init_foreign_key_map(self, foreign_key_map):
else:
self.__foreign_key_graph = None

def is_protected(self):
return self.__protected

def get_raw_connection(self):
if self.__protected:
logger.error("Cannot get a raw connection if protected mode is on.")
return
return self.__engine.raw_connection()

def get_conn(self):
if self.__protected:
logger.error("Cannot get a direct connection in protected mode.")
return
return self.__engine.connect()

def __del__(self, *args, **kwargs):
if not self.available:
return
Expand Down Expand Up @@ -368,11 +386,16 @@ def grab_session(self):
return
if self.session is None or not self.session.is_active:
logger.debug('Attempting to get session...')
DBSession = sessionmaker(bind=self.engine)
DBSession = sessionmaker(bind=self.__engine,
autoflush=self.__protected,
autocommit=self.__protected)
logger.debug('Got session.')
self.session = DBSession()
if self.session is None:
raise IndraDbException("Failed to grab session.")
if self.__protected:
self.session.flush = \
lambda *a, **k: logger.error("Write not allowed!")

def get_tables(self):
"""Get a list of available tables."""
Expand Down Expand Up @@ -408,12 +431,12 @@ def get_active_tables(self, schema=None):
The name of the schema whose tables you wish to see. The default is
public.
"""
return inspect(self.engine).get_table_names(schema=schema)
return inspect(self.__engine).get_table_names(schema=schema)

def get_schemas(self):
"""Return the list of schema names currently in the database."""
res = []
with self.engine.connect() as con:
with self.__engine.connect() as con:
raw_res = con.execute('SELECT schema_name '
'FROM information_schema.schemata;')
for r, in raw_res:
Expand All @@ -422,13 +445,19 @@ def get_schemas(self):

def create_schema(self, schema_name):
"""Create a schema with the given name."""
with self.engine.connect() as con:
if self.__protected:
logger.error("Running in protected mode, writes not allowed!")
return
with self.__engine.connect() as con:
con.execute('CREATE SCHEMA IF NOT EXISTS %s;' % schema_name)
return

def drop_schema(self, schema_name, cascade=True):
"""Drop a schema (rather forcefully by default)"""
with self.engine.connect() as con:
if self.__protected:
logger.error("Running in protected mode, writes not allowed!")
return
with self.__engine.connect() as con:
logger.info("Dropping schema %s." % schema_name)
con.execute('DROP SCHEMA IF EXISTS %s %s;'
% (schema_name, 'CASCADE' if cascade else ''))
Expand Down Expand Up @@ -599,7 +628,7 @@ def get_copy_cursor(self):
"""Execute SQL queries in the context of a copy operation."""
# Prep the connection.
if self._conn is None:
self._conn = self.engine.raw_connection()
self._conn = self.__engine.raw_connection()
self._conn.rollback()
return self._conn.cursor()

Expand All @@ -613,6 +642,9 @@ def make_copy_batch_id(self):
return random.randint(-2**30, 2**30)

def _prep_copy(self, tbl_name, data, cols):
if self.__protected:
logger.error("Manager is in protected mode, no writes allowed!")
return

# If cols is not specified, use all the cols in the table, else check
# to make sure the names are valid.
Expand Down Expand Up @@ -668,7 +700,7 @@ def _prep_copy(self, tbl_name, data, cols):

# Prep the connection.
if self._conn is None:
self._conn = self.engine.raw_connection()
self._conn = self.__engine.raw_connection()
self._conn.rollback()

return cols, data_bts
Expand Down Expand Up @@ -986,6 +1018,10 @@ def pg_dump(self, dump_file, **options):
dump_file : S3Path or str
The location on s3 where the content should be dumped.
"""
if self.__protected:
logger.error("Cannot execute pg_dump in protected mode.")
return

if isinstance(dump_file, str):
dump_file = S3Path.from_string(dump_file)
elif dump_file is not None and not isinstance(dump_file, S3Path):
Expand Down Expand Up @@ -1022,14 +1058,21 @@ def pg_dump(self, dump_file, **options):
return dump_file

def vacuum(self, analyze=True):
conn = self.engine.raw_connection()
if self.__protected:
logger.error("Vacuuming not allowed in protected mode.")
return
conn = self.__engine.raw_connection()
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
cursor = conn.cursor()
cursor.execute('VACUUM' + (' ANALYZE;' if analyze else ''))
return

def pg_restore(self, dump_file, **options):
"""Load content into the database from a dump file on s3."""
if self.__protected:
logger.error("Cannot execute pg_restore in protected mode.")
return

if isinstance(dump_file, str):
dump_file = S3Path.from_string(dump_file)
elif dump_file is not None and not isinstance(dump_file, S3Path):
Expand Down Expand Up @@ -1073,10 +1116,12 @@ class PrincipalDatabaseManager(DatabaseManager):
_instance_name_fmt = 'indradb-{name}'
_db_name = 'indradb_principal'

def __init__(self, host, label=None):
super(self.__class__, self).__init__(host, label)
def __init__(self, host, label=None, protected=False):
super(self.__class__, self).__init__(host, label, protected)
if not self.available:
return
self.__protected = self._DatabaseManager__protected
self.__engine = self._DatabaseManager__engine

self.public = principal_schema.get_schema(self.Base)
self.readonly = readonly_schema.get_schema(self.Base)
Expand All @@ -1096,13 +1141,19 @@ def __init__(self, host, label=None):

def __getattribute__(self, item):
if item == '_PaStmtSrc':
self.__PaStmtSrc.load_cols(self.engine)
self.load_pa_stmt_src_cols()
return self.__PaStmtSrc
elif item == 'SourceMeta':
self.__SourceMeta.load_cols(self.engine)
self.load_source_meta_cols()
return self.__SourceMeta
return super(DatabaseManager, self).__getattribute__(item)

def load_pa_stmt_src_cols(self, cols=None):
self.__PaStmtSrc.load_cols(self.__engine, cols)

def load_source_meta_cols(self, cols=None):
self.__SourceMeta.load_cols(self.__engine, cols)

def generate_readonly(self, belief_dict, allow_continue=True):
"""Manage the materialized views.
Expand All @@ -1114,6 +1165,10 @@ def generate_readonly(self, belief_dict, allow_continue=True):
If True (default), continue to build the schema if it already
exists. If False, give up if the schema already exists.
"""
if self.__protected:
logger.error("Cannot generate readonly in protected mode.")
return

# Optionally create the schema.
if 'readonly' in self.get_schemas():
if allow_continue:
Expand Down Expand Up @@ -1154,7 +1209,7 @@ def table_is_used(tbl, other_tables):
f"extra in tables={in_ro-to_create}."

# Dump the belief dict into the database.
self.Belief.__table__.create(bind=self.engine)
self.Belief.__table__.create(bind=self.__engine)
self.copy(self.Belief.full_name(),
[(int(h), n) for h, n in belief_dict.items()],
('mk_hash', 'belief'))
Expand Down Expand Up @@ -1199,8 +1254,14 @@ def dump_readonly(self, dump_file=None):
dump_file = dump_loc.get_element_path('readonly-%s.dump' % now_str)
return self.pg_dump(dump_file, schema='readonly')

def create_table(self, table_obj):
table_obj.__table__.create(self.__engine)

def create_tables(self, tbl_list=None):
"""Create the public tables for INDRA database."""
if self.__protected:
logger.error("Cannot create tables in protected mode.")
return
ordered_tables = ['text_ref', 'mesh_ref_annotations', 'text_content',
'reading', 'db_info', 'raw_statements', 'raw_agents',
'raw_mods', 'raw_muts', 'pa_statements', 'pa_agents',
Expand All @@ -1221,16 +1282,16 @@ def create_tables(self, tbl_list=None):
if tbl_name in tbl_name_list:
tbl_name_list.remove(tbl_name)
logger.debug("Creating %s..." % tbl_name)
if not self.public[tbl_name].__table__.exists(self.engine):
self.public[tbl_name].__table__.create(bind=self.engine)
if not self.public[tbl_name].__table__.exists(self.__engine):
self.public[tbl_name].__table__.create(bind=self.__engine)
logger.debug("Table created.")
else:
logger.debug("Table already existed.")

# The rest can be started any time.
for tbl_name in tbl_name_list:
logger.debug("Creating %s..." % tbl_name)
self.public[tbl_name].__table__.create(bind=self.engine)
self.public[tbl_name].__table__.create(bind=self.__engine)
logger.debug("Table created.")
return

Expand All @@ -1241,6 +1302,10 @@ def drop_tables(self, tbl_list=None, force=False):
is False, a warning prompt will be raised to asking for confirmation,
as this action will remove all data from that table.
"""
if self.__protected:
logger.error("Cannot drop tables in protected mode.")
return False

if tbl_list is not None:
for i, tbl in enumerate(tbl_list[:]):
if isinstance(tbl, str):
Expand All @@ -1265,13 +1330,13 @@ def drop_tables(self, tbl_list=None, force=False):
return False
if tbl_list is None:
logger.info("Removing all tables...")
self.Base.metadata.drop_all(self.engine)
self.Base.metadata.drop_all(self.__engine)
logger.debug("All tables removed.")
else:
for tbl in tbl_list:
logger.info("Removing %s..." % tbl.__tablename__)
if tbl.__table__.exists(self.engine):
tbl.__table__.drop(self.engine)
if tbl.__table__.exists(self.__engine):
tbl.__table__.drop(self.__engine)
logger.debug("Table removed.")
else:
logger.debug("Table doesn't exist.")
Expand All @@ -1298,10 +1363,12 @@ class ReadonlyDatabaseManager(DatabaseManager):
_instance_name_fmt = 'indradb-readonly-{name}'
_db_name = 'indradb_readonly'

def __init__(self, host, label=None):
super(self.__class__, self).__init__(host, label)
def __init__(self, host, label=None, protected=True):
super(self.__class__, self).__init__(host, label, protected)
if not self.available:
return
self.__protected = self._DatabaseManager__protected
self.__engine = self._DatabaseManager__engine

self.tables = readonly_schema.get_schema(self.Base)
for tbl in self.tables.values():
Expand All @@ -1325,13 +1392,13 @@ def get_source_names(self) -> set:

def __getattribute__(self, item):
if item == '_PaStmtSrc':
self.__PaStmtSrc.load_cols(self.engine)
self.__PaStmtSrc.load_cols(self.__engine)
return self.__PaStmtSrc
elif item == 'SourceMeta':
if self.__non_source_cols is None:
self.__non_source_cols = \
set(self.get_column_names(self.__SourceMeta))
self.__SourceMeta.load_cols(self.engine)
self.__SourceMeta.load_cols(self.__engine)
return self.__SourceMeta
return super(DatabaseManager, self).__getattribute__(item)

Expand All @@ -1348,6 +1415,9 @@ def get_active_tables(self, schema='readonly'):

def load_dump(self, dump_file, force_clear=True):
"""Load from a dump of the readonly schema on s3."""
if self.__protected:
logger.error("Cannot load a dump while in protected mode.")
return

# Make sure the database is clear.
if 'readonly' in self.get_schemas():
Expand Down
7 changes: 4 additions & 3 deletions indra_db/managers/dump_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ def _choose_db(cls, **kwargs):
else:
raise ValueError("No database specified.")
if db_opt == 'principal':
return get_db('primary')
return get_db('primary', protected=False)
else: # if db_opt == 'readonly'
return get_ro('primary')
return get_ro('primary', protected=False)

def get_s3_path(self):
if self.s3_dump_path is None:
Expand Down Expand Up @@ -535,5 +535,6 @@ def parse_args():

if __name__ == '__main__':
args = parse_args()
dump(get_db(args.database), get_ro(args.readonly), args.delet_existing,
dump(get_db(args.database, protected=False),
get_ro(args.readonly, protected=False), args.delet_existing,
args.allow_continue, args.load_only, args.dump_only)
2 changes: 1 addition & 1 deletion indra_db/schemas/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def build_indices(cls, db):

@staticmethod
def execute(db, sql):
conn = db.engine.raw_connection()
conn = db.get_raw_connection()
cursor = conn.cursor()
cursor.execute(sql)
conn.commit()
Expand Down
2 changes: 1 addition & 1 deletion indra_db/schemas/readonly_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def definition(cls, db):
db.grab_session()

# Make sure the necessary extension is installed.
with db.engine.connect() as conn:
with db.get_conn() as conn:
conn.execute('CREATE EXTENSION IF NOT EXISTS tablefunc;')

logger.info("Discovering the possible sources...")
Expand Down

0 comments on commit e35b010

Please sign in to comment.