-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: TreatmentEffect builder pattern
- Loading branch information
Showing
10 changed files
with
699 additions
and
376 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from medmodels.medrecord import MedRecord | ||
from medmodels.treatment_effect_estimation import TreatmentEffect | ||
|
||
__all__ = [MedRecord] | ||
__all__ = [MedRecord, TreatmentEffect] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from medmodels.treatment_effect_estimation.treatment_effect import TreatmentEffect | ||
|
||
__all__ = ["TreatmentEffect"] |
48 changes: 47 additions & 1 deletion
48
medmodels/treatment_effect_estimation/analysis_modules/adjust.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,59 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
from typing import TYPE_CHECKING, Optional, Set, Literal, Tuple | ||
from typing_extensions import TypeAlias | ||
|
||
from medmodels import MedRecord | ||
from medmodels.medrecord.types import ( | ||
NodeIndex, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from medmodels.treatment_effect_estimation.treatment_effect import TreatmentEffect | ||
|
||
|
||
MatchingMethod: TypeAlias = Literal["propensity", "nearest_neighbors"] | ||
|
||
|
||
class Adjust: | ||
_treatment_effect: TreatmentEffect | ||
|
||
def __init__(self, treatment_effect: TreatmentEffect) -> None: | ||
self._treatment_effect = treatment_effect | ||
|
||
def _apply_matching( | ||
self, | ||
method: Optional[MatchingMethod], | ||
medrecord: MedRecord, | ||
treatment_all: Set[NodeIndex], | ||
control_true: Set[NodeIndex], | ||
control_false: Set[NodeIndex], | ||
) -> Tuple[Set[NodeIndex], Set[NodeIndex]]: | ||
""" | ||
Update the treatment effect object with the matched controls. | ||
Args: | ||
medrecord (MedRecord): The MedRecord object containing the data. | ||
treatment_all (Set[NodeIndex]): The set of all patients in the treatment | ||
group. | ||
control_true (Set[NodeIndex]): The set of patients in the control group with | ||
the outcome of interest. | ||
control_false (Set[NodeIndex]): The set of patients in the control group | ||
without the outcome of interest. | ||
control_false (Set[NodeIndex]): The set of patients in the control group | ||
without the outcome of interest. | ||
Returns: | ||
Tuple[Set[NodeIndex], Set[NodeIndex]]: The updated control_true and | ||
control_false sets after matching. | ||
""" | ||
if method is None: | ||
return control_true, control_false | ||
|
||
# If it is not None, apply the matching method | ||
method_function = getattr(self, method) | ||
control_true, control_false = method_function( | ||
medrecord, treatment_all, control_true, control_false | ||
) | ||
|
||
return control_true, control_false |
119 changes: 0 additions & 119 deletions
119
medmodels/treatment_effect_estimation/analysis_modules/configure.py
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.