Skip to content

Commit

Permalink
Merge pull request #6826 from VJalili/authnz_trans
Browse files Browse the repository at this point in the history
Remove `trans` from OIDC-related types
  • Loading branch information
jmchilton committed Oct 7, 2018
2 parents 7d9093d + 69bd26e commit 8ed5d8c
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 57 deletions.
2 changes: 1 addition & 1 deletion lib/galaxy/authnz/managers.py
Expand Up @@ -129,7 +129,7 @@ def _extend_cloudauthz_config(self, trans, cloudauthz):
if cloudauthz.provider == "aws":
success, message, backend = self._get_authnz_backend(cloudauthz.authn.provider)
strategy = Strategy(trans, Storage, backend.config)
on_the_fly_config(trans)
on_the_fly_config(trans.sa_session)
try:
config['id_token'] = cloudauthz.authn.get_id_token(strategy)
except requests.exceptions.HTTPError as e:
Expand Down
18 changes: 9 additions & 9 deletions lib/galaxy/authnz/psa_authnz.py
Expand Up @@ -123,13 +123,13 @@ def _login_user(self, backend, user, social_user):
self.config['user'] = user

def authenticate(self, trans):
on_the_fly_config(trans)
on_the_fly_config(trans.sa_session)
strategy = Strategy(trans, Storage, self.config)
backend = self._load_backend(strategy, self.config['redirect_uri'])
return do_auth(backend)

def callback(self, state_token, authz_code, trans, login_redirect_url):
on_the_fly_config(trans)
on_the_fly_config(trans.sa_session)
self.config[setting_name('LOGIN_REDIRECT_URL')] = login_redirect_url
strategy = Strategy(trans, Storage, self.config)
strategy.session_set(BACKENDS_NAME[self.config['provider']] + '_state', state_token)
Expand All @@ -142,7 +142,7 @@ def callback(self, state_token, authz_code, trans, login_redirect_url):
return redirect_url, self.config.get('user', None)

def disconnect(self, provider, trans, disconnect_redirect_url=None, association_id=None):
on_the_fly_config(trans)
on_the_fly_config(trans.sa_session)
self.config[setting_name('DISCONNECT_REDIRECT_URL')] =\
disconnect_redirect_url if disconnect_redirect_url is not None else ()
strategy = Strategy(trans, Storage, self.config)
Expand Down Expand Up @@ -235,12 +235,12 @@ def is_integrity_error(cls, exception):
return exception.__class__ is IntegrityError


def on_the_fly_config(trans):
trans.app.model.PSACode.trans = trans
trans.app.model.UserAuthnzToken.trans = trans
trans.app.model.PSANonce.trans = trans
trans.app.model.PSAPartial.trans = trans
trans.app.model.PSAAssociation.trans = trans
def on_the_fly_config(sa_session):
PSACode.sa_session = sa_session
UserAuthnzToken.sa_session = sa_session
PSANonce.sa_session = sa_session
PSAPartial.sa_session = sa_session
PSAAssociation.sa_session = sa_session


def contains_required_data(response=None, is_new=False, **kwargs):
Expand Down
89 changes: 42 additions & 47 deletions lib/galaxy/model/__init__.py
Expand Up @@ -4753,9 +4753,8 @@ def __init__(self, user=None, session=None, openid=None):

class PSAAssociation(AssociationMixin):

# This static property is of type: galaxy.web.framework.webapp.GalaxyWebTransaction
# and it is set in: galaxy.authnz.psa_authnz.PSAAuthnz
trans = None
# This static property is set at: galaxy.authnz.psa_authnz.PSAAuthnz
sa_session = None

