-
Notifications
You must be signed in to change notification settings - Fork 1
/
base.py
86 lines (70 loc) · 2.58 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/methods/00_base.ipynb.
# %% auto 0
__all__ = ['CFModule', 'ParametricCFModule']
# %% ../../nbs/methods/00_base.ipynb 2
from ..import_essentials import *
from ..base import BaseConfig, BaseModule, PredFnMixedin, TrainableMixedin
# %% ../../nbs/methods/00_base.ipynb 3
def default_apply_constraints_fn(x, cf, hard, **kwargs):
return cf
def default_compute_reg_loss_fn(x, cf, **kwargs):
return 0.
# %% ../../nbs/methods/00_base.ipynb 4
class CFModule(BaseModule):
"""Base class for all counterfactual modules."""
def __init__(
self,
config,
*,
name: str = None,
apply_constraints_fn = None,
compute_reg_loss_fn = None,
**kwargs
):
super().__init__(config, name=name)
self._apply_constraints_fn = apply_constraints_fn
self._compute_reg_loss_fn = compute_reg_loss_fn
self.data_module = None
def set_data_module(self, data_module):
self.data_module = data_module
def set_apply_constraints_fn(self, apply_constraints_fn: Callable):
self._apply_constraints_fn = apply_constraints_fn
def set_compute_reg_loss_fn(self, compute_reg_loss_fn: Callable):
self._compute_reg_loss_fn = compute_reg_loss_fn
def apply_constraints(self, *args, **kwargs) -> Array:
if self._apply_constraints_fn is not None:
return self._apply_constraints_fn(*args, **kwargs)
else:
return default_apply_constraints_fn(*args, **kwargs)
def compute_reg_loss(self, *args, **kwargs):
if self._compute_reg_loss_fn is not None:
return self._compute_reg_loss_fn(*args, **kwargs)
else:
return default_compute_reg_loss_fn(*args, **kwargs)
def before_generate_cf(self, *args, **kwargs):
pass
def generate_cf(
self,
x: Array,
pred_fn: Callable = None,
y_target: Array = None,
rng_key: jrand.PRNGKey = None,
**kwargs
) -> Array: # Return counterfactual of x.
raise NotImplementedError
__ALL__ = [
"set_apply_constraints_fn",
"set_compute_reg_loss_fn",
"apply_constraints",
"compute_reg_loss",
"save",
"load_from_path",
"before_generate_cf",
"generate_cf"
]
# %% ../../nbs/methods/00_base.ipynb 5
class ParametricCFModule(CFModule, TrainableMixedin):
"""Base class for parametric counterfactual modules."""
def train(self, data, pred_fn, **kwargs):
"""Train the module."""
raise NotImplementedError