In [1]:
import importlib
import src.utils
import src.models
import src.counterfactual

importlib.reload(src.utils)
importlib.reload(src.models)
importlib.reload(src.counterfactual)

from src.utils import load_data, load_model, DatasetMetadata, clean_instance
from src.counterfactual import newton_op, distance, unscale_instance
import torch
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import sympy as sp
# str to sympy


from torch.utils.data import DataLoader
from src.models import LogisticModel
import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# device = device if not torch.backends.mps.is_available() else torch.device("mps")


In [2]:
class State:
    def __init__(self, model, metadata, max_epochs, dx_scaled, mean_scaled, upd_weights):
        self.model: LogisticModel = model
        self.metadata: DatasetMetadata = metadata
        self.dx_scaled: torch.Tensor = dx_scaled
        self.mean_scaled: torch.Tensor = mean_scaled
        self.epochs: int = 0
        self.max_epochs: int = max_epochs
        self.upd_weights: torch.Tensor = upd_weights # columns to be updated
        self.apply_reg = False # When to apply integer regularization
        self.reg_vars = False # When to apply nº variables regularization


In [3]:
filename = 'data/Loan_default.csv'
model_name = "model_small"
model_dict = "models/"+model_name+".pth"

In [4]:
# load the model
test_data: DataLoader
train_data, _, test_data, _, metadata = load_data(filename, batch_size=1024)

inputs = next(iter(test_data))[0].to(torch.float32).to(device)

# define model
model = load_model(model_name).to(torch.float32).to(device)

torch.save(model.state_dict(), model_dict)


## Extract model equation

In [5]:
# import sympy as sp
# import torch
# from sympy.parsing.sympy_parser import parse_expr

# def extract_symbolic_equation(model: torch.nn.Module, instance: torch.Tensor):
#     """
#     Extracts a symbolic equation from a trained PyTorch model.
#     Assumes a feedforward structure with linear layers and activations.
#     """
#     # Define symbolic variables for input features
#     x2, x3 = sp.symbols('x2 x3')  # Inputs
#     # constants = sp.symbols(f'c1:{model.input_dim + 1}')  # Constants for other features
    
#     # Build input vector with constants
#     x = [instance[i].item() if i not in [1, 2] else (x2 if i == 1 else x3) for i in range(model.input_dim)]
    
#     # Convert to a sympy matrix
#     X = sp.Matrix(x)
#     activations = []

#     # Iterate over layers
#     for layer in model.layers:
#         if isinstance(layer, torch.nn.Linear):
#             W = sp.Matrix(layer.weight.detach().numpy())  # Extract weight matrix
#             b = sp.Matrix(layer.bias.detach().numpy())    # Extract bias
#             X = W * X + b  # Apply linear transformation
#         elif isinstance(layer, torch.nn.ReLU):
#             activations.append(X)
#             X = X.applyfunc(lambda val: sp.Max(0, val))  # ReLU activation
#         elif isinstance(layer, torch.nn.Sigmoid):
#             activations.append(X)
#             X = X.applyfunc(lambda val: 1 / (1 + sp.exp(-val)))  # Sigmoid activation
#         # X.subs({sp.symbols(f'c{i+1}'): val for i, val in enumerate(inputs[0]) if i != 1 and i != 2})
#         print("Done: ", layer)
#         # print(X)


#     # Apply softmax at the end
#     denominator = sp.Add(*(sp.exp(e) for e in X))
#     softmax_expr = sp.Matrix([sp.exp(e) / denominator for e in X])

#     return softmax_expr, activations # .simplify()

# # Example usage
# model_sym = LogisticModel(inputs.shape[1], hidden_sizes=[16, 8])
# model_sym.load_state_dict(torch.load(model_dict))  # Load trained weights
# symbolic_eq, activations = extract_symbolic_equation(model_sym, inputs[0])
# model_eq = symbolic_eq[0]
# print(symbolic_eq)
# # with open('small.txt', 'w') as f:
# #     f.write(str(symbolic_eq))
# with open('small.txt', 'r') as f:
#     symbolic_eq1 = f.read()
#     symbolic_eq1 = parse_expr(symbolic_eq1)


