diff --git a/apiserver/dora/store/__init__.py b/apiserver/dora/store/__init__.py index 45f48b876..2f959d93a 100644 --- a/apiserver/dora/store/__init__.py +++ b/apiserver/dora/store/__init__.py @@ -2,6 +2,8 @@ from flask_sqlalchemy import SQLAlchemy +from dora.utils.log import LOG + db = SQLAlchemy() @@ -20,3 +22,15 @@ def configure_db_with_app(app): app.config["SQLALCHEMY_DATABASE_URI"] = connection_uri app.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"pool_size": 20, "max_overflow": 5} db.init_app(app) + + +def rollback_on_exc(func): + def wrapper(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except Exception as e: + self._db.session.rollback() + LOG.error(f"Error in {func.__name__} - {str(e)}") + raise + + return wrapper diff --git a/apiserver/dora/store/repos/code.py b/apiserver/dora/store/repos/code.py index 6f68a7534..512d303d5 100644 --- a/apiserver/dora/store/repos/code.py +++ b/apiserver/dora/store/repos/code.py @@ -5,7 +5,7 @@ from sqlalchemy import or_ from sqlalchemy.orm import defer -from dora.store import db +from dora.store import db, rollback_on_exc from dora.store.models.code import ( PullRequest, PullRequestEvent, @@ -22,47 +22,56 @@ class CodeRepoService: + def __init__(self): + self._db = db + + @rollback_on_exc def get_active_org_repos(self, org_id: str) -> List[OrgRepo]: return ( - db.session.query(OrgRepo) + self._db.session.query(OrgRepo) .filter(OrgRepo.org_id == org_id, OrgRepo.is_active.is_(True)) .all() ) + @rollback_on_exc def update_org_repos(self, org_repos: List[OrgRepo]): - [db.session.merge(org_repo) for org_repo in org_repos] - db.session.commit() + [self._db.session.merge(org_repo) for org_repo in org_repos] + self._db.session.commit() + @rollback_on_exc def save_pull_requests_data( self, pull_requests: List[PullRequest], pull_request_commits: List[PullRequestCommit], pull_request_events: List[PullRequestEvent], ): - [db.session.merge(pull_request) for pull_request in pull_requests] + [self._db.session.merge(pull_request) for pull_request in pull_requests] [ - db.session.merge(pull_request_commit) + self._db.session.merge(pull_request_commit) for pull_request_commit in pull_request_commits ] [ - db.session.merge(pull_request_event) + self._db.session.merge(pull_request_event) for pull_request_event in pull_request_events ] - db.session.commit() + self._db.session.commit() + @rollback_on_exc def update_prs(self, prs: List[PullRequest]): - [db.session.merge(pr) for pr in prs] - db.session.commit() + [self._db.session.merge(pr) for pr in prs] + self._db.session.commit() + @rollback_on_exc def save_revert_pr_mappings( self, revert_pr_mappings: List[PullRequestRevertPRMapping] ): - [db.session.merge(revert_pr_map) for revert_pr_map in revert_pr_mappings] - db.session.commit() + [self._db.session.merge(revert_pr_map) for revert_pr_map in revert_pr_mappings] + self._db.session.commit() + @rollback_on_exc def get_org_repo_bookmark(self, org_repo: OrgRepo, bookmark_type): return ( - db.session.query(Bookmark) + self._db.session.query(Bookmark) .filter( and_( Bookmark.repo_id == org_repo.id, @@ -72,16 +81,21 @@ def get_org_repo_bookmark(self, org_repo: OrgRepo, bookmark_type): .one_or_none() ) + @rollback_on_exc def update_org_repo_bookmark(self, bookmark: Bookmark): - db.session.merge(bookmark) - db.session.commit() + self._db.session.merge(bookmark) + self._db.session.commit() + @rollback_on_exc def get_repo_by_id(self, repo_id: str) -> Optional[OrgRepo]: - return db.session.query(OrgRepo).filter(OrgRepo.id == repo_id).one_or_none() + return ( + self._db.session.query(OrgRepo).filter(OrgRepo.id == repo_id).one_or_none() + ) + @rollback_on_exc def get_repo_pr_by_number(self, repo_id: str, pr_number) -> Optional[PullRequest]: return ( - db.session.query(PullRequest) + self._db.session.query(PullRequest) .options(defer(PullRequest.data)) .filter( and_( @@ -91,31 +105,34 @@ def get_repo_pr_by_number(self, repo_id: str, pr_number) -> Optional[PullRequest .one_or_none() ) + @rollback_on_exc def get_pr_events(self, pr_model: PullRequest): if not pr_model: return [] pr_events = ( - db.session.query(PullRequestEvent) + self._db.session.query(PullRequestEvent) .options(defer(PullRequestEvent.data)) .filter(PullRequestEvent.pull_request_id == pr_model.id) .all() ) return pr_events + @rollback_on_exc def get_prs_by_ids(self, pr_ids: List[str]): query = ( - db.session.query(PullRequest) + self._db.session.query(PullRequest) .options(defer(PullRequest.data)) .filter(PullRequest.id.in_(pr_ids)) ) return query.all() + @rollback_on_exc def get_prs_by_head_branch_match_strings( self, repo_ids: List[str], match_strings: List[str] ) -> List[PullRequest]: query = ( - db.session.query(PullRequest) + self._db.session.query(PullRequest) .options(defer(PullRequest.data)) .filter( and_( @@ -133,11 +150,12 @@ def get_prs_by_head_branch_match_strings( return query.all() + @rollback_on_exc def get_reverted_prs_by_numbers( self, repo_ids: List[str], numbers: List[str] ) -> List[PullRequest]: query = ( - db.session.query(PullRequest) + self._db.session.query(PullRequest) .options(defer(PullRequest.data)) .filter( and_( @@ -150,27 +168,31 @@ def get_reverted_prs_by_numbers( return query.all() + @rollback_on_exc def get_active_team_repos_by_team_id(self, team_id: str) -> List[TeamRepos]: return ( - db.session.query(TeamRepos) + self._db.session.query(TeamRepos) .filter(TeamRepos.team_id == team_id, TeamRepos.is_active.is_(True)) .all() ) + @rollback_on_exc def get_active_team_repos_by_team_ids(self, team_ids: List[str]) -> List[TeamRepos]: return ( - db.session.query(TeamRepos) + self._db.session.query(TeamRepos) .filter(TeamRepos.team_id.in_(team_ids), TeamRepos.is_active.is_(True)) .all() ) + @rollback_on_exc def get_active_org_repos_by_ids(self, repo_ids: List[str]) -> List[OrgRepo]: return ( - db.session.query(OrgRepo) + self._db.session.query(OrgRepo) .filter(OrgRepo.id.in_(repo_ids), OrgRepo.is_active.is_(True)) .all() ) + @rollback_on_exc def get_prs_merged_in_interval( self, repo_ids: List[str], @@ -179,7 +201,7 @@ def get_prs_merged_in_interval( base_branches: List[str] = None, has_non_null_mtd=False, ) -> List[PullRequest]: - query = db.session.query(PullRequest).options(defer(PullRequest.data)) + query = self._db.session.query(PullRequest).options(defer(PullRequest.data)) query = self._filter_prs_by_repo_ids(query, repo_ids) query = self._filter_prs_merged_in_interval(query, interval) @@ -194,17 +216,19 @@ def get_prs_merged_in_interval( return query.all() + @rollback_on_exc def get_pull_request_by_id(self, pr_id: str) -> PullRequest: return ( - db.session.query(PullRequest) + self._db.session.query(PullRequest) .options(defer(PullRequest.data)) .filter(PullRequest.id == pr_id) .one_or_none() ) + @rollback_on_exc def get_previous_pull_request(self, pull_request: PullRequest) -> PullRequest: return ( - db.session.query(PullRequest) + self._db.session.query(PullRequest) .options(defer(PullRequest.data)) .filter( PullRequest.repo_id == pull_request.repo_id, @@ -216,15 +240,17 @@ def get_previous_pull_request(self, pull_request: PullRequest) -> PullRequest: .first() ) + @rollback_on_exc def get_repos_by_ids(self, ids: List[str]) -> List[OrgRepo]: if not ids: return [] - return db.session.query(OrgRepo).filter(OrgRepo.id.in_(ids)).all() + return self._db.session.query(OrgRepo).filter(OrgRepo.id.in_(ids)).all() + @rollback_on_exc def get_team_repos(self, team_id) -> List[OrgRepo]: team_repos = ( - db.session.query(TeamRepos) + self._db.session.query(TeamRepos) .filter(and_(TeamRepos.team_id == team_id, TeamRepos.is_active == True)) .all() ) @@ -234,26 +260,29 @@ def get_team_repos(self, team_id) -> List[OrgRepo]: team_repo_ids = [tr.org_repo_id for tr in team_repos] return self.get_repos_by_ids(team_repo_ids) + @rollback_on_exc def get_merge_to_deploy_broker_bookmark( self, repo_id: str ) -> BookmarkMergeToDeployBroker: return ( - db.session.query(BookmarkMergeToDeployBroker) + self._db.session.query(BookmarkMergeToDeployBroker) .filter(BookmarkMergeToDeployBroker.repo_id == repo_id) .one_or_none() ) + @rollback_on_exc def update_merge_to_deploy_broker_bookmark( self, bookmark: BookmarkMergeToDeployBroker ): - db.session.merge(bookmark) - db.session.commit() + self._db.session.merge(bookmark) + self._db.session.commit() + @rollback_on_exc def get_prs_in_repo_merged_before_given_date_with_merge_to_deploy_as_null( self, repo_id: str, to_time: datetime ): return ( - db.session.query(PullRequest) + self._db.session.query(PullRequest) .options(defer(PullRequest.data)) .filter( PullRequest.repo_id == repo_id, @@ -264,11 +293,12 @@ def get_prs_in_repo_merged_before_given_date_with_merge_to_deploy_as_null( .all() ) + @rollback_on_exc def get_repo_revert_prs_mappings_updated_in_interval( self, repo_id, from_time, to_time ) -> List[PullRequestRevertPRMapping]: query = ( - db.session.query(PullRequestRevertPRMapping) + self._db.session.query(PullRequestRevertPRMapping) .join(PullRequest, PullRequest.id == PullRequestRevertPRMapping.pr_id) .filter( PullRequest.repo_id == repo_id, diff --git a/apiserver/dora/store/repos/core.py b/apiserver/dora/store/repos/core.py index b479127b3..d26d06008 100644 --- a/apiserver/dora/store/repos/core.py +++ b/apiserver/dora/store/repos/core.py @@ -2,7 +2,7 @@ from sqlalchemy import and_ -from dora.store import db +from dora.store import db, rollback_on_exc from dora.store.models import UserIdentityProvider, Integration from dora.store.models.core import Organization, Team, Users from dora.utils.cryptography import get_crypto_service @@ -11,56 +11,64 @@ class CoreRepoService: def __init__(self): self._crypto = get_crypto_service() + self._db = db + @rollback_on_exc def get_org(self, org_id): return ( - db.session.query(Organization) + self._db.session.query(Organization) .filter(Organization.id == org_id) .one_or_none() ) + @rollback_on_exc def get_org_by_name(self, org_name: str): return ( - db.session.query(Organization) + self._db.session.query(Organization) .filter(Organization.name == org_name) .one_or_none() ) + @rollback_on_exc def get_team(self, team_id: str) -> Team: return ( - db.session.query(Team) + self._db.session.query(Team) .filter(Team.id == team_id, Team.is_deleted.is_(False)) .one_or_none() ) + @rollback_on_exc def delete_team(self, team_id: str): - team = db.session.query(Team).filter(Team.id == team_id).one_or_none() + team = self._db.session.query(Team).filter(Team.id == team_id).one_or_none() if not team: return None team.is_deleted = True - db.session.merge(team) - db.session.commit() - return db.session.query(Team).filter(Team.id == team_id).one_or_none() + self._db.session.merge(team) + self._db.session.commit() + return self._db.session.query(Team).filter(Team.id == team_id).one_or_none() + @rollback_on_exc def get_user(self, user_id) -> Optional[Users]: - return db.session.query(Users).filter(Users.id == user_id).one_or_none() + return self._db.session.query(Users).filter(Users.id == user_id).one_or_none() + @rollback_on_exc def get_org_integrations_for_names(self, org_id: str, provider_names: List[str]): return ( - db.session.query(Integration) + self._db.session.query(Integration) .filter( and_(Integration.org_id == org_id, Integration.name.in_(provider_names)) ) .all() ) + @rollback_on_exc def get_access_token(self, org_id, provider: UserIdentityProvider) -> Optional[str]: user_identity: Integration = ( - db.session.query(Integration) + self._db.session.query(Integration) .filter( and_(Integration.org_id == org_id, Integration.name == provider.value) ) diff --git a/apiserver/dora/store/repos/incidents.py b/apiserver/dora/store/repos/incidents.py index 52967b627..532fa8c99 100644 --- a/apiserver/dora/store/repos/incidents.py +++ b/apiserver/dora/store/repos/incidents.py @@ -2,7 +2,7 @@ from sqlalchemy import and_ -from dora.store import db +from dora.store import db, rollback_on_exc from dora.store.models.incidents import ( Incident, IncidentFilter, @@ -19,17 +19,26 @@ class IncidentsRepoService: + def __init__(self): + self._db = db + + @rollback_on_exc def get_org_incident_services(self, org_id: str) -> List[OrgIncidentService]: return ( - db.session.query(OrgIncidentService) + self._db.session.query(OrgIncidentService) .filter(OrgIncidentService.org_id == org_id) .all() ) + @rollback_on_exc def update_org_incident_services(self, incident_services: List[OrgIncidentService]): - [db.session.merge(incident_service) for incident_service in incident_services] - db.session.commit() + [ + self._db.session.merge(incident_service) + for incident_service in incident_services + ] + self._db.session.commit() + @rollback_on_exc def get_incidents_bookmark( self, entity_id: str, @@ -37,7 +46,7 @@ def get_incidents_bookmark( provider: IncidentProvider, ) -> IncidentsBookmark: return ( - db.session.query(IncidentsBookmark) + self._db.session.query(IncidentsBookmark) .filter( and_( IncidentsBookmark.entity_id == entity_id, @@ -48,22 +57,25 @@ def get_incidents_bookmark( .one_or_none() ) + @rollback_on_exc def save_incidents_bookmark(self, bookmark: IncidentsBookmark): - db.session.merge(bookmark) - db.session.commit() + self._db.session.merge(bookmark) + self._db.session.commit() + @rollback_on_exc def save_incidents_data( self, incidents: List[Incident], incident_org_incident_service_map: List[IncidentOrgIncidentServiceMap], ): - [db.session.merge(incident) for incident in incidents] + [self._db.session.merge(incident) for incident in incidents] [ - db.session.merge(incident_service_map) + self._db.session.merge(incident_service_map) for incident_service_map in incident_org_incident_service_map ] - db.session.commit() + self._db.session.commit() + @rollback_on_exc def get_resolved_team_incidents( self, team_id: str, interval: Interval, incident_filter: IncidentFilter = None ) -> List[Incident]: @@ -78,6 +90,7 @@ def get_resolved_team_incidents( return query.all() + @rollback_on_exc def get_team_incidents( self, team_id: str, interval: Interval, incident_filter: IncidentFilter = None ) -> List[Incident]: @@ -89,11 +102,12 @@ def get_team_incidents( return query.all() + @rollback_on_exc def get_incident_by_key_type_and_provider( self, key: str, incident_type: IncidentType, provider: IncidentProvider ) -> Incident: return ( - db.session.query(Incident) + self._db.session.query(Incident) .filter( and_( Incident.key == key, @@ -108,7 +122,7 @@ def _get_team_incidents_query( self, team_id: str, incident_filter: IncidentFilter = None ): query = ( - db.session.query(Incident) + self._db.session.query(Incident) .join( IncidentOrgIncidentServiceMap, Incident.id == IncidentOrgIncidentServiceMap.incident_id, diff --git a/apiserver/dora/store/repos/settings.py b/apiserver/dora/store/repos/settings.py index 27a6e4a60..24736d004 100644 --- a/apiserver/dora/store/repos/settings.py +++ b/apiserver/dora/store/repos/settings.py @@ -2,7 +2,7 @@ from sqlalchemy import and_ -from dora.store import db +from dora.store import db, rollback_on_exc from dora.store.models import ( Settings, SettingType, @@ -13,11 +13,15 @@ class SettingsRepoService: + def __init__(self): + self._db = db + + @rollback_on_exc def get_setting( self, entity_id: str, entity_type: EntityType, setting_type: SettingType ) -> Optional[Settings]: return ( - db.session.query(Settings) + self._db.session.query(Settings) .filter( and_( Settings.setting_type == setting_type, @@ -29,14 +33,16 @@ def get_setting( .one_or_none() ) + @rollback_on_exc def create_settings(self, settings: List[Settings]) -> List[Settings]: - [db.session.merge(setting) for setting in settings] - db.session.commit() + [self._db.session.merge(setting) for setting in settings] + self._db.session.commit() return settings + @rollback_on_exc def save_setting(self, setting: Settings) -> Optional[Settings]: - db.session.merge(setting) - db.session.commit() + self._db.session.merge(setting) + self._db.session.commit() return self.get_setting( entity_id=setting.entity_id, @@ -44,6 +50,7 @@ def save_setting(self, setting: Settings) -> Optional[Settings]: setting_type=setting.setting_type, ) + @rollback_on_exc def delete_setting( self, entity_id: str, @@ -58,10 +65,11 @@ def delete_setting( setting.is_deleted = True setting.updated_by = deleted_by.id setting.updated_at = time_now() - db.session.merge(setting) - db.session.commit() + self._db.session.merge(setting) + self._db.session.commit() return setting + @rollback_on_exc def get_settings( self, entity_id: str, @@ -69,7 +77,7 @@ def get_settings( setting_types: List[SettingType], ) -> Optional[Settings]: return ( - db.session.query(Settings) + self._db.session.query(Settings) .filter( and_( Settings.setting_type.in_(setting_types), diff --git a/apiserver/dora/store/repos/workflows.py b/apiserver/dora/store/repos/workflows.py index 4c7ed1146..8b8fefa23 100644 --- a/apiserver/dora/store/repos/workflows.py +++ b/apiserver/dora/store/repos/workflows.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import defer from sqlalchemy import and_ -from dora.store import db +from dora.store import db, rollback_on_exc from dora.store.models.code.workflows.enums import ( RepoWorkflowRunsStatus, RepoWorkflowType, @@ -20,12 +20,16 @@ class WorkflowRepoService: + def __init__(self): + self._db = db + + @rollback_on_exc def get_active_repo_workflows_by_repo_ids_and_providers( self, repo_ids: List[str], providers: List[RepoWorkflowProviders] ) -> List[RepoWorkflow]: return ( - db.session.query(RepoWorkflow) + self._db.session.query(RepoWorkflow) .options(defer(RepoWorkflow.meta)) .filter( RepoWorkflow.org_repo_id.in_(repo_ids), @@ -35,11 +39,12 @@ def get_active_repo_workflows_by_repo_ids_and_providers( .all() ) + @rollback_on_exc def get_repo_workflow_run_by_provider_workflow_run_id( self, repo_workflow_id: str, provider_workflow_run_id: str ) -> RepoWorkflowRuns: return ( - db.session.query(RepoWorkflowRuns) + self._db.session.query(RepoWorkflowRuns) .filter( RepoWorkflowRuns.repo_workflow_id == repo_workflow_id, RepoWorkflowRuns.provider_workflow_run_id == provider_workflow_run_id, @@ -47,31 +52,35 @@ def get_repo_workflow_run_by_provider_workflow_run_id( .one_or_none() ) + @rollback_on_exc def save_repo_workflow_runs(self, repo_workflow_runs: List[RepoWorkflowRuns]): [ - db.session.merge(repo_workflow_run) + self._db.session.merge(repo_workflow_run) for repo_workflow_run in repo_workflow_runs ] - db.session.commit() + self._db.session.commit() + @rollback_on_exc def get_repo_workflow_runs_bookmark( self, repo_workflow_id: str ) -> RepoWorkflowRunsBookmark: return ( - db.session.query(RepoWorkflowRunsBookmark) + self._db.session.query(RepoWorkflowRunsBookmark) .filter(RepoWorkflowRunsBookmark.repo_workflow_id == repo_workflow_id) .one_or_none() ) + @rollback_on_exc def update_repo_workflow_runs_bookmark(self, bookmark: RepoWorkflowRunsBookmark): - db.session.merge(bookmark) - db.session.commit() + self._db.session.merge(bookmark) + self._db.session.commit() + @rollback_on_exc def get_repo_workflow_by_repo_ids( self, repo_ids: List[str], type: RepoWorkflowType ) -> List[RepoWorkflow]: return ( - db.session.query(RepoWorkflow) + self._db.session.query(RepoWorkflow) .options(defer(RepoWorkflow.meta)) .filter( and_( @@ -83,9 +92,10 @@ def get_repo_workflow_by_repo_ids( .all() ) + @rollback_on_exc def get_repo_workflows_by_repo_id(self, repo_id: str) -> List[RepoWorkflow]: return ( - db.session.query(RepoWorkflow) + self._db.session.query(RepoWorkflow) .options(defer(RepoWorkflow.meta)) .filter( RepoWorkflow.org_repo_id == repo_id, @@ -94,11 +104,12 @@ def get_repo_workflows_by_repo_id(self, repo_id: str) -> List[RepoWorkflow]: .all() ) + @rollback_on_exc def get_successful_repo_workflows_runs_by_repo_ids( self, repo_ids: List[str], interval: Interval, workflow_filter: WorkflowFilter ) -> List[Tuple[RepoWorkflow, RepoWorkflowRuns]]: query = ( - db.session.query(RepoWorkflow, RepoWorkflowRuns) + self._db.session.query(RepoWorkflow, RepoWorkflowRuns) .options(defer(RepoWorkflow.meta), defer(RepoWorkflowRuns.meta)) .join( RepoWorkflowRuns, RepoWorkflow.id == RepoWorkflowRuns.repo_workflow_id @@ -117,6 +128,7 @@ def get_successful_repo_workflows_runs_by_repo_ids( return query.all() + @rollback_on_exc def get_repos_workflow_runs_by_repo_ids( self, repo_ids: List[str], @@ -124,7 +136,7 @@ def get_repos_workflow_runs_by_repo_ids( workflow_filter: WorkflowFilter = None, ) -> List[Tuple[RepoWorkflow, RepoWorkflowRuns]]: query = ( - db.session.query(RepoWorkflow, RepoWorkflowRuns) + self._db.session.query(RepoWorkflow, RepoWorkflowRuns) .options(defer(RepoWorkflow.meta), defer(RepoWorkflowRuns.meta)) .join( RepoWorkflowRuns, RepoWorkflow.id == RepoWorkflowRuns.repo_workflow_id @@ -141,22 +153,24 @@ def get_repos_workflow_runs_by_repo_ids( return query.all() + @rollback_on_exc def get_repo_workflow_run_by_id( self, repo_workflow_run_id: str ) -> Tuple[RepoWorkflow, RepoWorkflowRuns]: return ( - db.session.query(RepoWorkflow, RepoWorkflowRuns) + self._db.session.query(RepoWorkflow, RepoWorkflowRuns) .options(defer(RepoWorkflow.meta), defer(RepoWorkflowRuns.meta)) .join(RepoWorkflow, RepoWorkflow.id == RepoWorkflowRuns.repo_workflow_id) .filter(RepoWorkflowRuns.id == repo_workflow_run_id) .one_or_none() ) + @rollback_on_exc def get_previous_workflow_run( self, workflow_run: RepoWorkflowRuns ) -> Tuple[RepoWorkflow, RepoWorkflowRuns]: return ( - db.session.query(RepoWorkflow, RepoWorkflowRuns) + self._db.session.query(RepoWorkflow, RepoWorkflowRuns) .options(defer(RepoWorkflow.meta), defer(RepoWorkflowRuns.meta)) .join(RepoWorkflow, RepoWorkflow.id == RepoWorkflowRuns.repo_workflow_id) .filter( @@ -168,11 +182,12 @@ def get_previous_workflow_run( .first() ) + @rollback_on_exc def get_repo_workflow_runs_conducted_after_time( self, repo_id: str, from_time: datetime = None, limit_value: int = 500 ): query = ( - db.session.query(RepoWorkflowRuns) + self._db.session.query(RepoWorkflowRuns) .options(defer(RepoWorkflowRuns.meta)) .join(RepoWorkflow, RepoWorkflow.id == RepoWorkflowRuns.repo_workflow_id) .filter(