In [None]:
import warnings
warnings.filterwarnings("ignore")

from carla.data.catalog import OnlineCatalog
from carla import MLModelCatalog
from carla.recourse_methods import Clue
from carla.models.negative_instances import predict_negative_instances
import numpy as np
from sklearn.metrics import f1_score, accuracy_score

num = 10
data_name = "compas"

In [None]:
def train_new_model(dataset):
    model = MLModelCatalog(dataset, "ann", backend="pytorch")
    model.train(
        learning_rate = 0.001,
        epochs = 10,
        max_depth = 50,
        n_estimators = 50,
        batch_size = 20,
        force_train = True
    )
    return model

In [None]:
def update_dataset(dataset, factuals, counterfactuals):
    fac_ind = []
#     for index, row in factuals.iterrows():
#         fac_ind.append(index)
#     for index, row in counterfactuals.iterrows():
#         dataset.loc[index] = counterfactuals.loc[index]
        
    for ((i_f, r_f), (i_c, r_c)) in zip(factuals.iterrows(), counterfactuals.iterrows()):
        if len(counterfactuals.loc[i_c].dropna()) > 0:
            dataset.loc[i_f] = counterfactuals.loc[i_c]

In [None]:
def train_recourse_method(method):
    rm = None
    if method == "clue":
        hyperparams = {
                "data_name": data_name,
                "train_vae": True,
                "width": 10,
                "depth": 3,
                "latent_dim": 12,
                "batch_size": 64,
                "epochs": 1,
                "lr": 0.001,
                "early_stop": 20,
            }

        # load a recourse model and pass black box model
        rm = Clue(dataset, model, hyperparams)
        
    return rm

In [None]:
def predict(model, data):
    pred = model.predict(data._df)
    return np.where(pred > 0.5, 1, 0)

In [None]:
def print_f1_score(model, data):
    score = f1_score(np.array(data._df[data.target]), predict(model, data))
    print(f"F1 score: {score}")

In [None]:
def print_accuracy(model, data):
    score = accuracy_score(np.array(data._df[data.target]), predict(model, data))
    print(f"Accuracy score: {score}")

In [None]:
def print_scores(model_pre, model_post, data):
    print("Before recourse:")
    print_f1_score(model_pre, data)
    print_accuracy(model_pre, data)
    print("\nAfter recourse:")
    print_f1_score(model_post, data)
    print_accuracy(model_post, data)

In [None]:
# load the dataset
dataset = OnlineCatalog(data_name)

# train a model on the dataset
model = train_new_model(dataset)

# generate counterfactual samples
factuals = predict_negative_instances(model, dataset._df).sample(num)
print("Number of factuals", len(factuals))

pre = model.predict(factuals)

In [None]:
rm = train_recourse_method("clue")

In [None]:
counterfactuals = rm.get_counterfactuals(factuals)
print("Number of counterfactuals:", len(counterfactuals.dropna()))

In [None]:
# load a new dataset
d_c = OnlineCatalog(data_name)

# replace factuals with counterfactuals
update_dataset(d_c._df, factuals, counterfactuals)

# train the new model
model2 = train_new_model(d_c)

In [None]:
print_scores(model, model2, dataset)

In [None]:
import matplotlib.pyplot as plt

plt.scatter(factuals['age'], factuals['length_of_stay'], c=factuals['score'], marker='o')
# plt.xlim([0, 0.1])
# plt.show()

plt.scatter(counterfactuals['age'], counterfactuals['length_of_stay'], c=counterfactuals['score'], marker='s')
# plt.xlim([0, 0.1])
plt.show()