## Training

In [6]:

person: torch.Tensor = inputs[0].to(torch.float32).to(device)
outputs = model(inputs).argmax(dim=1)
inputs_useful = inputs[outputs == 1]
# metadata.cols_for_mask = [True] * 100 + [False] * (len(metadata.cols_for_mask) - 100)
# metadata.cols_for_mask = [True] * len(metadata.cols_for_mask)
# metadata.cols_for_mask[1] = True
# metadata.cols_for_mask[2] = True
# metadata.cols_for_mask[3] = True
# metadata.cols_for_mask[4] = True
# metadata.cols_for_mask[5] = True
# metadata.cols_for_mask[6] = True
# metadata.cols_for_mask[7] = True
# metadata.cols_for_mask[8] = True

weights = torch.tensor(metadata.cols_for_mask, dtype=torch.float32).to(device)
# weights = torch.ones_like(inputs_useful[0], dtype=torch.float32).to(device)


In [7]:
person = inputs_useful[0].to(torch.float32).to(device)
# weights = torch.tensor([0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int).to(device)

## Global check

In [8]:
# import src.counterfactual

# importlib.reload(src.counterfactual)
# from src.counterfactual import newton_op, distance
# person = inputs_useful[0]
# p_new, state_p = newton_op(model, person, metadata, weights, 0.1, reg_int=False, reg_vars=False, reg_clamp=True, print_=False)
# torch.manual_seed(torch.randint(0, 100, (1,)).item())
# n = 5
# num_points = 10000
# num_linspace = 5000
# indexes = torch.nonzero(weights).reshape(-1)
# print(indexes)
# sampled_indexes = torch.tensor([6, 8]) # indexes #[torch.randint(0, len(indexes), (n,))]
# print(sampled_indexes)

# for sample_var in sampled_indexes:
#     w = weights.clone()
#     w[sample_var] = 0
#     print("Sampled variable:", sample_var.item())
#     x = p_new.repeat(num_points*num_linspace, 1)
#     print("points repeated")
#     # print("x:", x[:, w != 0])

#     # print(x)
#     x[:, w != 0] = (torch.distributions.uniform.Uniform(metadata.min_values, metadata.max_values).sample((num_points,)) * w)[:, w != 0].repeat(num_linspace, 1)
#     print("points generated")

#     x[:, sample_var] = torch.linspace(metadata.min_values[sample_var], metadata.max_values[sample_var], num_linspace).repeat(num_points)
#     print("linspace generated")

#     x = x[model(x)[:, 0] > metadata.threshold]
#     print("model filtered")

#     # calculate the distance
#     dists = distance(person, x, weights, state_p, with_sum=False)
#     print(torch.min(dists), distance(person, p_new, weights, state_p))
#     x = x[dists < distance(person, p_new, weights, state_p)]
#     print("distance filtered")
#     print(len(x))

## Trials

### Only 1 person

In [12]:
import src.counterfactual
import src.checks

importlib.reload(src.counterfactual)
importlib.reload(src.checks)
from src.counterfactual import newton_op
from src.checks import Checks

reg_int=True
reg_clamp=True

metadata.threshold = 0.5 + 1e-7
person = inputs_useful[6]
p_new, _ = newton_op(model, person, metadata, weights, 0.2, reg_int=reg_int, reg_clamp=reg_clamp, print_=True)
p = person.clone().detach()
check = Checks(model, metadata, reg_int=reg_int, reg_clamp=reg_clamp, dataset=next(iter(train_data))[0])
valid = check(person, p_new, weights)
print("Valid:", valid)
display(pd.DataFrame([unscale_instance(p_new, metadata).detach().numpy()], columns=metadata.columns))
# print(model(p_new.unsqueeze(0))[0][metadata.good_class].item(), model(p_new.unsqueeze(0)).argmax(dim=1))
# if check.sorted_points is not None:
#     minimal = torch.tensor(check.sorted_points.iloc[0, :-2].values).float()
#     print(distance(person, p_new, weights).item(), distance(person, minimal, weights).item())
#     display(pd.DataFrame(unscale_batch(torch.tensor(check.sorted_points.to_numpy()[:, :-2]).float(), metadata), columns=metadata.columns))

