In [1]:
%load_ext autoreload
%autoreload 2

In [13]:
import torch
import numpy as np

from counterfactuals.datasets import LawDataset, AdultDataset, GermanCreditDataset
from counterfactuals.cf_methods.ppcef import PPCEF
from counterfactuals.generative_models import MaskedAutoregressiveFlow
from counterfactuals.discriminative_models import MultilayerPerceptron
from counterfactuals.losses import MulticlassDiscLoss
from counterfactuals.metrics import evaluate_cf

In [18]:
datasets = {
    "adult": (
        AdultDataset("../data/adult.csv"),
        "adult_disc_model.pt",
        "adult_flow.pth",
    ),
    "law": (LawDataset("../data/law.csv"), "law_disc_model.pt", "law_flow.pth"),
    "german": (
        GermanCreditDataset("../data/german_credit.csv"),
        "german_disc_model.pt",
        "german_flow.pth",
    ),
}

dataset, disc_model_path, gen_model_path = datasets["german"]

In [22]:
dataset.X_train[:, 7:11]

array([[ 1.4554957 , -1.1856792 , -0.7186756 , -0.7083931 ],
       [-1.1928759 , -1.8245676 , -0.42487657,  2.010407  ],
       [-0.71613324, -2.3951652 ,  1.2338164 , -0.9816397 ],
       ...,
       [-0.9022146 ,  1.425001  , -0.3370698 , -1.5520179 ],
       [-0.08803034,  1.5119776 , -0.799332  , -1.2137854 ],
       [ 1.963898  , -0.8233758 , -0.6724981 , -1.436671  ]],
      dtype=float32)

In [16]:
dataset.X_train[:, 7:11]

