Skip to content

Commit

Permalink
wip(course): add filter for attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
jonbiemond committed Oct 21, 2023
1 parent 59de381 commit 572b469
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 36 deletions.
2 changes: 1 addition & 1 deletion bcitflex/app_functions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Functions used by the app."""
from .course_query import CourseFilter, filter_courses
from .course_query import ColumnCondition, filter_courses
48 changes: 30 additions & 18 deletions bcitflex/app_functions/course_query.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Filter courses."""
from abc import ABC, abstractmethod

from sqlalchemy import Column, ColumnExpressionArgument, select
from sqlalchemy.orm import Session

from bcitflex.model import Base, Course
from bcitflex.model import Base, Course, T_Base


def coerce_to_column_type(model_column: Column, value: str | int) -> str | int:
Expand All @@ -11,36 +13,46 @@ def coerce_to_column_type(model_column: Column, value: str | int) -> str | int:
return column_data_type(value)


class CourseFilter:
"""Course filter class."""
class ModelCondition(ABC):
"""Base class for SQLAlchemy condition."""

def __init__(self, model_column: Column, filter_value: str):
def __init__(self, model: T_Base, attribute: str, value: str):
"""Initialize the filter."""
self.model_column = model_column
self.filter_value = coerce_to_column_type(model_column, filter_value)

@property
def model(self) -> Base:
"""Return the model to filter on."""
return self.model_column.class_
self.model = model
self.model_attribute = getattr(model, attribute)
self.value = value

def clause(self) -> ColumnExpressionArgument:
@abstractmethod
def condition(self):
"""Return a filter condition."""
if self.filter_value:
return self.model_column.__eq__(self.filter_value)
raise NotImplementedError

def __repr__(self):
"""Return a string representation of the filter."""
return f"{self.__class__.__name__}({self.model_column}, {self.filter_value})"
return f"{self.__class__.__name__}({self.model_attribute}, {self.value})"

def __bool__(self):
"""Return True if the filter has a value."""
return bool(self.filter_value)
return bool(self.value)


class ColumnCondition(ModelCondition):
"""SQLAlchemy column condition."""

def __init__(self, model_attribute: Column, value: str):
"""Initialize the filter."""
super().__init__(model_attribute, value)
self.filter_value = coerce_to_column_type(model_attribute, value)

def condition(self) -> ColumnExpressionArgument:
"""Return a filter condition."""
if self.filter_value:
return self.model_attribute.__eq__(self.filter_value)


def filter_courses(
session: Session,
*filters: CourseFilter,
*filters: ModelCondition,
available: bool = False,
) -> list[Course]:
"""Return a list of courses that match the given criteria."""
Expand All @@ -57,7 +69,7 @@ def filter_courses(
if models:
query = query.join(*models)
query = query.where(
*[course_filter.clause() for course_filter in filters if course_filter]
*[course_filter.condition() for course_filter in filters if course_filter]
)

# Execute the query
Expand Down
6 changes: 3 additions & 3 deletions bcitflex/course.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sqlalchemy import select
from werkzeug.datastructures import ImmutableMultiDict

from bcitflex.app_functions import CourseFilter, filter_courses
from bcitflex.app_functions import ColumnCondition, filter_courses
from bcitflex.db import DBSession
from bcitflex.model import Course

Expand All @@ -16,10 +16,10 @@
}


def filters_from_form(form: ImmutableMultiDict) -> list[CourseFilter]:
def filters_from_form(form: ImmutableMultiDict) -> list[ColumnCondition]:
"""Return a list of CourseFilters from the given form."""
return [
CourseFilter(FILTER_TO_FIELD[key], value)
ColumnCondition(FILTER_TO_FIELD[key], value)
for key, value in form.items()
if key in FILTER_TO_FIELD and value is not None
]
Expand Down
2 changes: 1 addition & 1 deletion bcitflex/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import Base, MappedAsDataclass
from .base import Base, MappedAsDataclass, T_Base
from .course import Course
from .meeting import Meeting
from .offering import Offering
Expand Down
12 changes: 6 additions & 6 deletions bcitflex/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sqlalchemy.orm import DeclarativeBase, Mapper, Session
from sqlalchemy.orm import MappedAsDataclass as MappedAsDataclassBase

