diff --git a/bcitflex/app_functions/__init__.py b/bcitflex/app_functions/__init__.py index 8ce7b41..c21df04 100644 --- a/bcitflex/app_functions/__init__.py +++ b/bcitflex/app_functions/__init__.py @@ -1,2 +1,2 @@ """Functions used by the app.""" -from .course_query import CourseFilter, filter_courses +from .course_query import ColumnCondition, filter_courses diff --git a/bcitflex/app_functions/course_query.py b/bcitflex/app_functions/course_query.py index 1ba51cd..736585c 100644 --- a/bcitflex/app_functions/course_query.py +++ b/bcitflex/app_functions/course_query.py @@ -1,8 +1,10 @@ """Filter courses.""" +from abc import ABC, abstractmethod + from sqlalchemy import Column, ColumnExpressionArgument, select from sqlalchemy.orm import Session -from bcitflex.model import Base, Course +from bcitflex.model import Base, Course, T_Base def coerce_to_column_type(model_column: Column, value: str | int) -> str | int: @@ -11,36 +13,46 @@ def coerce_to_column_type(model_column: Column, value: str | int) -> str | int: return column_data_type(value) -class CourseFilter: - """Course filter class.""" +class ModelCondition(ABC): + """Base class for SQLAlchemy condition.""" - def __init__(self, model_column: Column, filter_value: str): + def __init__(self, model: T_Base, attribute: str, value: str): """Initialize the filter.""" - self.model_column = model_column - self.filter_value = coerce_to_column_type(model_column, filter_value) - - @property - def model(self) -> Base: - """Return the model to filter on.""" - return self.model_column.class_ + self.model = model + self.model_attribute = getattr(model, attribute) + self.value = value - def clause(self) -> ColumnExpressionArgument: + @abstractmethod + def condition(self): """Return a filter condition.""" - if self.filter_value: - return self.model_column.__eq__(self.filter_value) + raise NotImplementedError def __repr__(self): """Return a string representation of the filter.""" - return f"{self.__class__.__name__}({self.model_column}, {self.filter_value})" + return f"{self.__class__.__name__}({self.model_attribute}, {self.value})" def __bool__(self): """Return True if the filter has a value.""" - return bool(self.filter_value) + return bool(self.value) + + +class ColumnCondition(ModelCondition): + """SQLAlchemy column condition.""" + + def __init__(self, model_attribute: Column, value: str): + """Initialize the filter.""" + super().__init__(model_attribute, value) + self.filter_value = coerce_to_column_type(model_attribute, value) + + def condition(self) -> ColumnExpressionArgument: + """Return a filter condition.""" + if self.filter_value: + return self.model_attribute.__eq__(self.filter_value) def filter_courses( session: Session, - *filters: CourseFilter, + *filters: ModelCondition, available: bool = False, ) -> list[Course]: """Return a list of courses that match the given criteria.""" @@ -57,7 +69,7 @@ def filter_courses( if models: query = query.join(*models) query = query.where( - *[course_filter.clause() for course_filter in filters if course_filter] + *[course_filter.condition() for course_filter in filters if course_filter] ) # Execute the query diff --git a/bcitflex/course.py b/bcitflex/course.py index 38e3489..4c46c2a 100644 --- a/bcitflex/course.py +++ b/bcitflex/course.py @@ -4,7 +4,7 @@ from sqlalchemy import select from werkzeug.datastructures import ImmutableMultiDict -from bcitflex.app_functions import CourseFilter, filter_courses +from bcitflex.app_functions import ColumnCondition, filter_courses from bcitflex.db import DBSession from bcitflex.model import Course @@ -16,10 +16,10 @@ } -def filters_from_form(form: ImmutableMultiDict) -> list[CourseFilter]: +def filters_from_form(form: ImmutableMultiDict) -> list[ColumnCondition]: """Return a list of CourseFilters from the given form.""" return [ - CourseFilter(FILTER_TO_FIELD[key], value) + ColumnCondition(FILTER_TO_FIELD[key], value) for key, value in form.items() if key in FILTER_TO_FIELD and value is not None ] diff --git a/bcitflex/model/__init__.py b/bcitflex/model/__init__.py index a618ad3..afe22d1 100644 --- a/bcitflex/model/__init__.py +++ b/bcitflex/model/__init__.py @@ -1,4 +1,4 @@ -from .base import Base, MappedAsDataclass +from .base import Base, MappedAsDataclass, T_Base from .course import Course from .meeting import Meeting from .offering import Offering diff --git a/bcitflex/model/base.py b/bcitflex/model/base.py index 9ad8e76..eed7dcc 100644 --- a/bcitflex/model/base.py +++ b/bcitflex/model/base.py @@ -15,7 +15,7 @@ from sqlalchemy.orm import DeclarativeBase, Mapper, Session from sqlalchemy.orm import MappedAsDataclass as MappedAsDataclassBase -_T = TypeVar("_T", bound="Base") +T_Base = TypeVar("T_Base", bound="Base") # Constraint naming conventions convention = { @@ -38,7 +38,7 @@ def db_to_attr(cls_mapper: Mapper, db_name: str) -> str: raise ValueError(f"Unknown database name: {db_name}") -def updated_pks(obj: _T, new_pk_vals: dict) -> dict: +def updated_pks(obj: T_Base, new_pk_vals: dict) -> dict: """Return a dict of primary keys updated with new_pk_vals.""" cls_mapper = inspect(obj.__class__) @@ -80,7 +80,7 @@ class Base(SoftDeleteMixin, DeclarativeBase): @classmethod def _unique_constraint( - cls: "_T", + cls: "T_Base", constraint_name: str | None = None, ) -> UniqueConstraint | None: """Return the unique constraint of a model.""" @@ -119,11 +119,11 @@ def _unique_constraint( @classmethod def get_by_unique( - cls: "_T", + cls: "T_Base", session: Session, unique_id: int | str | tuple, constraint_name: str | None = None, - ) -> _T | None: + ) -> T_Base | None: """ Return an object using the unique constraint instead of the primary key. @@ -163,7 +163,7 @@ def clone( pk_id: int | str | tuple | dict | None = None, include_relationships: bool = True, **kwargs, - ) -> _T: + ) -> T_Base: """ Clone the object with the given primary key and kwargs including FK relationships. diff --git a/tests/app_functions/test_course_query.py b/tests/app_functions/test_course_query.py index 1e4dc00..0b4ba10 100644 --- a/tests/app_functions/test_course_query.py +++ b/tests/app_functions/test_course_query.py @@ -2,7 +2,7 @@ import pytest from bcitflex.app_functions.course_query import ( - CourseFilter, + ColumnCondition, coerce_to_column_type, filter_courses, ) @@ -32,16 +32,20 @@ class TestCourseFilter: def test_course_filter_init(self): """Test CourseFilter initialization.""" - code_filter = CourseFilter(Course.code, "1234") + code_filter = ColumnCondition(Course.code, "1234") model = code_filter.model assert issubclass(model, Course) def test_course_filter_clause(self): """Test CourseFilter clause method.""" - code_filter = CourseFilter(Course.code, "1234") - clause = code_filter.clause() + code_filter = ColumnCondition(Course.code, "1234") + clause = code_filter.condition() assert clause.right.value == "1234" + def test_condition_attribute(self): + """Test ColumnCondition with a class attribute.""" + code_filter = ColumnCondition(Course.is_available, "1234") + assert code_filter.model_attribute == Course.code @dbtest class TestFilterCourses: @@ -60,17 +64,17 @@ def test_filter_courses_available(self, session): def test_filter_courses_code(self, session): """Test that filter_courses returns only courses with the given code.""" - courses = filter_courses(session, CourseFilter(Course.code, "1234")) + courses = filter_courses(session, ColumnCondition(Course.code, "1234")) assert all(course.code == "1234" for course in courses) def test_filter_courses_subject(self, session): """Test that filter_courses returns only courses with the given subject.""" - courses = filter_courses(session, CourseFilter(Course.subject_id, "COMP")) + courses = filter_courses(session, ColumnCondition(Course.subject_id, "COMP")) assert all(course.subject_id == "COMP" for course in courses) def test_filter_courses_status(self, session): """Test that filter_courses returns only courses with the given status.""" - courses = filter_courses(session, CourseFilter(Offering.status, "Open")) + courses = filter_courses(session, ColumnCondition(Offering.status, "Open")) assert len(courses) != 0 assert all( any(offering.status == "Open" for offering in course.offerings)