Epoch: 0
ONLY MODEL DERIVATIVE: 0.05711791664361954
Changes:  delta1: -0.15528741478919983  delta_l: 0.0
dist: 0.10394056886434555 , threshold: 0.2816452980041504
Epoch: 1
ONLY MODEL DERIVATIVE: 0.09744583815336227
Changes:  delta1: -0.13539542257785797  delta_l: 0.0
dist: 0.3648679852485657 , threshold: 0.22263136506080627
Epoch: 2
ONLY MODEL DERIVATIVE: 0.1574321985244751
Changes:  delta1: -0.10684225708246231  delta_l: 0.0
dist: 0.6833893656730652 , threshold: 0.15195104479789734
Epoch: 3
Changes:  delta1: -0.19220706820487976  delta_l: -5.7508745193481445
dist: 1.5147488117218018 , threshold: -0.04436570405960083
Epoch: 4
Changes:  delta1: -7.619365078426199e-06  delta_l: 2.3715884685516357
dist: 1.351460337638855 , threshold: -0.0009301900863647461
Epoch: 5
Changes:  delta1: 5.775530054208389e-11  delta_l: 0.05691205710172653
dist: 1.3479050397872925 , threshold: -5.960464477539062e-07
Epoch: 6
Changes:  delta1: 5.775642117344937e-11  delta_l: 6.262936949497089e-05
dist: 1.3479028

Unnamed: 0,Age,Income,LoanAmount,CreditScore,MonthsEmployed,NumCreditLines,InterestRate,LoanTerm,DTIRatio,Education_High School,...,EmploymentType_Unemployed,MaritalStatus_Married,MaritalStatus_Single,HasMortgage_Yes,HasDependents_Yes,LoanPurpose_Business,LoanPurpose_Education,LoanPurpose_Home,LoanPurpose_Other,HasCoSigner_Yes
0,19.0,121621.0,76545.0,563.0,64.0,4.0,17.823526,60.0,0.677917,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0


### 1 person, different weights


In [137]:
import src.counterfactual

importlib.reload(src.counterfactual)
from src.counterfactual import newton_op

reg_int=True
reg_clamp=False
check = Checks(model, metadata, reg_int=reg_int, reg_clamp=reg_clamp)

person = inputs_useful[1]

weights1 = weights.clone()
print(weights1)
p_new1, state_p = newton_op(model, person, metadata, weights1, 0.2, reg_int=reg_int, reg_clamp=reg_clamp, print_=True)
valid1 = check(person, p_new, weights)
print("Valid1:", valid1)

weights2 = weights.clone()
weights2[1] = 2
print(weights2)
p_new2, state_p = newton_op(model, person, metadata, weights2, 0.2, reg_int=reg_int, reg_clamp=reg_clamp, print_=True)
valid2 = check(person, p_new, weights)
print("Valid2:", valid2)

display(pd.DataFrame([unscale_instance(person, metadata).detach().numpy(), unscale_instance(p_new1, metadata).detach().numpy(), unscale_instance(p_new2, metadata).detach().numpy()], columns=metadata.columns))
# display(check.sorted_points)
# print(model(p_new.unsqueeze(0))[0][metadata.good_class].item(), model(p_new.unsqueeze(0)).argmax(dim=1))
# if check.sorted_points is not None:
#     minimal = torch.tensor(check.sorted_points.iloc[0, :-2].values).float()
#     print(distance(person, p_new, weights).item(), distance(person, minimal, weights).item())
#     display(pd.DataFrame(unscale_batch(torch.tensor(check.sorted_points.to_numpy()[:, :-2]).float(), metadata), columns=metadata.columns))