_T = TypeVar("_T", bound="Base")
T_Base = TypeVar("T_Base", bound="Base")

# Constraint naming conventions
convention = {
Expand All @@ -38,7 +38,7 @@ def db_to_attr(cls_mapper: Mapper, db_name: str) -> str:
raise ValueError(f"Unknown database name: {db_name}")


def updated_pks(obj: _T, new_pk_vals: dict) -> dict:
def updated_pks(obj: T_Base, new_pk_vals: dict) -> dict:
"""Return a dict of primary keys updated with new_pk_vals."""

cls_mapper = inspect(obj.__class__)
Expand Down Expand Up @@ -80,7 +80,7 @@ class Base(SoftDeleteMixin, DeclarativeBase):

@classmethod
def _unique_constraint(
cls: "_T",
cls: "T_Base",
constraint_name: str | None = None,
) -> UniqueConstraint | None:
"""Return the unique constraint of a model."""
Expand Down Expand Up @@ -119,11 +119,11 @@ def _unique_constraint(

@classmethod
def get_by_unique(
cls: "_T",
cls: "T_Base",
session: Session,
unique_id: int | str | tuple,
constraint_name: str | None = None,
) -> _T | None:
) -> T_Base | None:
"""
Return an object using the unique constraint instead of the primary key.
Expand Down Expand Up @@ -163,7 +163,7 @@ def clone(
pk_id: int | str | tuple | dict | None = None,
include_relationships: bool = True,
**kwargs,
) -> _T:
) -> T_Base:
"""
Clone the object with the given primary key and kwargs including FK relationships.
Expand Down
18 changes: 11 additions & 7 deletions tests/app_functions/test_course_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from bcitflex.app_functions.course_query import (
CourseFilter,
ColumnCondition,
coerce_to_column_type,
filter_courses,
)
Expand Down Expand Up @@ -32,16 +32,20 @@ class TestCourseFilter:

def test_course_filter_init(self):
"""Test CourseFilter initialization."""
code_filter = CourseFilter(Course.code, "1234")
code_filter = ColumnCondition(Course.code, "1234")
model = code_filter.model
assert issubclass(model, Course)

def test_course_filter_clause(self):
"""Test CourseFilter clause method."""
code_filter = CourseFilter(Course.code, "1234")
clause = code_filter.clause()
code_filter = ColumnCondition(Course.code, "1234")
clause = code_filter.condition()
assert clause.right.value == "1234"

def test_condition_attribute(self):
"""Test ColumnCondition with a class attribute."""
code_filter = ColumnCondition(Course.is_available, "1234")
assert code_filter.model_attribute == Course.code

@dbtest
class TestFilterCourses:
Expand All @@ -60,17 +64,17 @@ def test_filter_courses_available(self, session):

def test_filter_courses_code(self, session):
"""Test that filter_courses returns only courses with the given code."""
courses = filter_courses(session, CourseFilter(Course.code, "1234"))
courses = filter_courses(session, ColumnCondition(Course.code, "1234"))
assert all(course.code == "1234" for course in courses)

def test_filter_courses_subject(self, session):
"""Test that filter_courses returns only courses with the given subject."""
courses = filter_courses(session, CourseFilter(Course.subject_id, "COMP"))
courses = filter_courses(session, ColumnCondition(Course.subject_id, "COMP"))
assert all(course.subject_id == "COMP" for course in courses)

def test_filter_courses_status(self, session):
"""Test that filter_courses returns only courses with the given status."""
courses = filter_courses(session, CourseFilter(Offering.status, "Open"))
courses = filter_courses(session, ColumnCondition(Offering.status, "Open"))
assert len(courses) != 0
assert all(
any(offering.status == "Open" for offering in course.offerings)
Expand Down

0 comments on commit 572b469

Please sign in to comment.