def __init__(self, server_url=None, handle=None, secret=None, issued=None, lifetime=None, assoc_type=None):
self.server_url = server_url
Expand All @@ -4766,82 +4765,79 @@ def __init__(self, server_url=None, handle=None, secret=None, issued=None, lifet
self.assoc_type = assoc_type

def save(self):
self.trans.sa_session.add(self)
self.trans.sa_session.flush()
self.sa_session.add(self)
self.sa_session.flush()

@classmethod
def store(cls, server_url, association):
try:
assoc = cls.trans.sa_session.query(cls).filter_by(server_url=server_url, handle=association.handle)[0]
assoc = cls.sa_session.query(cls).filter_by(server_url=server_url, handle=association.handle)[0]
except IndexError:
assoc = cls(server_url=server_url, handle=association.handle)
assoc.secret = base64.encodestring(association.secret).decode()
assoc.issued = association.issued
assoc.lifetime = association.lifetime
assoc.assoc_type = association.assoc_type
cls.trans.sa_session.add(assoc)
cls.trans.sa_session.flush()
cls.sa_session.add(assoc)
cls.sa_session.flush()

@classmethod
def get(cls, *args, **kwargs):
return cls.trans.sa_session.query(cls).filter_by(*args, **kwargs)
return cls.sa_session.query(cls).filter_by(*args, **kwargs)

@classmethod
def remove(cls, ids_to_delete):
cls.trans.sa_session.query(cls).filter(cls.id.in_(ids_to_delete)).delete(synchronize_session='fetch')
cls.sa_session.query(cls).filter(cls.id.in_(ids_to_delete)).delete(synchronize_session='fetch')


class PSACode(CodeMixin):
__table_args__ = (UniqueConstraint('code', 'email'),)

# This static property is of type: galaxy.web.framework.webapp.GalaxyWebTransaction
# and it is set in: galaxy.authnz.psa_authnz.PSAAuthnz
trans = None
# This static property is set at: galaxy.authnz.psa_authnz.PSAAuthnz
sa_session = None

def __init__(self, email, code):
self.email = email
self.code = code

def save(self):
self.trans.sa_session.add(self)
self.trans.sa_session.flush()
self.sa_session.add(self)
self.sa_session.flush()

@classmethod
def get_code(cls, code):
return cls.trans.sa_session.query(cls).filter(cls.code == code).first()
return cls.sa_session.query(cls).filter(cls.code == code).first()


class PSANonce(NonceMixin):

# This static property is of type: galaxy.web.framework.webapp.GalaxyWebTransaction
# and it is set in: galaxy.authnz.psa_authnz.PSAAuthnz
trans = None
# This static property is set at: galaxy.authnz.psa_authnz.PSAAuthnz
sa_session = None

def __init__(self, server_url, timestamp, salt):
self.server_url = server_url
self.timestamp = timestamp
self.salt = salt

def save(self):
self.trans.sa_session.add(self)
self.trans.sa_session.flush()
self.sa_session.add(self)
self.sa_session.flush()

@classmethod
def use(cls, server_url, timestamp, salt):
try:
return cls.trans.sa_session.query(cls).filter_by(server_url=server_url, timestamp=timestamp, salt=salt)[0]
return cls.sa_session.query(cls).filter_by(server_url=server_url, timestamp=timestamp, salt=salt)[0]
except IndexError:
instance = cls(server_url=server_url, timestamp=timestamp, salt=salt)
cls.trans.sa_session.add(instance)
cls.trans.sa_session.flush()
cls.sa_session.add(instance)
cls.sa_session.flush()
return instance


class PSAPartial(PartialMixin):

# This static property is of type: galaxy.web.framework.webapp.GalaxyWebTransaction
# and it is set in: galaxy.authnz.psa_authnz.PSAAuthnz
trans = None
# This static property is set at: galaxy.authnz.psa_authnz.PSAAuthnz
sa_session = None

def __init__(self, token, data, next_step, backend):
self.token = token
Expand All @@ -4850,26 +4846,25 @@ def __init__(self, token, data, next_step, backend):
self.backend = backend

def save(self):
self.trans.sa_session.add(self)
self.trans.sa_session.flush()
self.sa_session.add(self)
self.sa_session.flush()

@classmethod
def load(cls, token):
return cls.trans.sa_session.query(cls).filter(cls.token == token).first()
return cls.sa_session.query(cls).filter(cls.token == token).first()

@classmethod
def destroy(cls, token):
partial = cls.load(token)
if partial:
cls.trans.sa_session.delete(partial)
cls.sa_session.delete(partial)


class UserAuthnzToken(UserMixin):
__table_args__ = (UniqueConstraint('provider', 'uid'),)

# This static property is of type: galaxy.web.framework.webapp.GalaxyWebTransaction
# and it is set in: galaxy.authnz.psa_authnz.PSAAuthnz
trans = None
# This static property is set at: galaxy.authnz.psa_authnz.PSAAuthnz
sa_session = None

def __init__(self, provider, uid, extra_data=None, lifetime=None, assoc_type=None, user=None):
self.provider = provider
Expand All @@ -4888,12 +4883,12 @@ def get_id_token(self, strategy):

def set_extra_data(self, extra_data=None):
if super(UserAuthnzToken, self).set_extra_data(extra_data):
self.trans.sa_session.add(self)
self.trans.sa_session.flush()
self.sa_session.add(self)
self.sa_session.flush()

def save(self):
self.trans.sa_session.add(self)
self.trans.sa_session.flush()
self.sa_session.add(self)
self.sa_session.flush()

@classmethod
def username_max_length(cls):
Expand All @@ -4907,12 +4902,12 @@ def user_model(cls):

@classmethod
def changed(cls, user):
cls.trans.sa_session.add(user)
cls.trans.sa_session.flush()
cls.sa_session.add(user)
cls.sa_session.flush()

@classmethod
def user_query(cls):
return cls.trans.sa_session.query(cls.user_model())
return cls.sa_session.query(cls.user_model())

@classmethod
def user_exists(cls, *args, **kwargs):
Expand All @@ -4927,8 +4922,8 @@ def create_user(cls, *args, **kwargs):
model = cls.user_model()
instance = model(*args, **kwargs)
instance.set_random_password()
cls.trans.sa_session.add(instance)
cls.trans.sa_session.flush()
cls.sa_session.add(instance)
cls.sa_session.flush()
return instance

@classmethod
Expand All @@ -4943,13 +4938,13 @@ def get_users_by_email(cls, email):
def get_social_auth(cls, provider, uid):
uid = str(uid)
try:
return cls.trans.sa_session.query(cls).filter_by(provider=provider, uid=uid)[0]
return cls.sa_session.query(cls).filter_by(provider=provider, uid=uid)[0]
except IndexError:
return None

@classmethod
def get_social_auth_for_user(cls, user, provider=None, id=None):
qs = cls.trans.sa_session.query(cls).filter_by(user_id=user.id)
qs = cls.sa_session.query(cls).filter_by(user_id=user.id)
if provider:
qs = qs.filter_by(provider=provider)
if id:
Expand All @@ -4960,8 +4955,8 @@ def get_social_auth_for_user(cls, user, provider=None, id=None):
def create_social_auth(cls, user, uid, provider):
uid = str(uid)
instance = cls(user=user, uid=uid, provider=provider)
cls.trans.sa_session.add(instance)
cls.trans.sa_session.flush()
cls.sa_session.add(instance)
cls.sa_session.flush()
return instance


Expand Down

0 comments on commit 8ed5d8c

Please sign in to comment.