tensor([0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.])
Epoch: 0
ONLY MODEL DERIVATIVE: 0.0828336551785469
Changes:  delta1: -0.13978469371795654  delta_l: 0.0
dist: 0.08863423019647598 , threshold: 0.24497714638710022
Epoch: 1
ONLY MODEL DERIVATIVE: 0.1373034566640854
Changes:  delta1: -0.11506202071905136  delta_l: 0.0
dist: 0.29451456665992737 , threshold: 0.17655232548713684
Epoch: 2
Changes:  delta1: -0.24510157108306885  delta_l: -5.36952018737793
dist: 1.130959391593933 , threshold: -0.07328683137893677
Epoch: 3
Changes:  delta1: 1.162972785095917e-05  delta_l: 3.02851939201355
dist: 0.9077669978141785 , threshold: -0.0018201470375061035
Epoch: 4
Changes:  delta1: 8.847264587608095e-11  delta_l: 0.030933715403079987
dist: 0.9026430249214172 , threshold: -2.1457672119140625e-06
Epoch: 5
Changes:  delta1: 8.847574062276209e-11  delta_l: 5.725659502786584e-05
dist: 0.9026370644569397 , threshold: -5.960464477539063e-08
Epoch:

Unnamed: 0,Age,Income,LoanAmount,CreditScore,MonthsEmployed,NumCreditLines,InterestRate,LoanTerm,DTIRatio,Education_High School,...,EmploymentType_Unemployed,MaritalStatus_Married,MaritalStatus_Single,HasMortgage_Yes,HasDependents_Yes,LoanPurpose_Business,LoanPurpose_Education,LoanPurpose_Home,LoanPurpose_Other,HasCoSigner_Yes
0,19.0,29467.0,151769.0,606.0,33.0,1.0,6.63,24.0,0.48,1.0,...,1.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0
1,19.0,48946.0,121685.0,633.0,49.0,1.0,3.442996,24.0,0.464706,1.0,...,1.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0
2,19.0,40415.0,117950.0,636.0,52.0,1.0,3.010474,24.0,0.462681,1.0,...,1.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0


### 1 person, different flags (clamp, int)

In [138]:
# import src.counterfactual

# importlib.reload(src.counterfactual)
# from src.counterfactual import newton_op, distance, unscale_instance, scale_instance
# person = inputs_useful[47]
# person_new, state = newton_op(model, person, metadata, weights, 0.1, print_=True)
# person_new_clamp, _ = newton_op(model, person, metadata, weights, 0.1, reg_clamp=True, print_=True)
# person_new_int, _ = newton_op(model, person, metadata, weights, reg_int=True, print_=True)
# person_new_clamp_int, _ = newton_op(model, person, metadata, weights, 0.1, reg_int=True, reg_clamp=True, print_=True)
# # person_new_vars, _ = newton_op(model, person, metadata, weights, reg_vars=True, print_=True)
# # person_new_int_vars, _ = newton_op(model, person, metadata, weights, 0.1, reg_int=True, reg_vars=True, print_=True)

# names = ['person', 'person_new', 'person_new_clamp','person_new_int', 'person_new_clamp_int']
# ps = [eval(i) for i in names]
# outputs = [model(p.unsqueeze(0))[0][metadata.good_class].item() for p in ps]

# distances = [distance(person, p, weights, state=state).item() for p in ps]

# a = pd.DataFrame([unscale_instance(x, metadata).detach().cpu().numpy().reshape(-1) for x in ps], columns=metadata.columns)
# a['output'] = outputs
# a['distance'] = distances
# # set index
# a['names'] = names
# a = a.set_index('names')
# print(a.columns)
# a

### Batch

In [151]:
import src.counterfactual
import src.checks
importlib.reload(src.counterfactual)
importlib.reload(src.checks)
from src.counterfactual import newton_op
from src.checks import Checks

reg_int = True
reg_clamp = True
metadata.threshold = 0.5 + 1e-7

