In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np

from counterfactuals.datasets import LawDataset, AdultDataset, GermanCreditDataset
from counterfactuals.cf_methods.ppcef import PPCEF
from counterfactuals.generative_models import MaskedAutoregressiveFlow
from counterfactuals.discriminative_models import MultilayerPerceptron
from counterfactuals.losses import MulticlassDiscLoss
from counterfactuals.metrics import evaluate_cf

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
datasets = {
    "adult": (
        AdultDataset("../data/adult.csv"),
        "adult_disc_model.pt",
        "adult_flow.pth",
    ),
    "law": (LawDataset("../data/law.csv"), "law_disc_model.pt", "law_flow.pth"),
    "german": (
        GermanCreditDataset("../data/german_credit.csv"),
        "german_disc_model.pt",
        "german_flow.pth",
    ),
}

dataset, disc_model_path, gen_model_path = datasets["law"]

In [None]:
from counterfactuals.datasets.utils import CustomCategoricalTransformer
from sklearn.compose import ColumnTransformer


def dequantize(dataset):
    transformers = [
        (f"cat_group_{i}", CustomCategoricalTransformer(), group)
        for i, group in enumerate(dataset.categorical_features_lists)
    ]

    column_transformer = ColumnTransformer(
        transformers=transformers, remainder="passthrough"
    )

    dataset.X_train = column_transformer.fit_transform(dataset.X_train)
    dataset.X_test = column_transformer.transform(dataset.X_test)
    return column_transformer


# def inverse_dequantizaiton(dataset, dequantizer):
def quantize(dataset, dequantizer):
    for categorical_features, transform in zip(
        dataset.categorical_features_lists, dequantizer.named_transformers_
    ):
        dataset.X_train[:, categorical_features] = dequantizer.named_transformers_[
            transform
        ].inverse_transform(dataset.X_train[:, list(range(len(categorical_features)))])
        dataset.X_test[:, categorical_features] = dequantizer.named_transformers_[
            transform
        ].inverse_transform(dataset.X_test[:, list(range(len(categorical_features)))])

In [None]:
dequantizer = dequantize(dataset)

In [4]:
dataset.inverse_dequantization()

# disc_model = MultilayerPerceptron(dataset.X_test.shape[1], [512, 512], 2)
disc_model = MultilayerPerceptron(dataset.X_test.shape[1], [256, 256], 2)
disc_model.fit(
    dataset.train_dataloader(batch_size=128, shuffle=True),
    dataset.test_dataloader(batch_size=128, shuffle=False),
    epochs=5000,
    patience=100,
    lr=1e-3,
    checkpoint_path=disc_model_path,
)
disc_model.load(disc_model_path)
# disc_model.load("german_disc_model_onehot.pt")

  self.load_state_dict(torch.load(path))
Epoch 221, Train: 0.5015, test: 0.4936, patience: 100:   4%|▍         | 222/5000 [00:07<02:37, 30.42it/s]


In [5]:
y_pred = disc_model.predict(dataset.X_test).detach().numpy().flatten()
print("Test accuracy:", (y_pred == dataset.y_test).mean())

Test accuracy: 0.75


In [6]:
dataset.y_train = disc_model.predict(dataset.X_train).detach().numpy()
dataset.y_test = disc_model.predict(dataset.X_test).detach().numpy()

In [7]:
dataset.dequantize()

In [8]:
gen_model = MaskedAutoregressiveFlow(
    features=dataset.X_train.shape[1],
    hidden_features=16,
    num_blocks_per_layer=4,
    num_layers=8,
    context_features=1,
    batch_norm_within_layers=True,
    batch_norm_between_layers=True,
    use_random_permutations=True,
)
train_dataloader = dataset.train_dataloader(
    batch_size=256, shuffle=True, noise_lvl=0.03
)
test_dataloader = dataset.test_dataloader(batch_size=256, shuffle=False)

# gen_model.fit(
#     train_dataloader,
#     train_dataloader,
#     learning_rate=1e-3,
#     patience=100,
#     num_epochs=500,
#     checkpoint_path=gen_model_path,
# )
gen_model.load(gen_model_path)

  self.load_state_dict(torch.load(path))


In [9]:
# t = torch.nn.functional.gumbel_softmax(torch.rand(10, 4), dim=1)
# t

In [10]:
cf = PPCEF(
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=MulticlassDiscLoss(),
    neptune_run=None,
)

target_class = 0
X_test_origin = dataset.X_test[dataset.y_test != target_class]
y_test_origin = dataset.y_test[dataset.y_test != target_class]

cf_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(
        torch.tensor(X_test_origin).float(),
        torch.tensor(y_test_origin).float(),
    ),
    batch_size=1024,
    shuffle=False,
)


