Skip to content

Commit

Permalink
[IMP] carepoint: Globalize models
Browse files Browse the repository at this point in the history
* Create and use global models obj to hold instantiated declaratives
* Create session context manager to guarantee commits and closes
* Remove `_do_queries` in favor of session context manager
  • Loading branch information
lasley committed Sep 20, 2016
1 parent 6fa70e7 commit 108f91d
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 185 deletions.
170 changes: 83 additions & 87 deletions carepoint/db/carepoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,33 @@
# Copyright 2016-TODAY LasLabs Inc.
# License MIT (https://opensource.org/licenses/MIT).

import os
import imp
import operator
import os
import urllib2
from sqlalchemy import text, bindparam

from contextlib import contextmanager

from sqlalchemy import bindparam
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.inspection import inspect
from smb.SMBHandler import SMBHandler
from sqlalchemy.orm import sessionmaker
from sqlalchemy import text


from .db import Db
from smb.SMBHandler import SMBHandler


Base = declarative_base()
Base.get = lambda s, k, v=None: getattr(s, k, v)
Base.__getitem__ = lambda s, k, v=None: getattr(s, k, v)
Base.__setitem__ = lambda s, k, v: setattr(s, k, v)


models, env, dbs = {}, {}, {}


class Carepoint(dict):
""" Base CarePoint db connector object """

Expand All @@ -37,11 +47,29 @@ class Carepoint(dict):
'==': operator.eq,
}

def __init__(self, server, user, passwd, smb_user=None, smb_passwd=None,
db_args=None,
):
def __init__(
self, server, user, passwd, smb_user=None, smb_passwd=None,
db_args=None, **engine_args
):
""" It initializes new Carepoint object
Args:
server (str): IP or Hostname to database
user (str): Username for database
passwd (str): Password for database
smb_user (str): Username to use for SMB connection, ``None`` to
use the database user
smd_passwd (str): Password to use for the SMB connection, ``None``
to use the database password
db_args (dict): Dictionary of arguments to send during initial
db creation
**engine_args (mixed): Kwargs to pass to ``create_engine``
"""