successes = 0
epochs = 0
bad_idxs = []
total = 0
total_time = []
check = Checks(model, metadata, reg_int=reg_int, reg_clamp=reg_clamp)
for idx, p in enumerate(inputs_useful[:176]):

    from time import perf_counter
    
    s = perf_counter()
    p_new, ep = newton_op(model, p, metadata, weights, 0.2, reg_int=reg_int, reg_clamp=reg_clamp, der=False)
    time = perf_counter()-s
    print("Person:", idx, f'Time: {time} s') # if time > 0.2 else '')
    # TODO: poner la minimalidad
    valid = check(p, p_new, weights)
    successes += valid # and (((state_p.metadata.max_values < p_new) | (state_p.metadata.min_values > p_new)).sum() == 0))
    # print("Person:", idx, "Rate of grad desc:",minimality_check(p, p_new, weights, ep, model))
    epochs += ep.epochs
    total += 1
    total_time.append(time)
    # if not valid:
    #     bad_idxs.append(idx)
        # print(idx, valid)
print("Successes:", successes, "Total:", total)
print("Average epochs:", epochs / total)
print("Time:", np.array(total_time).mean() , np.array(total_time).std() )
print("Stability:", np.array(check.diff_factors).mean(), '±', np.array(check.diff_factors).std())
print("Success rate:", successes / total)


Person: 0 Time: 0.03498062497237697 s
Person: 1 Time: 0.024592834000941366 s
Person: 2 Time: 0.023454791982658207 s
Person: 3 Time: 0.01598016597563401 s
Person: 4 Time: 0.0363622090080753 s
Person: 5 Time: 0.01630404096795246 s
Person: 6 Time: 0.031305333017371595 s
Person: 7 Time: 0.026613374997396022 s
Person: 8 Time: 0.026215707999654114 s
Person: 9 Time: 0.020610959036275744 s
Person: 10 Time: 0.02429954201215878 s
Person: 11 Time: 0.029072375036776066 s
Person: 12 Time: 0.03063470800407231 s
Person: 13 Time: 0.020688374992460012 s
Person: 14 Time: 0.019534000020939857 s
Person: 15 Time: 0.023705292027443647 s
Person: 16 Time: 0.021956415963359177 s
Person: 17 Time: 0.026825707987882197 s
Person: 18 Time: 0.02642929198918864 s
Person: 19 Time: 0.021233165985904634 s
Person: 20 Time: 0.022055500012356788 s
Person: 21 Time: 0.01629433297784999 s
Person: 22 Time: 0.02514075004728511 s
Person: 23 Time: 0.02792491699801758 s
Person: 24 Time: 0.022228166984859854 s
Person: 25 Time: 0.02

KeyboardInterrupt: 

### All batches

In [None]:
# loan 4.353627812117338 ± 5.806163177495694

In [None]:
# successes = 0
# bad_idxs = []
# total = 0
# for i, inputs in enumerate(test_data):
#     print(i, end='\r')
#     outputs = model(inputs[0]).argmax(dim=1)
#     inputs_useful = inputs[0][outputs == 1]
#     for idx, p in enumerate(inputs_useful):
#         _, ep = newton_op(model, p, weights, 0.1) #if idx not in [103, 105, 237, 406, 417, 450] else None
#         # print("Person:", idx, "Success:", not ep)
#         successes += ep
#         total += 1
#         # if not ep:
#         #     bad_idxs.append(idx)
#     print(successes/total)
# print("Successes:", successes, "Total:", total)
# print("Success rate:", successes / total)

In [114]:
print("Average epochs:", epochs / total)
print("Time:", np.array(total_time).mean() , np.array(total_time).std() )
print("Stability:", np.array(check.diff_factors).mean(), '±', np.array(check.diff_factors).std())
print("Success rate:", successes / total)


Average epochs: 8.909090909090908
Time: 0.03304126153869385 0.025611918371383748
Stability: nan ± nan
Success rate: 1.0