log_prob_threshold = torch.quantile(gen_model.predict_log_prob(cf_dataloader), 0.25)
deltas, X_orig, y_orig, y_target, logs = cf.explain_dataloader(
    cf_dataloader,
    alpha=100,
    log_prob_threshold=log_prob_threshold,
    epochs=4000,
    lr=0.001,
    categorical_intervals=dataset.categorical_features_lists,
)

Discriminator loss: 0.3668, Prob loss: 0.0000: 100%|██████████| 4000/4000 [01:10<00:00, 56.87it/s]  


In [11]:
# X_cf = X_orig + deltas

# evaluate_cf(
#     disc_model=disc_model,
#     gen_model=gen_model,
#     X_cf=X_cf,
#     model_returned=np.ones(X_cf.shape[0]),
#     continuous_features=dataset.numerical_features,
#     categorical_features=dataset.categorical_features,
#     X_train=dataset.X_train,
#     y_train=dataset.y_train,
#     X_test=X_orig,
#     y_test=y_orig,
#     median_log_prob=log_prob_threshold,
#     y_target=y_target,
# )

In [12]:
# torch.nn.functional.gumbel_softmax(torch.rand(4, 3), tau=0.1, dim=1)

In [13]:
# disc_model = MultilayerPerceptron(dataset.X_test.shape[1], [512, 512], 2)
# disc_model.load("german_disc_model_onehot.pt")

In [20]:
X_cf = X_orig + deltas
X_cf_cat = X_cf.copy()

# X_cf_cat = torch.tensor(X_cf_cat)
# for interval in dataset.intervals:
#     begin, end = interval
#     X_cf_cat[:, begin:end] = torch.nn.functional.gumbel_softmax(
#         X_cf_cat[:, begin:end], tau=0.1, dim=1
#     )
# X_cf_cat[:, begin:end] = torch.nn.functional.softmax(X_cf_cat[:, begin:end], dim=1)
# X_cf_cat = X_cf_cat.numpy()

for interval in dataset.categorical_features_lists:
    max_indices = np.argmax(X_cf_cat[:, interval], axis=1)
    X_cf_cat[:, interval] = np.eye(X_cf_cat[:, interval].shape[1])[max_indices]

In [21]:
evaluate_cf(
    disc_model=disc_model,
    gen_model=gen_model,
    X_cf=X_cf_cat,
    model_returned=np.ones(X_cf_cat.shape[0]),
    continuous_features=dataset.numerical_features,
    categorical_features=dataset.categorical_features,
    X_train=dataset.X_train,
    y_train=dataset.y_train,
    X_test=X_orig,
    y_test=y_orig,
    median_log_prob=log_prob_threshold,
    y_target=y_target,
)

{'coverage': 1.0,
 'validity': 0.7323943661971831,
 'actionability': 0.0,
 'sparsity': 0.8461538461538461,
 'proximity_categorical_hamming': 0.7041734617815552,
 'proximity_categorical_jaccard': 0.7041734617815552,
 'proximity_continuous_manhattan': 0.758775208960976,
 'proximity_continuous_euclidean': 0.7041734617815552,
 'proximity_continuous_mad': 1.8347812809887722,
 'proximity_l2_jaccard': 0.7041734617815552,
 'proximity_mad_hamming': 1.8347812809887722,
 'prob_plausibility': 1.0,
 'log_density_cf': -183.50435,
 'log_density_test': -14291.4,
 'lof_scores_cf': 1.7357957,
 'lof_scores_test': 2.9949172,
 'isolation_forest_scores_cf': -0.013422678107123704,
 'isolation_forest_scores_test': -0.05890787676718063}

In [17]:
# from collections import defaultdict
# import bisect

# import pandas as pd
# from sklearn.model_selection import KFold
# from sklearn.preprocessing import StandardScaler

# SEED = 42

# class TargetEncoderNormalizingDataCatalog():
#     def __init__(self, data):
#         self.data_frame = data.raw
#         self.continous = data.continous
#         self.categoricals = data.categoricals
#         self.feature_names = self.categoricals + self.continous
#         self.scaler = StandardScaler()
#         self.target = data.target
#         self.data_catalog = data
#         self.convert_to_target_encoding_form()
#         self.normalize_feature()
#         self.encoded_feature_name = ""
#         self.immutables = data.immutables