super(Carepoint, self).__init__()
global env, dbs
self.env = env
self.dbs = dbs
self.iter_refresh = False
params = {
'user': user,
Expand All @@ -51,32 +79,35 @@ def __init__(self, server, user, passwd, smb_user=None, smb_passwd=None,
}
if db_args is not None:
params.update(db_args)
if engine_args:
params.update(engine_args)
# @TODO: Lazy load, once other dbs needed
self.dbs = {
'cph': Db(**params),
}
self.env = {
'cph': sessionmaker(bind=self.dbs['cph']),
}
self.sessions = {}
if not self.dbs.get('cph'):
self.dbs['cph'] = Db(**params)
if not self.env.get('cph'):
self.env['cph'] = sessionmaker(bind=self.dbs['cph'])
if smb_user is None:
self.smb_creds = {
'user': user,
'passwd': passwd,
}
else:
self.smb_creds = {
'user': user,
'passwd': passwd,
'user': smb_user,
'passwd': smb_passwd,
}

@contextmanager
def _get_session(self, model_obj):
session = self.env[model_obj.__dbname__]()
try:
return self.sessions[model_obj.__dbname__]
except KeyError:
session = self.env[model_obj.__dbname__]()
self.sessions[model_obj.__dbname__] = session
return session
yield session
session.commit()
except:
session.rollback()
raise
finally:
session.close()

@property
def _smb_prefix(self):
Expand Down Expand Up @@ -181,30 +212,6 @@ def _create_entities(self, model_obj, cols):
pass
return out

def _do_queries(self, session, *queries, **kwargs):
""" Wrapper method for running any query against the DB safely
Will create a transaction, then loop and run `queries` methods
:param model_obj: Table class to search
:type model_obj: :class:`sqlalchemy.Table`
:param queries: Query method(s) to be run in context of transaction
:type queries: Methods to be run against db
:kwarg no_commit: True to not commit transaction
:returns: Query result(s)
:rtype: ResultProxy if singleton, else List of ResultProxies
"""
res = []
try:
for query in queries:
res.append(query())
if not kwargs.get('no_commit'):
session.commit()
except:
session.rollback()
raise
if len(res) == 1:
return res[0]
return res

def read(self, model_obj, record_id, with_entities=None):
""" Get record by id and return the object
:param model_obj: Table class to search
Expand All @@ -216,16 +223,12 @@ def read(self, model_obj, record_id, with_entities=None):
:type with_entities: list or None
:rtype: :class:`sqlalchemy.engine.ResultProxy`
"""
session = self._get_session(model_obj)
res = self._do_queries(
session,
lambda: session.query(model_obj).get(record_id),
no_commit=True,
)
if with_entities:
res.with_entities(*self._create_entities(
model_obj, with_entities
))
with self._get_session(model_obj) as session:
res = session.query(model_obj).get(record_id)
if with_entities:
res.with_entities(*self._create_entities(
model_obj, with_entities
))
return res

def search(self, model_obj, filters=None, with_entities=None):
Expand All @@ -238,18 +241,15 @@ def search(self, model_obj, filters=None, with_entities=None):
:type with_entities: list or None
:rtype: :class:`sqlalchemy.engine.ResultProxy`
"""
session = self._get_session(model_obj)
if filters is None:
filters = {}
filters = self._unwrap_filters(model_obj, filters)
res = self._do_queries(
session,
lambda: session.query(model_obj).filter(*filters),
)
if with_entities:
res.with_entities(*self._create_entities(
model_obj, with_entities
))
with self._get_session(model_obj) as session:
if filters is None:
filters = {}
filters = self._unwrap_filters(model_obj, filters)
res = session.query(model_obj).filter(*filters)
if with_entities:
res.with_entities(*self._create_entities(
model_obj, with_entities
))
return res

def create(self, model_obj, vals):
Expand All @@ -260,14 +260,10 @@ def create(self, model_obj, vals):
:type vals: dict
:rtype: :class:`sqlalchemy.ext.declarative.Declarative`
"""
session = self._get_session(model_obj)

def __create():
with self._get_session(model_obj) as session:
record = model_obj(**vals)
session.add(record)
return record

return self._do_queries(session, __create)
return record

def update(self, model_obj, record_id, vals):
""" Wrapper to update a record in Carepoint
Expand All @@ -279,16 +275,11 @@ def update(self, model_obj, record_id, vals):
:type vals: dict
:rtype: :class:`sqlalchemy.ext.declarative.Declarative`
"""

session = self._get_session(model_obj)

def __update():
with self._get_session(model_obj) as session:
record = self.read(model_obj, record_id)
for key, val in vals.items():
setattr(record, key, val)
return record

return self._do_queries(session, __update)
return record

def delete(self, model_obj, record_id):
""" Wrapper to delete a record in Carepoint
Expand All @@ -299,19 +290,14 @@ def delete(self, model_obj, record_id):
:return: Whether the record was found, and deleted
:rtype: bool
"""

session = self._get_session(model_obj)

def __delete():
with self._get_session(model_obj) as session:
record = self.read(model_obj, record_id)
result_cnt = record.count()
if result_cnt == 0:
return False
assert result_cnt == 1
session.delete(record)
return True

return self._do_queries(session, __delete)
return True

def get_pks(self, model_obj):
""" Return the Primary keys in the model
Expand Down Expand Up @@ -362,8 +348,18 @@ def __getattr__(self, key):
except KeyError:
raise AttributeError()

def __setitem__(self, key, val, __global=False, *args, **kwargs):
""" Re-implement __setitem__ to allow for global model sync """
super(Carepoint, self).__setitem__(key, val, *args, **kwargs)
if not __global:
global models
models[key] = val

def __getitem__(self, key, retry=True, default=False):
""" Re-implement __getitem__ to scan for models if key missing """
global models
for k, v in models.iteritems():
self.__setitem__(k, v, True)
try:
return super(Carepoint, self).__getitem__(key)
except KeyError:
Expand Down
24 changes: 20 additions & 4 deletions carepoint/db/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,24 @@ class Db(object):
ODBC_DRIVER = 'FreeTDS&TDS_VERSION=8.0'
SQLITE = 'sqlite'

def __new__(self, server=None, user=None, passwd=None,
db=None, port=1433, drv=ODBC_DRIVER, ):
def __new__(
self, server=None, user=None, passwd=None, db=None, port=1433,
drv=ODBC_DRIVER, **engine_args
):
""" It establishes a new database connection and returns engine
Args:
server (str): IP or Hostname to database
user (str): Username for database
passwd (str): Password for database
db (str): Name of database
port (int): Connection port
drv (str): Name of underlying database driver for connection
**engine_args (mixed): Kwargs to pass to ``create_engine``
Return:
sqlalchemy.engine.Engine
"""

if drv != self.SQLITE:
params = {
Expand All @@ -25,7 +41,7 @@ def __new__(self, server=None, user=None, passwd=None,
'prt': port,
}
dsn = 'mssql+pyodbc://{usr}:{pass}@{srv}:{prt}/{db}?driver={drv}'
return create_engine(dsn.format(**params))
return create_engine(dsn.format(**params), **engine_args)

else:
return create_engine('%s://' % self.SQLITE)
return create_engine('%s://' % self.SQLITE, **engine_args)

0 comments on commit 108f91d

Please sign in to comment.