array([[1., 0., 0., 0.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       ...,
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [1., 0., 0., 0.]], dtype=float32)

In [None]:
# dataset = AdultDataset("../data/adult.csv")
# dataset = GermanCreditDataset("../data/german_credit.csv")
# dataset = LawDataset("../data/law.csv")

In [5]:
disc_model = MultilayerPerceptron(dataset.X_test.shape[1], [512, 512], 2)
disc_model.fit(
    dataset.train_dataloader(batch_size=128, shuffle=True),
    dataset.test_dataloader(batch_size=128, shuffle=False),
    epochs=5000,
    patience=100,
    lr=1e-3,
    checkpoint_path=disc_model_path,
)
disc_model.load(disc_model_path)

  self.load_state_dict(torch.load(path))
Epoch 134, Train: 0.5444, test: 0.5725, patience: 100:   3%|▎         | 135/5000 [00:02<01:25, 56.93it/s]


In [6]:
y_pred = disc_model.predict(dataset.X_test).detach().numpy().flatten()
print("Test accuracy:", (y_pred == dataset.y_test).mean())

Test accuracy: 0.725


In [None]:
dataset.y_train = disc_model.predict(dataset.X_train).detach().numpy()
dataset.y_test = disc_model.predict(dataset.X_test).detach().numpy()

In [11]:
gen_model = MaskedAutoregressiveFlow(
    features=dataset.X_train.shape[1],
    hidden_features=16,
    num_blocks_per_layer=4,
    num_layers=8,
    context_features=1,
    batch_norm_within_layers=True,
    batch_norm_between_layers=True,
    use_random_permutations=True,
)
train_dataloader = dataset.train_dataloader(
    batch_size=256, shuffle=True, noise_lvl=0.03
)
test_dataloader = dataset.test_dataloader(batch_size=256, shuffle=False)

gen_model.fit(
    train_dataloader,
    train_dataloader,
    learning_rate=1e-3,
    patience=100,
    num_epochs=500,
    checkpoint_path=gen_model_path,
)
gen_model.load(gen_model_path)

Epoch 499, Train: -59.6201, test: -59.2076, patience: 33: 100%|██████████| 500/500 [00:53<00:00,  9.42it/s]               
  self.load_state_dict(torch.load(path))


In [None]:
# t = torch.nn.functional.gumbel_softmax(torch.rand(10, 4), dim=1)
# t

In [12]:
cf = PPCEF(
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=MulticlassDiscLoss(),
    neptune_run=None,
)
cf_dataloader = dataset.test_dataloader(batch_size=1024, shuffle=False)
log_prob_threshold = torch.quantile(gen_model.predict_log_prob(cf_dataloader), 0.25)
deltas, X_orig, y_orig, y_target, logs = cf.explain_dataloader(
    cf_dataloader,
    alpha=100,
    log_prob_threshold=log_prob_threshold,
    epochs=10000,
    lr=0.0005,
    categorical_intervals=dataset.intervals,
)

Discriminator loss: 0.5178, Prob loss: 81617190912.0000: 100%|██████████| 10000/10000 [02:51<00:00, 58.16it/s]        


In [None]:
X_cf = X_orig + deltas

evaluate_cf(
    disc_model=disc_model,
    gen_model=gen_model,
    X_cf=X_cf,
    model_returned=np.ones(X_cf.shape[0]),
    continuous_features=dataset.numerical_features,
    categorical_features=dataset.categorical_features,
    X_train=dataset.X_train,
    y_train=dataset.y_train,
    X_test=X_orig,
    y_test=y_orig,
    median_log_prob=log_prob_threshold,
    y_target=y_target,
)

{'coverage': 1.0,
 'validity': 1.0,
 'actionability': 0.03333333333333333,
 'sparsity': 0.9644736842105263,
 'proximity_categorical_hamming': 0.8588701917714915,
 'proximity_categorical_jaccard': 0.8588701917714915,
 'proximity_continuous_manhattan': 0.8743564683240689,
 'proximity_continuous_euclidean': 0.8588701917714915,
 'proximity_continuous_mad': 1.0398214643327925,
 'proximity_l2_jaccard': 0.8588701917714915,
 'proximity_mad_hamming': 1.0398214643327925,
 'prob_plausibility': 1.0,
 'log_density_cf': -178.47653,
 'log_density_test': -735.59247,
 'lof_scores_cf': 1.0352931,
 'lof_scores_test': 1.0257134,
 'isolation_forest_scores_cf': 0.030959670740704336,
 'isolation_forest_scores_test': 0.031583695862440614}

In [None]:
# torch.nn.functional.gumbel_softmax(torch.rand(4, 3), tau=0.1, dim=1)

In [None]:
X_cf = X_orig + deltas
X_cf_cat = X_cf.copy()

X_cf_cat = torch.tensor(X_cf_cat)
for interval in dataset.intervals:
    begin, end = interval
    X_cf_cat[:, begin:end] = torch.nn.functional.gumbel_softmax(
        X_cf_cat[:, begin:end], tau=0.1, dim=1
    )
    # X_cf_cat[:, begin:end] = torch.nn.functional.softmax(X_cf_cat[:, begin:end], dim=1)
X_cf_cat = X_cf_cat.numpy()

for interval in dataset.intervals:
    begin, end = interval
    max_indices = np.argmax(X_cf_cat[:, begin:end], axis=1)
    X_cf_cat[:, begin:end] = np.eye(X_cf_cat[:, begin:end].shape[1])[max_indices]

NameError: name 'X_orig' is not defined

In [None]:
evaluate_cf(
    disc_model=disc_model,
    gen_model=gen_model,
    X_cf=X_cf_cat,
    model_returned=np.ones(X_cf_cat.shape[0]),
    continuous_features=dataset.numerical_features,
    categorical_features=dataset.categorical_features,
    X_train=dataset.X_train,
    y_train=dataset.y_train,
    X_test=X_orig,
    y_test=y_orig,
    median_log_prob=log_prob_threshold,
    y_target=y_target,
)

{'coverage': 1.0,
 'validity': 0.4666666666666667,
 'actionability': 0.0,
 'sparsity': 0.995906432748538,
 'proximity_categorical_hamming': 0.8910642958216046,
 'proximity_categorical_jaccard': 0.8910642958216046,
 'proximity_continuous_manhattan': 0.9074907010846737,
 'proximity_continuous_euclidean': 0.8910642958216046,
 'proximity_continuous_mad': 1.0735054347226936,
 'proximity_l2_jaccard': 0.8910642958216046,
 'proximity_mad_hamming': 1.0735054347226936,
 'prob_plausibility': 0.0,
 'log_density_cf': -574714.56,
 'log_density_test': -735.59247,
 'lof_scores_cf': 1.5396397,
 'lof_scores_test': 1.0257134,
 'isolation_forest_scores_cf': -0.12522032679129527,
 'isolation_forest_scores_test': 0.031583695862440614}

In [None]:
# from collections import defaultdict
# import bisect

# import pandas as pd
# from sklearn.model_selection import KFold
# from sklearn.preprocessing import StandardScaler

# SEED = 42

# class TargetEncoderNormalizingDataCatalog():
#     def __init__(self, data):
#         self.data_frame = data.raw
#         self.continous = data.continous
#         self.categoricals = data.categoricals
#         self.feature_names = self.categoricals + self.continous
#         self.scaler = StandardScaler()
#         self.target = data.target
#         self.data_catalog = data
#         self.convert_to_target_encoding_form()
#         self.normalize_feature()
#         self.encoded_feature_name = ""
#         self.immutables = data.immutables

#     def convert_to_target_encoding_form(self):
#         self.cat_dict = {}
#         self.target_encoded_dict = {}
#         for feature in self.categoricals:
#             tmp_dict = defaultdict(lambda: 0)
#             data_tmp = pd.DataFrame({feature: self.data_frame[feature], self.target: self.data_frame[self.target]})
#             target_mean = data_tmp.groupby(feature)[self.target].mean()
#             self.target_encoded_dict[feature] = target_mean
#             for cat in target_mean.index.tolist():
#                 tmp_dict[cat] = target_mean[cat]
#             self.cat_dict[feature] = dict(tmp_dict)

#             tmp = np.repeat(np.nan, self.data_frame.shape[0])
#             kf = KFold(n_splits=10, shuffle=True, random_state=SEED)
#             for idx_1, idx_2 in kf.split(self.data_frame):
#                 target_mean = data_tmp.iloc[idx_1].groupby(feature)[self.target].mean()
#                 tmp[idx_2] = self.data_frame[feature].iloc[idx_2].map(target_mean)
#             self.data_frame[feature] = tmp

#         self.data_frame[self.categoricals] = self.data_frame[self.categoricals].astype('float')

#     def normalize_feature(self):
#         self.data_frame[self.feature_names] = self.scaler.fit_transform(self.data_frame[self.feature_names])

#     def denormalize_continuous_feature(self, df):
#         df[self.feature_names] = self.scaler.inverse_transform(df[self.feature_names])
#         return df

#     def convert_from_targetenc_to_original_forms(self, df):
#         for cat in self.categoricals:
#             d = self.cat_dict[cat]
#             # ソート済みのキーと値のリストを作成
#             sorted_keys = sorted(d.keys(), key=lambda k: d[k])
#             sorted_values = [d[k] for k in sorted_keys]

#             values = df[cat].values
#             replace_values = []
#             for val in values:
#                 # 二分探索でbに最も近い値のインデックスを見つける
#                 index = bisect.bisect_left(sorted_values, val)

#                 # 最も近い値のインデックスを範囲内に収める
#                 if index == len(sorted_values):
#                     index -= 1
#                 elif index > 0 and abs(sorted_values[index] - val) > abs(sorted_values[index - 1] - val):
#                     index -= 1

#                 # 最も絶対値の差が小さいキーを見つける
#                 closest_key = sorted_keys[index]
#                 replace_values.append(closest_key)
#             df[cat] = replace_values
#         return df


In [None]:
# columns = {
#     "compas": ["Sex", "Age_Cat", "Race", "C_Charge_Degree",
#                 "Priors_Count", "Time_Served", "Status"],
#     "german_credit": ["Existing-Account-Status", "Month-Duration",
#                         "Credit-History", "Purpose", "Credit-Amount",
#                         "Savings-Account", "Present-Employment", "Instalment-Rate",
#                         "Sex", "Guarantors", "Residence","Property", "Age",
#                         "Installment", "Housing", "Existing-Credits", "Job",
#                         "Num-People", "Telephone", "Foreign-Worker", "Status"],
#     "adult_income": ["Age", "Workclass", "Fnlwgt", "Education", "Marital-Status",
#                         "Occupation", "Relationship", "Race", "Sex", "Capital-Gain",
#                         "Capital-Loss", "Hours-Per-Week", "Native-Country", "Status"],
#     "default_credit": ['Limit_Bal', 'Sex', 'Education', 'Marriage', 'Age', 'Pay_0',
#                         'Pay_2', 'Pay_3', 'Pay_4', 'Pay_5', 'Pay_6', 'Bill_Amt1',
#                         'Bill_Amt2', 'Bill_Amt3', 'Bill_Amt4', 'Bill_Amt5',
#                         'Bill_Amt6', 'Pay_Amt1', 'Pay_Amt2', 'Pay_Amt3', 'Pay_Amt4',
#                         'Pay_Amt5', 'Pay_Amt6', 'Status'],
#     "heloc": ['ExternalRiskEstimate', 'MSinceOldestTradeOpen',
#                 'MSinceMostRecentTradeOpen', 'AverageMInFile',
#                 'NumSatisfactoryTrades', 'NumTrades60Ever2DerogPubRec',
#                 'NumTrades90Ever2DerogPubRec', 'PercentTradesNeverDelq',
#                 'MSinceMostRecentDelq', 'MaxDelq2PublicRecLast12M', 'MaxDelqEver',
#                 'NumTotalTrades', 'NumTradesOpeninLast12M', 'PercentInstallTrades',
#                 'MSinceMostRecentInqexcl7days', 'NumInqLast6M',
#                 'NumInqLast6Mexcl7days', 'NetFractionRevolvingBurden',
#                 'NetFractionInstallBurden', 'NumRevolvingTradesWBalance',
#                 'NumInstallTradesWBalance', 'NumBank2NatlTradesWHighUtilization',
#                 'PercentTradesWBalance', 'Status']
# }