Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions apiserver/dora/store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from flask_sqlalchemy import SQLAlchemy

from dora.utils.log import LOG

db = SQLAlchemy()


Expand All @@ -20,3 +22,15 @@ def configure_db_with_app(app):
app.config["SQLALCHEMY_DATABASE_URI"] = connection_uri
app.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"pool_size": 20, "max_overflow": 5}
db.init_app(app)


def rollback_on_exc(func):
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except Exception as e:
self._db.session.rollback()
LOG.error(f"Error in {func.__name__} - {str(e)}")
raise

return wrapper
98 changes: 64 additions & 34 deletions apiserver/dora/store/repos/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy import or_
from sqlalchemy.orm import defer

from dora.store import db
from dora.store import db, rollback_on_exc
from dora.store.models.code import (
PullRequest,
PullRequestEvent,
Expand All @@ -22,47 +22,56 @@


class CodeRepoService:
def __init__(self):
self._db = db

@rollback_on_exc
def get_active_org_repos(self, org_id: str) -> List[OrgRepo]:
return (
db.session.query(OrgRepo)
self._db.session.query(OrgRepo)
.filter(OrgRepo.org_id == org_id, OrgRepo.is_active.is_(True))
.all()
)

@rollback_on_exc
def update_org_repos(self, org_repos: List[OrgRepo]):
[db.session.merge(org_repo) for org_repo in org_repos]
db.session.commit()
[self._db.session.merge(org_repo) for org_repo in org_repos]
self._db.session.commit()

@rollback_on_exc
def save_pull_requests_data(
self,
pull_requests: List[PullRequest],
pull_request_commits: List[PullRequestCommit],
pull_request_events: List[PullRequestEvent],
):
[db.session.merge(pull_request) for pull_request in pull_requests]
[self._db.session.merge(pull_request) for pull_request in pull_requests]
[
db.session.merge(pull_request_commit)
self._db.session.merge(pull_request_commit)
for pull_request_commit in pull_request_commits
]
[
db.session.merge(pull_request_event)
self._db.session.merge(pull_request_event)
for pull_request_event in pull_request_events
]
db.session.commit()
self._db.session.commit()

@rollback_on_exc
def update_prs(self, prs: List[PullRequest]):
[db.session.merge(pr) for pr in prs]
db.session.commit()
[self._db.session.merge(pr) for pr in prs]
self._db.session.commit()

@rollback_on_exc
def save_revert_pr_mappings(
self, revert_pr_mappings: List[PullRequestRevertPRMapping]
):
[db.session.merge(revert_pr_map) for revert_pr_map in revert_pr_mappings]
db.session.commit()
[self._db.session.merge(revert_pr_map) for revert_pr_map in revert_pr_mappings]
self._db.session.commit()

@rollback_on_exc
def get_org_repo_bookmark(self, org_repo: OrgRepo, bookmark_type):
return (
db.session.query(Bookmark)
self._db.session.query(Bookmark)
.filter(
and_(
Bookmark.repo_id == org_repo.id,
Expand All @@ -72,16 +81,21 @@ def get_org_repo_bookmark(self, org_repo: OrgRepo, bookmark_type):
.one_or_none()
)

@rollback_on_exc
def update_org_repo_bookmark(self, bookmark: Bookmark):
db.session.merge(bookmark)
db.session.commit()
self._db.session.merge(bookmark)
self._db.session.commit()

@rollback_on_exc
def get_repo_by_id(self, repo_id: str) -> Optional[OrgRepo]:
return db.session.query(OrgRepo).filter(OrgRepo.id == repo_id).one_or_none()
return (
self._db.session.query(OrgRepo).filter(OrgRepo.id == repo_id).one_or_none()
)

@rollback_on_exc
def get_repo_pr_by_number(self, repo_id: str, pr_number) -> Optional[PullRequest]:
return (
db.session.query(PullRequest)
self._db.session.query(PullRequest)
.options(defer(PullRequest.data))
.filter(
and_(
Expand All @@ -91,31 +105,34 @@ def get_repo_pr_by_number(self, repo_id: str, pr_number) -> Optional[PullRequest
.one_or_none()
)

@rollback_on_exc
def get_pr_events(self, pr_model: PullRequest):
if not pr_model:
return []

pr_events = (
db.session.query(PullRequestEvent)
self._db.session.query(PullRequestEvent)
.options(defer(PullRequestEvent.data))
.filter(PullRequestEvent.pull_request_id == pr_model.id)
.all()
)
return pr_events

@rollback_on_exc
def get_prs_by_ids(self, pr_ids: List[str]):
query = (
db.session.query(PullRequest)
self._db.session.query(PullRequest)
.options(defer(PullRequest.data))
.filter(PullRequest.id.in_(pr_ids))
)
return query.all()

@rollback_on_exc
def get_prs_by_head_branch_match_strings(
self, repo_ids: List[str], match_strings: List[str]
) -> List[PullRequest]:
query = (
db.session.query(PullRequest)
self._db.session.query(PullRequest)
.options(defer(PullRequest.data))
.filter(
and_(
Expand All @@ -133,11 +150,12 @@ def get_prs_by_head_branch_match_strings(

return query.all()

@rollback_on_exc
def get_reverted_prs_by_numbers(
self, repo_ids: List[str], numbers: List[str]
) -> List[PullRequest]:
query = (
db.session.query(PullRequest)
self._db.session.query(PullRequest)
.options(defer(PullRequest.data))
.filter(
and_(
Expand All @@ -150,27 +168,31 @@ def get_reverted_prs_by_numbers(

return query.all()

@rollback_on_exc
def get_active_team_repos_by_team_id(self, team_id: str) -> List[TeamRepos]:
return (
db.session.query(TeamRepos)
self._db.session.query(TeamRepos)
.filter(TeamRepos.team_id == team_id, TeamRepos.is_active.is_(True))
.all()
)

@rollback_on_exc
def get_active_team_repos_by_team_ids(self, team_ids: List[str]) -> List[TeamRepos]:
return (
db.session.query(TeamRepos)
self._db.session.query(TeamRepos)
.filter(TeamRepos.team_id.in_(team_ids), TeamRepos.is_active.is_(True))
.all()
)

@rollback_on_exc
def get_active_org_repos_by_ids(self, repo_ids: List[str]) -> List[OrgRepo]:
return (
db.session.query(OrgRepo)
self._db.session.query(OrgRepo)
.filter(OrgRepo.id.in_(repo_ids), OrgRepo.is_active.is_(True))
.all()
)

@rollback_on_exc
def get_prs_merged_in_interval(
self,
repo_ids: List[str],
Expand All @@ -179,7 +201,7 @@ def get_prs_merged_in_interval(
base_branches: List[str] = None,
has_non_null_mtd=False,
) -> List[PullRequest]:
query = db.session.query(PullRequest).options(defer(PullRequest.data))
query = self._db.session.query(PullRequest).options(defer(PullRequest.data))

query = self._filter_prs_by_repo_ids(query, repo_ids)
query = self._filter_prs_merged_in_interval(query, interval)
Expand All @@ -194,17 +216,19 @@ def get_prs_merged_in_interval(

return query.all()

@rollback_on_exc
def get_pull_request_by_id(self, pr_id: str) -> PullRequest:
return (
db.session.query(PullRequest)
self._db.session.query(PullRequest)
.options(defer(PullRequest.data))
.filter(PullRequest.id == pr_id)
.one_or_none()
)

@rollback_on_exc
def get_previous_pull_request(self, pull_request: PullRequest) -> PullRequest:
return (
db.session.query(PullRequest)
self._db.session.query(PullRequest)
.options(defer(PullRequest.data))
.filter(
PullRequest.repo_id == pull_request.repo_id,
Expand All @@ -216,15 +240,17 @@ def get_previous_pull_request(self, pull_request: PullRequest) -> PullRequest:
.first()
)

@rollback_on_exc
def get_repos_by_ids(self, ids: List[str]) -> List[OrgRepo]:
if not ids:
return []

return db.session.query(OrgRepo).filter(OrgRepo.id.in_(ids)).all()
return self._db.session.query(OrgRepo).filter(OrgRepo.id.in_(ids)).all()

@rollback_on_exc
def get_team_repos(self, team_id) -> List[OrgRepo]:
team_repos = (
db.session.query(TeamRepos)
self._db.session.query(TeamRepos)
.filter(and_(TeamRepos.team_id == team_id, TeamRepos.is_active == True))
.all()
)
Expand All @@ -234,26 +260,29 @@ def get_team_repos(self, team_id) -> List[OrgRepo]:
team_repo_ids = [tr.org_repo_id for tr in team_repos]
return self.get_repos_by_ids(team_repo_ids)

@rollback_on_exc
def get_merge_to_deploy_broker_bookmark(
self, repo_id: str
) -> BookmarkMergeToDeployBroker:
return (
db.session.query(BookmarkMergeToDeployBroker)
self._db.session.query(BookmarkMergeToDeployBroker)
.filter(BookmarkMergeToDeployBroker.repo_id == repo_id)
.one_or_none()
)

@rollback_on_exc
def update_merge_to_deploy_broker_bookmark(
self, bookmark: BookmarkMergeToDeployBroker
):
db.session.merge(bookmark)
db.session.commit()
self._db.session.merge(bookmark)
self._db.session.commit()

@rollback_on_exc
def get_prs_in_repo_merged_before_given_date_with_merge_to_deploy_as_null(
self, repo_id: str, to_time: datetime
):
return (
db.session.query(PullRequest)
self._db.session.query(PullRequest)
.options(defer(PullRequest.data))
.filter(
PullRequest.repo_id == repo_id,
Expand All @@ -264,11 +293,12 @@ def get_prs_in_repo_merged_before_given_date_with_merge_to_deploy_as_null(
.all()
)

@rollback_on_exc
def get_repo_revert_prs_mappings_updated_in_interval(
self, repo_id, from_time, to_time
) -> List[PullRequestRevertPRMapping]:
query = (
db.session.query(PullRequestRevertPRMapping)
self._db.session.query(PullRequestRevertPRMapping)
.join(PullRequest, PullRequest.id == PullRequestRevertPRMapping.pr_id)
.filter(
PullRequest.repo_id == repo_id,
Expand Down
30 changes: 19 additions & 11 deletions apiserver/dora/store/repos/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sqlalchemy import and_

from dora.store import db
from dora.store import db, rollback_on_exc
from dora.store.models import UserIdentityProvider, Integration
from dora.store.models.core import Organization, Team, Users
from dora.utils.cryptography import get_crypto_service
Expand All @@ -11,56 +11,64 @@
class CoreRepoService:
def __init__(self):
self._crypto = get_crypto_service()
self._db = db

@rollback_on_exc
def get_org(self, org_id):
return (
db.session.query(Organization)
self._db.session.query(Organization)
.filter(Organization.id == org_id)
.one_or_none()
)

@rollback_on_exc
def get_org_by_name(self, org_name: str):
return (
db.session.query(Organization)
self._db.session.query(Organization)
.filter(Organization.name == org_name)
.one_or_none()
)

@rollback_on_exc
def get_team(self, team_id: str) -> Team:
return (
db.session.query(Team)
self._db.session.query(Team)
.filter(Team.id == team_id, Team.is_deleted.is_(False))
.one_or_none()
)

@rollback_on_exc
def delete_team(self, team_id: str):

team = db.session.query(Team).filter(Team.id == team_id).one_or_none()
team = self._db.session.query(Team).filter(Team.id == team_id).one_or_none()

if not team:
return None

team.is_deleted = True

db.session.merge(team)
db.session.commit()
return db.session.query(Team).filter(Team.id == team_id).one_or_none()
self._db.session.merge(team)
self._db.session.commit()
return self._db.session.query(Team).filter(Team.id == team_id).one_or_none()

@rollback_on_exc
def get_user(self, user_id) -> Optional[Users]:
return db.session.query(Users).filter(Users.id == user_id).one_or_none()
return self._db.session.query(Users).filter(Users.id == user_id).one_or_none()

@rollback_on_exc
def get_org_integrations_for_names(self, org_id: str, provider_names: List[str]):
return (
db.session.query(Integration)
self._db.session.query(Integration)
.filter(
and_(Integration.org_id == org_id, Integration.name.in_(provider_names))
)
.all()
)

@rollback_on_exc
def get_access_token(self, org_id, provider: UserIdentityProvider) -> Optional[str]:
user_identity: Integration = (
db.session.query(Integration)
self._db.session.query(Integration)
.filter(
and_(Integration.org_id == org_id, Integration.name == provider.value)
)
Expand Down
Loading