#     def convert_to_target_encoding_form(self):
#         self.cat_dict = {}
#         self.target_encoded_dict = {}
#         for feature in self.categoricals:
#             tmp_dict = defaultdict(lambda: 0)
#             data_tmp = pd.DataFrame({feature: self.data_frame[feature], self.target: self.data_frame[self.target]})
#             target_mean = data_tmp.groupby(feature)[self.target].mean()
#             self.target_encoded_dict[feature] = target_mean
#             for cat in target_mean.index.tolist():
#                 tmp_dict[cat] = target_mean[cat]
#             self.cat_dict[feature] = dict(tmp_dict)

#             tmp = np.repeat(np.nan, self.data_frame.shape[0])
#             kf = KFold(n_splits=10, shuffle=True, random_state=SEED)
#             for idx_1, idx_2 in kf.split(self.data_frame):
#                 target_mean = data_tmp.iloc[idx_1].groupby(feature)[self.target].mean()
#                 tmp[idx_2] = self.data_frame[feature].iloc[idx_2].map(target_mean)
#             self.data_frame[feature] = tmp

#         self.data_frame[self.categoricals] = self.data_frame[self.categoricals].astype('float')

#     def normalize_feature(self):
#         self.data_frame[self.feature_names] = self.scaler.fit_transform(self.data_frame[self.feature_names])

#     def denormalize_continuous_feature(self, df):
#         df[self.feature_names] = self.scaler.inverse_transform(df[self.feature_names])
#         return df

#     def convert_from_targetenc_to_original_forms(self, df):
#         for cat in self.categoricals:
#             d = self.cat_dict[cat]
#             # ソート済みのキーと値のリストを作成
#             sorted_keys = sorted(d.keys(), key=lambda k: d[k])
#             sorted_values = [d[k] for k in sorted_keys]

#             values = df[cat].values
#             replace_values = []
#             for val in values:
#                 # 二分探索でbに最も近い値のインデックスを見つける
#                 index = bisect.bisect_left(sorted_values, val)

#                 # 最も近い値のインデックスを範囲内に収める
#                 if index == len(sorted_values):
#                     index -= 1
#                 elif index > 0 and abs(sorted_values[index] - val) > abs(sorted_values[index - 1] - val):
#                     index -= 1

#                 # 最も絶対値の差が小さいキーを見つける
#                 closest_key = sorted_keys[index]
#                 replace_values.append(closest_key)
#             df[cat] = replace_values
#         return df


In [18]:
# columns = {
#     "compas": ["Sex", "Age_Cat", "Race", "C_Charge_Degree",
#                 "Priors_Count", "Time_Served", "Status"],
#     "german_credit": ["Existing-Account-Status", "Month-Duration",
#                         "Credit-History", "Purpose", "Credit-Amount",
#                         "Savings-Account", "Present-Employment", "Instalment-Rate",
#                         "Sex", "Guarantors", "Residence","Property", "Age",
#                         "Installment", "Housing", "Existing-Credits", "Job",
#                         "Num-People", "Telephone", "Foreign-Worker", "Status"],
#     "adult_income": ["Age", "Workclass", "Fnlwgt", "Education", "Marital-Status",
#                         "Occupation", "Relationship", "Race", "Sex", "Capital-Gain",
#                         "Capital-Loss", "Hours-Per-Week", "Native-Country", "Status"],
#     "default_credit": ['Limit_Bal', 'Sex', 'Education', 'Marriage', 'Age', 'Pay_0',
#                         'Pay_2', 'Pay_3', 'Pay_4', 'Pay_5', 'Pay_6', 'Bill_Amt1',
#                         'Bill_Amt2', 'Bill_Amt3', 'Bill_Amt4', 'Bill_Amt5',
#                         'Bill_Amt6', 'Pay_Amt1', 'Pay_Amt2', 'Pay_Amt3', 'Pay_Amt4',
#                         'Pay_Amt5', 'Pay_Amt6', 'Status'],
#     "heloc": ['ExternalRiskEstimate', 'MSinceOldestTradeOpen',
#                 'MSinceMostRecentTradeOpen', 'AverageMInFile',
#                 'NumSatisfactoryTrades', 'NumTrades60Ever2DerogPubRec',
#                 'NumTrades90Ever2DerogPubRec', 'PercentTradesNeverDelq',
#                 'MSinceMostRecentDelq', 'MaxDelq2PublicRecLast12M', 'MaxDelqEver',
#                 'NumTotalTrades', 'NumTradesOpeninLast12M', 'PercentInstallTrades',
#                 'MSinceMostRecentInqexcl7days', 'NumInqLast6M',
#                 'NumInqLast6Mexcl7days', 'NetFractionRevolvingBurden',
#                 'NetFractionInstallBurden', 'NumRevolvingTradesWBalance',
#                 'NumInstallTradesWBalance', 'NumBank2NatlTradesWHighUtilization',
#                 'PercentTradesWBalance', 'Status']
# }