Skip to content

Commit

Permalink
🐛 Refactor auto_commit decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
s9891326 committed Jan 6, 2023
1 parent 9673559 commit 9b1e108
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 10 deletions.
44 changes: 35 additions & 9 deletions repository/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@
from typing import List

from sqlalchemy import and_
from sqlalchemy.orm import Session

from repository.database import Session
from repository.database import SessionLocal
from repository.models import User, UserSubGood


def auto_close_or_rollback(func):
def auto_commit(func):
@functools.wraps(func)
def wrapper(*args):
session = Session()
session: Session = SessionLocal()
try:
result = func(*args, session=session)
# if insert、update and delete will commit the transaction
if session.new or session.dirty or session.deleted:
session.commit()
except Exception:
session.rollback()
raise
Expand All @@ -23,25 +27,47 @@ def wrapper(*args):
return wrapper


@auto_close_or_rollback
@auto_commit
def get_users(_skip: int = 0, _limit: int = 100, **kwargs) -> List[User]:
session = kwargs["session"]
session: Session = kwargs["session"]
return session.query(User).offset(_skip).limit(_limit).all()


@auto_close_or_rollback
@auto_commit
def upsert_user(user_id: str, chat_id: str, **kwargs):
session = kwargs["session"]
session: Session = kwargs["session"]
data = User(id=user_id, chat_id=chat_id, state=1)
session.merge(data)
session.commit()


@auto_close_or_rollback
@auto_commit
def find_user_by_good_id(good_id: str, **kwargs):
session = kwargs["session"]
session: Session = kwargs["session"]
return (
session.query(User.chat_id)
.join(UserSubGood, and_(User.id == UserSubGood.user_id, User.state == 1))
.filter(UserSubGood.good_id == good_id)
.all()
)


# @auto_commit
# def insert_user(**kwargs):
# session: Session = kwargs["session"]
# data = User(id="1234", chat_id="12", state=1)
# session.add(data)
#
#
# @auto_commit
# def update_user(**kwargs):
# session: Session = kwargs["session"]
# user = session.query(User).filter(User.id == "1234").first()
# user.state = 0
#
#
# @auto_commit
# def delete_user(**kwargs):
# session: Session = kwargs["session"]
# user = session.query(User).filter(User.id == "1234").first()
# session.delete(user)
2 changes: 1 addition & 1 deletion repository/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
f"postgresql://{pt_config.DB_USER}:{pt_config.DB_PASSWORD}@"
f"{pt_config.DB_HOST}:5432/{pt_config.DB_NAME}"
)
Session = sessionmaker(bind=Engine, autocommit=True)
SessionLocal = sessionmaker(bind=Engine)
Base = declarative_base()

0 comments on commit 9b1e108

Please sign in to comment.