From 8d2b8e6457d5ae0ed1136091cb8c143a96abd614 Mon Sep 17 00:00:00 2001 From: Dolph Mathews Date: Wed, 1 May 2013 10:46:42 -0500 Subject: [PATCH] get SQL refs from session (bp sql-query-get) Change-Id: I2200e33868d50bb69089f3108a5a4c061afccd6e --- keystone/catalog/backends/sql.py | 10 +- keystone/credential/backends/sql.py | 26 ++-- keystone/identity/backends/sql.py | 198 ++++++++++++---------------- keystone/policy/backends/sql.py | 6 +- keystone/token/backends/sql.py | 11 +- keystone/trust/backends/sql.py | 6 +- tests/test_backend_sql.py | 5 + 7 files changed, 113 insertions(+), 149 deletions(-) diff --git a/keystone/catalog/backends/sql.py b/keystone/catalog/backends/sql.py index 9175a9e6c8..63a2ec019a 100644 --- a/keystone/catalog/backends/sql.py +++ b/keystone/catalog/backends/sql.py @@ -60,10 +60,10 @@ def list_services(self): return [s.to_dict() for s in list(services)] def _get_service(self, session, service_id): - try: - return session.query(Service).filter_by(id=service_id).one() - except sql.NotFound: + ref = session.query(Service).get(service_id) + if not ref: raise exception.ServiceNotFound(service_id=service_id) + return ref def get_service(self, service_id): session = self.get_session() @@ -112,8 +112,8 @@ def create_endpoint(self, endpoint_id, endpoint_ref): def delete_endpoint(self, endpoint_id): session = self.get_session() with session.begin(): - if not session.query(Endpoint).filter_by(id=endpoint_id).delete(): - raise exception.EndpointNotFound(endpoint_id=endpoint_id) + ref = self._get_endpoint(session, endpoint_id) + session.delete(ref) session.flush() def _get_endpoint(self, session, endpoint_id): diff --git a/keystone/credential/backends/sql.py b/keystone/credential/backends/sql.py index 721cdc6d56..cf8d4bd3e7 100644 --- a/keystone/credential/backends/sql.py +++ b/keystone/credential/backends/sql.py @@ -14,10 +14,8 @@ # License for the specific language governing permissions and limitations # under the License. -from keystone import clean from keystone.common import sql from keystone.common.sql import migration -from keystone.common import utils from keystone import credential from keystone import exception @@ -55,22 +53,21 @@ def list_credentials(self): refs = session.query(CredentialModel).all() return [ref.to_dict() for ref in refs] - def get_credential(self, credential_id): - session = self.get_session() - ref = (session.query(CredentialModel) - .filter_by(id=credential_id).first()) + def _get_credential(self, session, credential_id): + ref = session.query(CredentialModel).get(credential_id) if ref is None: raise exception.CredentialNotFound(credential_id=credential_id) - return ref.to_dict() + return ref + + def get_credential(self, credential_id): + session = self.get_session() + return self._get_credential(session, credential_id).to_dict() @sql.handle_conflicts(type='credential') def update_credential(self, credential_id, credential): session = self.get_session() with session.begin(): - ref = (session.query(CredentialModel) - .filter_by(id=credential_id).first()) - if ref is None: - raise exception.CredentialNotFound(credential_id=credential_id) + ref = self._get_credential(session, credential_id) old_dict = ref.to_dict() for k in credential: old_dict[k] = credential[k] @@ -85,12 +82,7 @@ def update_credential(self, credential_id, credential): def delete_credential(self, credential_id): session = self.get_session() - try: - ref = (session.query(CredentialModel) - .filter_by(id=credential_id).one()) - except sql.NotFound: - raise exception.CredentialNotFound(credential_id=credential_id) - with session.begin(): + ref = self._get_credential(session, credential_id) session.delete(ref) session.flush() diff --git a/keystone/identity/backends/sql.py b/keystone/identity/backends/sql.py index dd8ad6eab2..0262d87d07 100644 --- a/keystone/identity/backends/sql.py +++ b/keystone/identity/backends/sql.py @@ -153,7 +153,7 @@ def _check_password(self, password, user_ref): https://blueprints.launchpad.net/keystone/+spec/sql-identiy-pam """ - return utils.check_password(password, user_ref.get('password')) + return utils.check_password(password, user_ref.password) # Identity interface def authenticate(self, user_id=None, tenant_id=None, password=None): @@ -163,12 +163,14 @@ def authenticate(self, user_id=None, tenant_id=None, password=None): in the list of tenants on the user. """ + session = self.get_session() + user_ref = None tenant_ref = None metadata_ref = {} try: - user_ref = self._get_user(user_id) + user_ref = self._get_user(session, user_id) except exception.UserNotFound: raise AssertionError('Invalid user / password') @@ -188,14 +190,18 @@ def authenticate(self, user_id=None, tenant_id=None, password=None): metadata_ref = {} except exception.MetadataNotFound: metadata_ref = {} - return (identity.filter_user(user_ref), tenant_ref, metadata_ref) + user_ref = identity.filter_user(user_ref.to_dict()) + return (user_ref, tenant_ref, metadata_ref) + + def _get_project(self, session, project_id): + project_ref = session.query(Project).get(project_id) + if project_ref is None: + raise exception.ProjectNotFound(project_id=project_id) + return project_ref def get_project(self, tenant_id): session = self.get_session() - tenant_ref = session.query(Project).filter_by(id=tenant_id).first() - if tenant_ref is None: - raise exception.ProjectNotFound(project_id=tenant_id) - return tenant_ref.to_dict() + return self._get_project(session, tenant_id).to_dict() def get_project_by_name(self, tenant_name, domain_id): session = self.get_session() @@ -245,16 +251,16 @@ def get_metadata(self, user_id=None, tenant_id=None, def create_grant(self, role_id, user_id=None, group_id=None, domain_id=None, project_id=None): - - self.get_role(role_id) + session = self.get_session() + self._get_role(session, role_id) if user_id: - self.get_user(user_id) + self._get_user(session, user_id) if group_id: - self.get_group(group_id) + self._get_group(session, group_id) if domain_id: - self.get_domain(domain_id) + self._get_domain(session, domain_id) if project_id: - self.get_project(project_id) + self._get_project(session, project_id) try: metadata_ref = self.get_metadata(user_id, project_id, @@ -275,14 +281,15 @@ def create_grant(self, role_id, user_id=None, group_id=None, def list_grants(self, user_id=None, group_id=None, domain_id=None, project_id=None): + session = self.get_session() if user_id: - self.get_user(user_id) + self._get_user(session, user_id) if group_id: - self.get_group(group_id) + self._get_group(session, group_id) if domain_id: - self.get_domain(domain_id) + self._get_domain(session, domain_id) if project_id: - self.get_project(project_id) + self._get_project(session, project_id) try: metadata_ref = self.get_metadata(user_id, project_id, @@ -293,15 +300,16 @@ def list_grants(self, user_id=None, group_id=None, def get_grant(self, role_id, user_id=None, group_id=None, domain_id=None, project_id=None): - self.get_role(role_id) + session = self.get_session() + role_ref = self._get_role(session, role_id) if user_id: - self.get_user(user_id) + self._get_user(session, user_id) if group_id: - self.get_group(group_id) + self._get_group(session, group_id) if domain_id: - self.get_domain(domain_id) + self._get_domain(session, domain_id) if project_id: - self.get_project(project_id) + self._get_project(session, project_id) try: metadata_ref = self.get_metadata(user_id, project_id, @@ -311,19 +319,20 @@ def get_grant(self, role_id, user_id=None, group_id=None, role_ids = set(metadata_ref.get('roles', [])) if role_id not in role_ids: raise exception.RoleNotFound(role_id=role_id) - return self.get_role(role_id) + return role_ref.to_dict() def delete_grant(self, role_id, user_id=None, group_id=None, domain_id=None, project_id=None): - self.get_role(role_id) + session = self.get_session() + self._get_role(session, role_id) if user_id: - self.get_user(user_id) + self._get_user(session, user_id) if group_id: - self.get_group(group_id) + self._get_group(session, group_id) if domain_id: - self.get_domain(domain_id) + self._get_domain(session, domain_id) if project_id: - self.get_project(project_id) + self._get_project(session, project_id) try: metadata_ref = self.get_metadata(user_id, project_id, @@ -352,7 +361,7 @@ def list_projects(self): def get_projects_for_user(self, user_id): session = self.get_session() - self.get_user(user_id) + self._get_user(session, user_id) query = session.query(UserProjectGrant) query = query.filter_by(user_id=user_id) membership_refs = query.all() @@ -376,17 +385,19 @@ def _get_user_project_roles(self, metadata_ref, user_id, project_id): pass def get_roles_for_user_and_project(self, user_id, tenant_id): - self.get_user(user_id) - self.get_project(tenant_id) + session = self.get_session() + self._get_user(session, user_id) + self._get_project(session, tenant_id) metadata_ref = {} self._get_user_project_roles(metadata_ref, user_id, tenant_id) self._get_user_group_project_roles(metadata_ref, user_id, tenant_id) return list(set(metadata_ref.get('roles', []))) def add_role_to_user_and_project(self, user_id, tenant_id, role_id): - self.get_user(user_id) - self.get_project(tenant_id) - self.get_role(role_id) + session = self.get_session() + self._get_user(session, user_id) + self._get_project(session, tenant_id) + self._get_role(session, role_id) try: metadata_ref = self.get_metadata(user_id, tenant_id) is_new = False @@ -443,12 +454,9 @@ def update_project(self, tenant_id, tenant): if 'name' in tenant: tenant['name'] = clean.project_name(tenant['name']) - try: - tenant_ref = session.query(Project).filter_by(id=tenant_id).one() - except sql.NotFound: - raise exception.ProjectNotFound(project_id=tenant_id) with session.begin(): + tenant_ref = self._get_project(session, tenant_id) old_project_dict = tenant_ref.to_dict() for k in tenant: old_project_dict[k] = tenant[k] @@ -464,12 +472,9 @@ def update_project(self, tenant_id, tenant): def delete_project(self, tenant_id): session = self.get_session() - try: - tenant_ref = session.query(Project).filter_by(id=tenant_id).one() - except sql.NotFound: - raise exception.ProjectNotFound(project_id=tenant_id) - with session.begin(): + tenant_ref = self._get_project(session, tenant_id) + q = session.query(UserProjectGrant) q = q.filter_by(project_id=tenant_id) q.delete(False) @@ -482,10 +487,6 @@ def delete_project(self, tenant_id): q = q.filter_by(project_id=tenant_id) q.delete(False) - delete_query = session.query(Project).filter_by(id=tenant_id) - if not delete_query.delete(False): - raise exception.ProjectNotFound(project_id=tenant_id) - session.delete(tenant_ref) session.flush() @@ -561,14 +562,16 @@ def list_domains(self): refs = session.query(Domain).all() return [ref.to_dict() for ref in refs] - def get_domain(self, domain_id): - session = self.get_session() - ref = session.query(Domain).filter_by(id=domain_id).first() + def _get_domain(self, session, domain_id): + ref = session.query(Domain).get(domain_id) if ref is None: raise exception.DomainNotFound(domain_id=domain_id) - return ref.to_dict() + return ref + + def get_domain(self, domain_id): + session = self.get_session() + return self._get_domain(session, domain_id).to_dict() - @sql.handle_conflicts(type='domain') def get_domain_by_name(self, domain_name): session = self.get_session() try: @@ -581,9 +584,7 @@ def get_domain_by_name(self, domain_name): def update_domain(self, domain_id, domain): session = self.get_session() with session.begin(): - ref = session.query(Domain).filter_by(id=domain_id).first() - if ref is None: - raise exception.DomainNotFound(domain_id=domain_id) + ref = self._get_domain(session, domain_id) old_dict = ref.to_dict() for k in domain: old_dict[k] = domain[k] @@ -597,10 +598,8 @@ def update_domain(self, domain_id, domain): def delete_domain(self, domain_id): session = self.get_session() - ref = session.query(Domain).filter_by(id=domain_id).first() - if not ref: - raise exception.DomainNotFound(domain_id=domain_id) with session.begin(): + ref = self._get_domain(session, domain_id) session.delete(ref) session.flush() @@ -640,14 +639,17 @@ def list_users(self): user_refs = session.query(User) return [identity.filter_user(x.to_dict()) for x in user_refs] - def _get_user(self, user_id): - session = self.get_session() - user_ref = session.query(User).filter_by(id=user_id).first() + def _get_user(self, session, user_id): + user_ref = session.query(User).get(user_id) if not user_ref: raise exception.UserNotFound(user_id=user_id) - return user_ref.to_dict() + return user_ref - def _get_user_by_name(self, user_name, domain_id): + def get_user(self, user_id): + session = self.get_session() + return identity.filter_user(self._get_user(session, user_id).to_dict()) + + def get_user_by_name(self, user_name, domain_id): session = self.get_session() query = session.query(User) query = query.filter_by(name=user_name) @@ -656,14 +658,7 @@ def _get_user_by_name(self, user_name, domain_id): user_ref = query.one() except sql.NotFound: raise exception.UserNotFound(user_id=user_name) - return user_ref.to_dict() - - def get_user(self, user_id): - return identity.filter_user(self._get_user(user_id)) - - def get_user_by_name(self, user_name, domain_id): - return identity.filter_user( - self._get_user_by_name(user_name, domain_id)) + return identity.filter_user(user_ref.to_dict()) @sql.handle_conflicts(type='user') def update_user(self, user_id, user): @@ -676,9 +671,7 @@ def update_user(self, user_id, user): raise exception.ValidationError('Cannot change user ID') with session.begin(): - user_ref = session.query(User).filter_by(id=user_id).first() - if user_ref is None: - raise exception.UserNotFound(user_id=user_id) + user_ref = self._get_user(session, user_id) old_user_dict = user_ref.to_dict() user = utils.hash_user_password(user) for k in user: @@ -750,12 +743,8 @@ def list_users_in_group(self, group_id): def delete_user(self, user_id): session = self.get_session() - try: - ref = session.query(User).filter_by(id=user_id).one() - except sql.NotFound: - raise exception.UserNotFound(user_id=user_id) - with session.begin(): + ref = self._get_user(session, user_id) q = session.query(UserProjectGrant) q = q.filter_by(user_id=user_id) @@ -769,9 +758,6 @@ def delete_user(self, user_id): q = q.filter_by(user_id=user_id) q.delete(False) - if not session.query(User).filter_by(id=user_id).delete(False): - raise exception.UserNotFound(user_id=user_id) - session.delete(ref) session.flush() @@ -791,24 +777,22 @@ def list_groups(self): refs = session.query(Group).all() return [ref.to_dict() for ref in refs] - def _get_group(self, group_id): - session = self.get_session() - ref = session.query(Group).filter_by(id=group_id).first() + def _get_group(self, session, group_id): + ref = session.query(Group).get(group_id) if not ref: raise exception.GroupNotFound(group_id=group_id) - return ref.to_dict() + return ref def get_group(self, group_id): - return self._get_group(group_id) + session = self.get_session() + return self._get_group(session, group_id).to_dict() @sql.handle_conflicts(type='group') def update_group(self, group_id, group): session = self.get_session() with session.begin(): - ref = session.query(Group).filter_by(id=group_id).first() - if ref is None: - raise exception.GroupNotFound(group_id=group_id) + ref = self._get_group(session, group_id) old_dict = ref.to_dict() for k in group: old_dict[k] = group[k] @@ -823,12 +807,9 @@ def update_group(self, group_id, group): def delete_group(self, group_id): session = self.get_session() - try: - ref = session.query(Group).filter_by(id=group_id).one() - except sql.NotFound: - raise exception.GroupNotFound(group_id=group_id) - with session.begin(): + ref = self._get_group(session, group_id) + q = session.query(GroupProjectGrant) q = q.filter_by(group_id=group_id) q.delete(False) @@ -841,9 +822,6 @@ def delete_group(self, group_id): q = q.filter_by(group_id=group_id) q.delete(False) - if not session.query(Group).filter_by(id=group_id).delete(False): - raise exception.GroupNotFound(group_id=group_id) - session.delete(ref) session.flush() @@ -863,20 +841,21 @@ def list_roles(self): refs = session.query(Role).all() return [ref.to_dict() for ref in refs] - def get_role(self, role_id): - session = self.get_session() - ref = session.query(Role).filter_by(id=role_id).first() + def _get_role(self, session, role_id): + ref = session.query(Role).get(role_id) if ref is None: raise exception.RoleNotFound(role_id=role_id) - return ref.to_dict() + return ref + + def get_role(self, role_id): + session = self.get_session() + return self._get_role(session, role_id).to_dict() @sql.handle_conflicts(type='role') def update_role(self, role_id, role): session = self.get_session() with session.begin(): - ref = session.query(Role).filter_by(id=role_id).first() - if ref is None: - raise exception.RoleNotFound(role_id=role_id) + ref = self._get_role(session, role_id) old_dict = ref.to_dict() for k in role: old_dict[k] = role[k] @@ -891,12 +870,8 @@ def update_role(self, role_id, role): def delete_role(self, role_id): session = self.get_session() - try: - ref = session.query(Role).filter_by(id=role_id).one() - except sql.NotFound: - raise exception.RoleNotFound(role_id=role_id) - with session.begin(): + ref = self._get_role(session, role_id) for metadata_ref in session.query(UserProjectGrant): try: self.delete_grant(role_id, user_id=metadata_ref.user_id, @@ -922,8 +897,5 @@ def delete_role(self, role_id): except exception.RoleNotFound: pass - if not session.query(Role).filter_by(id=role_id).delete(): - raise exception.RoleNotFound(role_id=role_id) - session.delete(ref) session.flush() diff --git a/keystone/policy/backends/sql.py b/keystone/policy/backends/sql.py index 2472c1ed5c..c1eff268b8 100644 --- a/keystone/policy/backends/sql.py +++ b/keystone/policy/backends/sql.py @@ -53,10 +53,10 @@ def list_policies(self): def _get_policy(self, session, policy_id): """Private method to get a policy model object (NOT a dictionary).""" - try: - return session.query(PolicyModel).filter_by(id=policy_id).one() - except sql.NotFound: + ref = session.query(PolicyModel).get(policy_id) + if not ref: raise exception.PolicyNotFound(policy_id=policy_id) + return ref def get_policy(self, policy_id): session = self.get_session() diff --git a/keystone/token/backends/sql.py b/keystone/token/backends/sql.py index fef3b81b8a..2e68bdc975 100644 --- a/keystone/token/backends/sql.py +++ b/keystone/token/backends/sql.py @@ -41,11 +41,9 @@ def get_token(self, token_id): if token_id is None: raise exception.TokenNotFound(token_id=token_id) session = self.get_session() - query = session.query(TokenModel) - query = query.filter_by(id=token.unique_id(token_id), valid=True) - token_ref = query.first() + token_ref = session.query(TokenModel).get(token.unique_id(token_id)) now = datetime.datetime.utcnow() - if not token_ref: + if not token_ref or not token_ref.valid: raise exception.TokenNotFound(token_id=token_id) if not token_ref.expires: raise exception.TokenNotFound(token_id=token_id) @@ -73,9 +71,8 @@ def delete_token(self, token_id): session = self.get_session() key = token.unique_id(token_id) with session.begin(): - token_ref = session.query(TokenModel).filter_by(id=key, - valid=True).first() - if not token_ref: + token_ref = session.query(TokenModel).get(key) + if not token_ref or not token_ref.valid: raise exception.TokenNotFound(token_id=token_id) token_ref.valid = False session.flush() diff --git a/keystone/trust/backends/sql.py b/keystone/trust/backends/sql.py index cd68b0bc05..daa8e3f7d3 100644 --- a/keystone/trust/backends/sql.py +++ b/keystone/trust/backends/sql.py @@ -114,10 +114,8 @@ def list_trusts_for_trustor(self, trustor_user_id): def delete_trust(self, trust_id): session = self.get_session() with session.begin(): - try: - trust_ref = (session.query(TrustModel). - filter_by(id=trust_id).one()) - except sql.NotFound: + trust_ref = session.query(TrustModel).get(trust_id) + if not trust_ref: raise exception.TrustNotFound(trust_id=trust_id) trust_ref.deleted_at = timeutils.utcnow() session.flush() diff --git a/tests/test_backend_sql.py b/tests/test_backend_sql.py index 16408ed204..e4a19ef9d0 100644 --- a/tests/test_backend_sql.py +++ b/tests/test_backend_sql.py @@ -73,6 +73,11 @@ def tearDown(self): class SqlIdentity(SqlTests, test_backend.IdentityTests): + def test_password_hashed(self): + session = self.identity_api.get_session() + user_ref = self.identity_api._get_user(session, self.user_foo['id']) + self.assertNotEqual(user_ref['password'], self.user_foo['password']) + def test_delete_user_with_project_association(self): user = {'id': uuid.uuid4().hex, 'name': uuid.uuid4().hex,