In [1]:
import sys

sys.path.append("../")

In [6]:
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression

from algorec.environments import ClosedEnvironment
from algorec.populations import BasePopulation
from algorec.recourse import ActionableRecourse, NFeatureRecourse

rng = np.random.default_rng(42)
df = pd.DataFrame(rng.random((100, 4)), columns=["a", "b", "c", "d"])
df["cat_1"] = rng.integers(0, 2, 100)
y = rng.integers(0, 2, 100)

lr = LogisticRegression().fit(df, y)

### Environment using ``ActionableRecourse``

In [13]:
# Test an environment
population = BasePopulation(
    data=df,
    # step_direction={"a": 1, "b": -1},
    # immutable=["c"],
    categorical=["cat_1"],
)

recourse = ActionableRecourse(model=lr, threshold=0.6)

environment = ClosedEnvironment(
    population=population,
    recourse=recourse,
    threshold=0.6,
)

In [14]:
assert environment.step_ == 0

environment.run_simulation(6)

assert environment.step_ == 6
assert environment.population_.data.shape[0] == 1
assert environment.population_.data.index[0] == np.argmin(lr.predict_proba(df)[:, -1])

assert (environment.population_.data.dtypes == environment.population.data.dtypes).all()

AssertionError: 

### Environment using ``NFeatureRecourse``

In [None]:
population = BasePopulation(
    data=df,
    step_direction={"a": 1, "b": -1},
    immutable=["c"],
    categorical=["cat_1"],
)

recourse = NFeatureRecourse(model=lr, threshold=0.6)

environment = ClosedEnvironment(
    population=population,
    recourse=recourse,
    threshold=0.6,
)

In [8]:
assert environment.step_ == 0

environment.run_simulation(6)

assert environment.step_ == 6
assert environment.population_.data.shape[0] == 1
assert environment.population_.data.index[0] == np.argmin(lr.predict_proba(df)[:, -1])

assert (environment.population_.data.dtypes == environment.population.data.dtypes).all()

AssertionError: 

In [12]:
df

Unnamed: 0,a,b,c,d,cat_1
0,0.773956,0.438878,0.858598,0.697368,0
1,0.094177,0.975622,0.761140,0.786064,0
2,0.128114,0.450386,0.370798,0.926765,1
3,0.643865,0.822762,0.443414,0.227239,1
4,0.554585,0.063817,0.827631,0.631664,1
...,...,...,...,...,...
95,0.552993,0.936140,0.780301,0.479370,1
96,0.376359,0.986632,0.717760,0.951195,1
97,0.118479,0.850534,0.637074,0.121922,0
98,0.588258,0.686096,0.012303,0.454318,0


In [10]:
environment.metadata_[1]["population"].data

Unnamed: 0,a,b,c,d,cat_1
0,2.640310,0.438878,0.858598,0.998640,-0.254532
1,1.960531,0.975622,0.761140,1.087336,-0.254532
2,1.994468,0.450386,0.370798,1.228037,0.745468
3,2.510219,0.822762,0.443414,0.528511,0.745468
4,2.420939,0.063817,0.827631,0.932937,0.745468
...,...,...,...,...,...
95,2.419347,0.936140,0.780301,0.780642,0.745468
96,2.242713,0.986632,0.717760,1.252467,0.745468
97,1.984832,0.850534,0.637074,0.423194,-0.254532
98,2.454612,0.686096,0.012303,0.755590,-0.254532


In [11]:
environment.counterfactual(environment.metadata_[1]["population"])

Unnamed: 0,a,b,c,d,cat_1
0,2.640310,0.438878,0.858598,0.998640,-0.254532
1,1.960531,0.975622,0.761140,1.087336,-0.254532
2,2.696836,0.450386,0.370798,1.341415,0.649679
3,2.510219,0.822762,0.443414,0.528511,0.745468
4,2.980790,0.063817,0.827631,1.023309,0.669116
...,...,...,...,...,...
95,2.419347,0.936140,0.780301,0.780642,0.745468
96,2.242713,0.986632,0.717760,1.252467,0.745468
97,1.984832,0.850534,0.637074,0.423194,-0.254532
98,2.668003,0.686096,0.012303,0.790036,-0.283634
