In [67]:
from carla.data.catalog import OnlineCatalog
from carla import MLModelCatalog
from carla.recourse_methods import Clue
from carla.models.negative_instances import predict_negative_instances
import numpy as np

import warnings
warnings.filterwarnings("ignore")

num = 10
data_name = "heloc"

In [None]:
dataset = OnlineCatalog(data_name)

In [64]:
def train_new_model(dataset):
    model = MLModelCatalog(dataset, "ann", backend="pytorch")
    model.train(
        learning_rate = 0.001,
        epochs = 10,
        max_depth = 50,
        n_estimators = 50,
        batch_size = 20,
        force_train = True
    )
    return model

In [65]:
def update_dataset(dataset, factuals, counterfactuals):
    fac_ind = []
    for index, row in factuals.iterrows():
        fac_ind.append(index)
    for index, row in counterfactuals.iterrows():
        dataset.loc[index] = counterfactuals.loc[index]

In [66]:
def train_recourse_method(method):
    rm = None
    if method == "clue":
        hyperparams = {
                "data_name": data_name,
                "train_vae": True,
                "width": 10,
                "depth": 3,
                "latent_dim": 12,
                "batch_size": 64,
                "epochs": 1,
                "lr": 0.001,
                "early_stop": 20,
            }

        # load a recourse model and pass black box model
        rm = Clue(dataset, model, hyperparams)
        
    return rm

In [68]:
def predict(model, data):
    pred = model.predict(data)
    return np.where(pred > 0.5, 1, 0)

In [18]:
# generate counterfactual examples
factuals = predict_negative_instances(model, dataset._df).iloc[:num]
print("Number of factuals", len(factuals))

pre = model.predict(factuals)

Unnamed: 0,ExternalRiskEstimate,MSinceOldestTradeOpen,MSinceMostRecentTradeOpen,AverageMInFile,NumSatisfactoryTrades,...,NumRevolvingTradesWBalance,NumInstallTradesWBalance,NumBank2NatlTradesWHighUtilization,PercentTradesWBalance,RiskPerformance
0,0.360656,0.177278,0.010444,0.211082,0.253165,...,0.25,0.0,0.055556,0.69,0.0
1,0.459016,0.069913,0.039164,0.097625,0.025316,...,0.0,0.071089,0.081335,0.0,0.0
2,0.557377,0.0799,0.013055,0.05277,0.113924,...,0.125,0.045455,0.055556,0.86,0.0
3,0.540984,0.208489,0.002611,0.182058,0.35443,...,0.1875,0.136364,0.166667,0.91,0.0
5,0.42623,0.168539,0.028721,0.195251,0.392405,...,0.375,0.136364,0.166667,0.94,0.0
6,0.344262,0.107366,0.018277,0.087071,0.316456,...,0.21875,0.272727,0.111111,1.0,1.0
7,0.57377,0.182272,0.018277,0.16095,0.21519,...,0.0625,0.045455,0.111111,0.4,1.0
8,0.42623,0.401998,0.005222,0.353562,0.303797,...,0.21875,0.0,0.166667,0.9,0.0
9,0.459016,0.09613,0.010444,0.084433,0.240506,...,0.15625,0.090909,0.055556,0.62,0.0
11,0.278689,0.385768,0.065274,0.189974,0.151899,...,0.15625,0.045455,0.111111,0.88,0.0


In [None]:
rm = train_recourse_method("clue")

In [21]:
counterfactuals = rm.get_counterfactuals(factuals)
print("Number of counterfactuals:", len(counterfactuals))

  f"cannot re-order features for non dataframe input: {type(x)}"
  f"cannot re-order features for non dataframe input: {type(x)}"


In [38]:
d_c = OnlineCatalog(data_name)
update_dataset(d_c._df, counterfactuals)
model2 = train_new_model(d_c)

In [45]:
post = model2.predict(factuals)
post

array([[0.19870104],
       [0.3094813 ],
       [0.16752979],
       [0.16535327],
       [0.1347224 ],
       [0.07916164],
       [0.42133692],
       [0.23459026],
       [0.14387123],
       [0.16688304]], dtype=float32)

In [70]:
post - pre

array([[ 0.02695996],
       [ 0.05455267],
       [-0.00967307],
       [-0.01217966],
       [-0.0190099 ],
       [-0.07457066],
       [ 0.10207954],
       [ 0.01860483],
       [-0.00986107],
       [ 0.01315074]], dtype=float32)

In [None]:
import matplotlib.pyplot as plt

plt.scatter(factuals['age'], factuals['length_of_stay'], c=factuals['score'], marker='o')
# plt.xlim([0, 0.1])
# plt.show()

plt.scatter(counterfactuals['age'], counterfactuals['length_of_stay'], c=counterfactuals['score'], marker='s')
# plt.xlim([0, 0.1])
plt.show()

