## Subdivision of the dataset into N institutions 

In [18]:
# Libraries
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.model_selection import train_test_split

In [27]:
# Load data
# Diabeters
df_train = pd.read_csv('diabetes_binary_5050split_health_indicators_BRFSS2015.csv')
df_train = df_train.rename(columns={'Diabetes_binary': 'Labels'})
# Breast cancer
X_breast = pd.read_csv('X_breast.csv')
y_breast = pd.read_csv('y_breast.csv')
y_breast['Diagnosis'] = y_breast['Diagnosis'].map({'M': 1, 'B': 0})
# add labels to X_breast with the same name as in df_train
df_train_breast = pd.DataFrame(X_breast)
df_train_breast['Labels'] = y_breast['Diagnosis']

print(f"Diabetes dataset: {df_train.shape}")
print(f"Breast cancer dataset: {df_train_breast.shape}")

Diabetes dataset: (70692, 22)
Breast cancer dataset: (569, 32)


### Random Subdivision 


In [43]:
# N institutions (5% out for testing)
N = 3

def random_split(df, N, file_prefix='df_diabetes'):
    """
    Splits a DataFrame into N parts and saves each part as a CSV file.

    Parameters:
    df (pd.DataFrame): The DataFrame to split.
    N (int): Number of parts to split the DataFrame into.
    file_prefix (str): Prefix for the output file names.
    """
    # Shuffle the DataFrame
    df_shuffled = df.sample(frac=1, random_state=1).reset_index(drop=True)

    # Leave out 5% for testing
    df_train, df_test = train_test_split(df_shuffled, test_size=0.15, random_state=1)
    df_test.to_csv(file_prefix + '_random_test.csv', index=False)
    print(f'Saved: {file_prefix}_random_test.csv of shape {df_test.shape}')

    # Split the DataFrame into N parts
    df_splits = np.array_split(df_train, N)

    # Save each part as a CSV file
    for i, split in enumerate(df_splits, start=1):
        filename = f'{file_prefix}_random_{i}.csv'
        split.to_csv(filename, index=False)
        print(f'Saved: {filename} of shape {split.shape}')


random_split(df_train, N, file_prefix='df_diabetes')
random_split(df_train_breast, N, file_prefix='df_breast')


Saved: df_diabetes_random_test.csv of shape (10604, 22)
Saved: df_diabetes_random_1.csv of shape (20030, 22)
Saved: df_diabetes_random_2.csv of shape (20029, 22)


  return bound(*args, **kwds)


Saved: df_diabetes_random_3.csv of shape (20029, 22)
Saved: df_breast_random_test.csv of shape (86, 32)
Saved: df_breast_random_1.csv of shape (161, 32)
Saved: df_breast_random_2.csv of shape (161, 32)
Saved: df_breast_random_3.csv of shape (161, 32)


### Cluster based Subdivision

In [44]:
from sklearn.cluster import KMeans
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import copy

# Function to calculate Euclidean distances between centroids
def centroid_distances(centroids0, centroids1):
    N = len(centroids0)
    print(f"N: {N}")
    distances = np.zeros((N, N))
    for i in range(N):
        for j in range(N):
            distances[i, j] = np.linalg.norm(centroids0[i] - centroids1[j])
    return distances

# Function to calculate centroids
def calculate_centroids(df, labels):
    N = len(np.unique(labels))
    centroids = []
    for i in range(N):
        centroids.append(df[labels == i].mean().to_numpy())
    return centroids

def cluster_by_class_split(df_train, N, file_prefix='df_diabetes'):
    """
    In this code, distances will be a matrix where the element at [i, j] represents
    the distance between the i-th cluster of class 0 and the j-th cluster of class 1.
    The final matrix will be a N x N matrix, not simmetrical in general.
    The following result means that for the first cluster of class 0, the second cluster 
    of class 1 is the closest one. For the second cluster of class 0, the third cluster of
    class 1 is the closest one. And so on.
    array([[22.52661847, 16.58598092, 30.50548191],
       [ 4.33080647, 32.17891945, 25.41195157],
       [27.11059815, 19.7759446 ,  8.12520036]])
    """

    # Leave out 5% for testing
    df_train, df_test = train_test_split(df_train, test_size=0.15, random_state=1)
    df_test.to_csv(file_prefix + '_2cluster_test.csv', index=False)
    print(f'Saved: {file_prefix}_2cluster_test.csv of shape {df_test.shape}')

    # Splitting the dataset by class
    df_train_0 = df_train[df_train['Labels'] == 0].drop('Labels', axis=1)
    df_train_1 = df_train[df_train['Labels'] == 1].drop('Labels', axis=1)
    # KMeans clustering
    kmeans_0 = KMeans(n_clusters=N, random_state=1).fit(df_train_0)
    kmeans_1 = KMeans(n_clusters=N, random_state=1).fit(df_train_1)
    # Calculating centroids
    centroids_0 = calculate_centroids(df_train_0, kmeans_0.labels_)
    centroids_1 = calculate_centroids(df_train_1, kmeans_1.labels_)
    # Calculating distances
    distance_matrix = centroid_distances(centroids_0, centroids_1)  

    # Pairing clusters
    pairs = pair_clusters(distance_matrix)

    # create the N clusters
    i = 1
    for c0,c1 in pairs:
        df_0 = df_train[df_train['Labels'] == 0][kmeans_0.labels_ == c0]
        df_1 = df_train[df_train['Labels'] == 1][kmeans_1.labels_ == c1]
        # merge the clusters
        df = pd.concat([df_0, df_1])
        # randomize the order of the rows
        df = df.sample(frac=1).reset_index(drop=True)
        # save the new dataset
        filename = f'{file_prefix}_2cluster_{i}.csv'
        df.to_csv(filename, index=False)
        print(f'Saved: {filename} of shape {df.shape} pairs: {c0} and {c1}')
        i += 1

def pair_clusters(dist_matrix):
    distances_copy = copy.deepcopy(dist_matrix)
    pairs = []
    # cycle
    while distances_copy.size > 0:
        # Find the minimum value and its column index
        min_value = np.min(distances_copy)
        min_col_index = np.argmin(np.min(distances_copy, axis=0))
        min_row_index = np.argmin(distances_copy[:, min_col_index])

        # identify the real position 
        ind = np.where(dist_matrix == min_value) #print("Minimum value:", min_value)#print("Column index of minimum value:", ind[1])#print("Row index of minimum value:", ind[0])

        # record pairing 
        pairs.append((ind[1].item(0), ind[0].item(0)))  # (cluster_{min_col_index}_0, cluster_{min_row_index}_1)

        # remove the paired clusters from further consideration
        distances_copy = np.delete(distances_copy, min_row_index, axis=0)  # remove row
        distances_copy = np.delete(distances_copy, min_col_index, axis=1)  # remove column

    return pairs

cluster_by_class_split(df_train, N, file_prefix='df_diabetes')
cluster_by_class_split(df_train_breast, N, file_prefix='df_breast')

Saved: df_diabetes_2cluster_test.csv of shape (10604, 22)
N: 3
Saved: df_diabetes_2cluster_1.csv of shape (8179, 22) pairs: 0 and 1
Saved: df_diabetes_2cluster_2.csv of shape (46463, 22) pairs: 1 and 0
Saved: df_diabetes_2cluster_3.csv of shape (5446, 22) pairs: 2 and 2
Saved: df_breast_2cluster_test.csv of shape (86, 32)
N: 3
Saved: df_breast_2cluster_1.csv of shape (90, 32) pairs: 0 and 1
Saved: df_breast_2cluster_2.csv of shape (229, 32) pairs: 2 and 0
Saved: df_breast_2cluster_3.csv of shape (164, 32) pairs: 1 and 2


In [45]:
# N institutions - clusters _ OLD VERSION
N = 3

def cluster_split(df, N, file_prefix='df_diabetes'):
    """
    Splits a DataFrame into N clusters and saves each cluster as a CSV file.

    Parameters:
    df (pd.DataFrame): The DataFrame to cluster.
    N (int): Number of clusters to form.
    file_prefix (str): Prefix for the output file names.
    """

    # Leave out 5% for testing
    df_train, df_test = train_test_split(df, test_size=0.15, random_state=1)
    df_test.to_csv(file_prefix + '_cluster_test.csv', index=False)
    print(f'Saved: {file_prefix}_cluster_test.csv of shape {df_test.shape}')

    # Perform KMeans clustering
    kmeans = KMeans(n_clusters=N, random_state=1)
    clusters = kmeans.fit_predict(df_train)

    # Split the DataFrame based on clusters
    for i in range(N):
        cluster_df = df_train[clusters == i]
        filename = f'{file_prefix}_cluster_{i+1}.csv'
        cluster_df.to_csv(filename, index=False)
        print(f'Saved: {filename} of shape {cluster_df.shape}')


cluster_split(df_train, N, file_prefix='df_diabetes')
cluster_split(df_train_breast, N, file_prefix='df_breast')


Saved: df_diabetes_cluster_test.csv of shape (10604, 22)
Saved: df_diabetes_cluster_1.csv of shape (12405, 22)
Saved: df_diabetes_cluster_2.csv of shape (9835, 22)
Saved: df_diabetes_cluster_3.csv of shape (37848, 22)
Saved: df_breast_cluster_test.csv of shape (86, 32)
Saved: df_breast_cluster_1.csv of shape (364, 32)
Saved: df_breast_cluster_2.csv of shape (18, 32)
Saved: df_breast_cluster_3.csv of shape (101, 32)


#### Double-Check 

In [46]:
# read the data
print("Diabetes dataset")
df1 = pd.read_csv('df_diabetes_2cluster_1.csv')
df2 = pd.read_csv('df_diabetes_2cluster_2.csv')
df3 = pd.read_csv('df_diabetes_2cluster_3.csv')
print(f"Total shape 2cluster: {df1.shape[0] + df2.shape[0] + df3.shape[0]},{df1.shape[1]}")

df1 = pd.read_csv('df_diabetes_random_1.csv')
df2 = pd.read_csv('df_diabetes_random_2.csv')
df3 = pd.read_csv('df_diabetes_random_3.csv')
print(f"Total shape random: {df1.shape[0] + df2.shape[0] + df3.shape[0]},{df1.shape[1]}")

df1 = pd.read_csv('df_diabetes_cluster_1.csv')
df2 = pd.read_csv('df_diabetes_cluster_2.csv')
df3 = pd.read_csv('df_diabetes_cluster_3.csv')
print(f"Total shape cluster: {df1.shape[0] + df2.shape[0] + df3.shape[0]},{df1.shape[1]}")

# print the shape of the data
df1 = pd.read_csv('df_diabetes_random_test.csv')
df2 = pd.read_csv('df_diabetes_2cluster_test.csv')
df3 = pd.read_csv('df_diabetes_cluster_test.csv')
print(f"Test shape random: {df1.shape}")
print(f"Test shape 2cluster: {df2.shape}")
print(f"Test shape cluster: {df3.shape}")

# breast dataset
print("\nBreast cancer dataset")
df1 = pd.read_csv('df_breast_random_1.csv')
df2 = pd.read_csv('df_breast_random_2.csv')
df3 = pd.read_csv('df_breast_random_3.csv')
print(f"Total shape random: {df1.shape[0] + df2.shape[0] + df3.shape[0]},{df1.shape[1]}")

df1 = pd.read_csv('df_breast_2cluster_1.csv')
df2 = pd.read_csv('df_breast_2cluster_2.csv')
df3 = pd.read_csv('df_breast_2cluster_3.csv')
print(f"Total shape 2cluster: {df1.shape[0] + df2.shape[0] + df3.shape[0]},{df1.shape[1]}")

df1 = pd.read_csv('df_breast_cluster_1.csv')
df2 = pd.read_csv('df_breast_cluster_2.csv')
df3 = pd.read_csv('df_breast_cluster_3.csv')
print(f"Total shape cluster: {df1.shape[0] + df2.shape[0] + df3.shape[0]},{df1.shape[1]}")

# print the shape of the data
df1 = pd.read_csv('df_breast_random_test.csv')
df2 = pd.read_csv('df_breast_2cluster_test.csv')
df3 = pd.read_csv('df_breast_cluster_test.csv')
print(f"Test shape random: {df1.shape}")
print(f"Test shape 2cluster: {df2.shape}")
print(f"Test shape cluster: {df3.shape}")


Diabetes dataset
Total shape 2cluster: 60088,22
Total shape random: 60088,22
Total shape cluster: 60088,22
Test shape random: (10604, 22)
Test shape 2cluster: (10604, 22)
Test shape cluster: (10604, 22)

Breast cancer dataset
Total shape random: 483,32
Total shape 2cluster: 483,32
Total shape cluster: 483,32
Test shape random: (86, 32)
Test shape 2cluster: (86, 32)
Test shape cluster: (86, 32)


In [47]:
df1.head()

Unnamed: 0.1,Unnamed: 0,radius1,texture1,perimeter1,area1,smoothness1,compactness1,concavity1,concave_points1,symmetry1,...,texture3,perimeter3,area3,smoothness3,compactness3,concavity3,concave_points3,symmetry3,fractal_dimension3,Labels
0,556,10.16,19.59,64.73,311.7,0.1003,0.07504,0.005025,0.01116,0.1791,...,22.88,67.88,347.3,0.1265,0.12,0.01005,0.02232,0.2262,0.06742,0
1,273,9.742,15.67,61.5,289.9,0.09037,0.04689,0.01103,0.01407,0.2081,...,20.88,68.09,355.2,0.1467,0.0937,0.04043,0.05159,0.2841,0.08175,0
2,256,19.55,28.77,133.6,1207.0,0.0926,0.2063,0.1784,0.1144,0.1893,...,36.27,178.6,1926.0,0.1281,0.5329,0.4251,0.1941,0.2818,0.1005,1
3,168,17.47,24.68,116.1,984.6,0.1049,0.1603,0.2159,0.1043,0.1538,...,32.33,155.3,1660.0,0.1376,0.383,0.489,0.1721,0.216,0.093,1
4,340,14.42,16.54,94.15,641.2,0.09751,0.1139,0.08007,0.04223,0.1912,...,21.51,111.4,862.1,0.1294,0.3371,0.3755,0.1414,0.3053,0.08764,0


In [48]:
df2.head()

Unnamed: 0.1,Unnamed: 0,radius1,texture1,perimeter1,area1,smoothness1,compactness1,concavity1,concave_points1,symmetry1,...,texture3,perimeter3,area3,smoothness3,compactness3,concavity3,concave_points3,symmetry3,fractal_dimension3,Labels
0,421,14.69,13.98,98.22,656.1,0.1031,0.1836,0.145,0.063,0.2086,...,18.34,114.1,809.2,0.1312,0.3635,0.3219,0.1108,0.2827,0.09208,0
1,47,13.17,18.66,85.98,534.6,0.1158,0.1231,0.1226,0.0734,0.2128,...,27.95,102.8,759.4,0.1786,0.4166,0.5006,0.2088,0.39,0.1179,1
2,292,12.95,16.02,83.14,513.7,0.1005,0.07943,0.06155,0.0337,0.173,...,19.93,88.81,585.4,0.1483,0.2068,0.2241,0.1056,0.338,0.09584,0
3,186,18.31,18.58,118.6,1041.0,0.08588,0.08468,0.08169,0.05814,0.1621,...,26.36,139.2,1410.0,0.1234,0.2445,0.3538,0.1571,0.3206,0.06938,1
4,414,15.13,29.81,96.71,719.5,0.0832,0.04605,0.04686,0.02739,0.1852,...,36.91,110.1,931.4,0.1148,0.09866,0.1547,0.06575,0.3233,0.06165,1


# Breast dataset


In [5]:
import torch
import torch.nn as nn
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import os
import sys
sys.path.append(os.path.join(os.path.dirname("utils.py"), '..'))
import utils

# Set random seed and Use 'cuda' GPU
device = utils.check_gpu(manual_seed=True)

X = pd.read_csv('X_breast.csv', index_col=0)
Y = pd.read_csv('y_breast.csv', index_col=0)
X.head()

MPS is available


Unnamed: 0,radius1,texture1,perimeter1,area1,smoothness1,compactness1,concavity1,concave_points1,symmetry1,fractal_dimension1,...,radius3,texture3,perimeter3,area3,smoothness3,compactness3,concavity3,concave_points3,symmetry3,fractal_dimension3
0,17.99,10.38,122.8,1001.0,0.1184,0.2776,0.3001,0.1471,0.2419,0.07871,...,25.38,17.33,184.6,2019.0,0.1622,0.6656,0.7119,0.2654,0.4601,0.1189
1,20.57,17.77,132.9,1326.0,0.08474,0.07864,0.0869,0.07017,0.1812,0.05667,...,24.99,23.41,158.8,1956.0,0.1238,0.1866,0.2416,0.186,0.275,0.08902
2,19.69,21.25,130.0,1203.0,0.1096,0.1599,0.1974,0.1279,0.2069,0.05999,...,23.57,25.53,152.5,1709.0,0.1444,0.4245,0.4504,0.243,0.3613,0.08758
3,11.42,20.38,77.58,386.1,0.1425,0.2839,0.2414,0.1052,0.2597,0.09744,...,14.91,26.5,98.87,567.7,0.2098,0.8663,0.6869,0.2575,0.6638,0.173
4,20.29,14.34,135.1,1297.0,0.1003,0.1328,0.198,0.1043,0.1809,0.05883,...,22.54,16.67,152.2,1575.0,0.1374,0.205,0.4,0.1625,0.2364,0.07678


In [6]:
# Map 'M' to 1 and 'B' to 0
Y['Diagnosis'] = Y['Diagnosis'].map({'M': 1, 'B': 0})
Y.head()

Unnamed: 0,Diagnosis
0,1
1,1
2,1
3,1
4,1


In [10]:
# Use 10 % of total data as Test set and the rest as (Train + Validation) set 
X_train_val, X_test, y_train_val, y_test = train_test_split(X, Y, test_size=0.1)

# Use 20 % of (Train + Validation) set as Validation set
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.2)

In [11]:
# Normalize the data
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train.values)
X_val = scaler.transform(X_val.values)

# Convert to PyTorch tensor
X_train = torch.tensor(X_train).float().to(device)
X_val = torch.tensor(X_val).float().to(device)

y_train = torch.LongTensor(y_train.values).to(device).squeeze()
y_val = torch.LongTensor(y_val.values).to(device).squeeze()

# Hyperparameter
learning_rate = 1e-1   ### After tuning?
n_epochs = 500
drop_prob = 0.3

def randomize_class(a, include=True):
        # Get the number of classes and the number of samples
        num_classes = a.size(1)
        num_samples = a.size(0)

        # Generate random indices for each row to place 1s, excluding the original positions
        random_indices = torch.randint(0, num_classes, (num_samples,)).to(a.device)

        # Ensure that the generated indices are different from the original positions
        # TODO we inclue also same label to make sure that every class is represented 
        if not include:
            original_indices = a.argmax(dim=1)
            random_indices = torch.where(random_indices == original_indices, (random_indices + 1) % num_classes, random_indices)

        # Create a second tensor with 1s at the random indices
        b = torch.zeros_like(a)
        b[torch.arange(num_samples), random_indices] = 1
        return b

In [14]:
# Model
EPS = 1e-9
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        input_dim = 30
        self.fc1 = nn.Linear(input_dim, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 256)
        self.fc4 = nn.Linear(256, 64)
        self.fc5 = nn.Linear(64, 2)
        self.concept_mean_predictor = torch.nn.Sequential(torch.nn.Linear(input_dim, 128), torch.nn.LeakyReLU(), torch.nn.Linear(128, 20))  # 20 is fine?
        self.concept_var_predictor = torch.nn.Sequential(torch.nn.Linear(input_dim, 128), torch.nn.LeakyReLU(), torch.nn.Linear(128, 20))
        self.decoder = torch.nn.Sequential(torch.nn.Linear(20, 128), torch.nn.LeakyReLU(), torch.nn.Linear(128, input_dim))
        self.concept_mean_z3_predictor = torch.nn.Sequential(torch.nn.Linear(20 + input_dim + 2, 128), torch.nn.LeakyReLU(), torch.nn.Linear(128, 20))
        self.concept_var_z3_predictor = torch.nn.Sequential(torch.nn.Linear(20 + input_dim + 2, 128), torch.nn.LeakyReLU(), torch.nn.Linear(128, 20))
        self.concept_mean_qz3_predictor = torch.nn.Sequential(torch.nn.Linear(20 + input_dim + 4 + input_dim, 128), torch.nn.LeakyReLU(), torch.nn.Linear(128, 20))
        self.concept_var_qz3_predictor = torch.nn.Sequential(torch.nn.Linear(20 + input_dim + 4 + input_dim, 128), torch.nn.LeakyReLU(), torch.nn.Linear(128, 20))
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)
#         self.mask = torch.nn.Parameter(torch.Tensor([0,0,0,0,0,1,0,0,0,0,
#                                   0,0,0,0,0,0,0,1,1,1,1]), requires_grad=False)
        mean = 0.5
        self.mask = torch.nn.Parameter(torch.Tensor([mean,mean,mean,mean,mean,mean,mean,mean,mean,1,mean,mean,mean,mean,
                                  mean,mean,mean,mean,mean,mean,mean,1,1,1,1,1,1,1,1,1]), requires_grad=False)
        std = 1
        self.mask_std = torch.nn.Parameter(torch.Tensor([std,std,std,std,std,std,std,std,std,0.00001,std,std,std,std,
                                      std,std,std,std,std,std,std,0.00001,0.00001,0.00001,0.00001,0.00001,0.00001,0.00001,0.00001,0.00001]), requires_grad=False)
        self.binary_feature = torch.nn.Parameter(torch.Tensor(
                            [1,1,1,0,1,1,1,1,1,1,1,1,1,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0]).bool(), requires_grad=False)
        self.mask.to(device)
        self.mask_std.to(device)
        
        self.distr_mask = torch.distributions.Normal(self.mask, self.mask_std)
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight.data)
                
    def get_mask(self, x):
        mask = torch.rand(x.shape).to(device)
        return mask
                
    def multiple_cf(self, x, include=False, n=1):
        out = self.fc1(x)
        out = self.relu(out)
        
        out = self.fc2(out)
        out = self.relu(out)
        
        out = self.fc3(out)
        out = self.relu(out)
        
        out = self.fc4(out)
        out = self.relu(out)
        
        out = self.fc5(out)
        
        z2_mu = self.concept_mean_predictor(x)
        z2_log_var = self.concept_var_predictor(x)
        z2_sigma = torch.exp(z2_log_var / 2) + EPS
        qz2_x = torch.distributions.Normal(z2_mu, z2_sigma)
        z2 = qz2_x.rsample()
        p_z2 = torch.distributions.Normal(torch.zeros_like(qz2_x.mean), torch.ones_like(qz2_x.mean))

        x_reconstructed = self.decoder(z2)
#         x_reconstructed = torch.clamp(x_reconstructed, min=0, max=1)
#         x_reconstructed[:, self.binary_feature] = torch.sigmoid(x_reconstructed[:, self.binary_feature])
#         x_reconstructed[:, ~self.binary_feature] = torch.clamp(x_reconstructed[:, ~self.binary_feature], min=0, max=1)
        
        y_prime = randomize_class((out).float(), include=include)
        
        z2_c_y_y_prime = torch.cat((z2, x, out, y_prime, ), dim=1)
        z3_mu = self.concept_mean_qz3_predictor(z2_c_y_y_prime)
        z3_log_var = self.concept_var_qz3_predictor(z2_c_y_y_prime)
        z3_sigma = torch.exp(z3_log_var / 2) + EPS
        qz3_z2_c_y_y_prime = torch.distributions.Normal(z3_mu, z3_sigma)
        z3 = qz3_z2_c_y_y_prime.rsample((n,))
        
        z2_c_y = torch.cat((z2, x, out), dim=1)
        z3_mu = self.concept_mean_z3_predictor(z2_c_y)
        z3_log_var = self.concept_var_z3_predictor(z2_c_y)
        z3_sigma = torch.exp(z3_log_var / 2) + EPS
        pz3_z2_c_y = torch.distributions.Normal(z3_mu, z3_sigma)
        
        x_prime_reconstructed = self.decoder(z3)
#         x_prime_reconstructed = torch.clamp(x_prime_reconstructed, min=0, max=1)
#         x_prime_reconstructed[:, :, self.binary_feature] = torch.sigmoid(x_prime_reconstructed[:, :, self.binary_feature])
#         x_prime_reconstructed[:, :, ~self.binary_feature] = torch.clamp(x_prime_reconstructed[:, :, ~self.binary_feature], min=0, max=1)
        
        x_prime_reconstructed = x_prime_reconstructed * (1 - self.mask) + (x * self.mask)
        out2 = self.fc1(x_prime_reconstructed)
        out2 = self.relu(out2)
        
        out2 = self.fc2(out2)
        out2 = self.relu(out2)
        
        out2 = self.fc3(out2)
        out2 = self.relu(out2)
        
        out2 = self.fc4(out2)
        out2 = self.relu(out2)
        
        out2 = self.fc5(out2)
        
        return out, x_reconstructed, qz2_x, p_z2, out2, x_prime_reconstructed, qz3_z2_c_y_y_prime, pz3_z2_c_y, y_prime
                
    def forward(self, x, include=True, n=1, mask_init=None):
        out = self.fc1(x)
        out = self.relu(out)
        
        out = self.fc2(out)
        out = self.relu(out)
        
        out = self.fc3(out)
        out = self.relu(out)
        
        out = self.fc4(out)
        out = self.relu(out)
        
        out = self.fc5(out)
        
        z2_mu = self.concept_mean_predictor(x)
        z2_log_var = self.concept_var_predictor(x)
        z2_sigma = torch.exp(z2_log_var / 2) + EPS
        qz2_x = torch.distributions.Normal(z2_mu, z2_sigma)
        z2 = qz2_x.rsample()
        p_z2 = torch.distributions.Normal(torch.zeros_like(qz2_x.mean), torch.ones_like(qz2_x.mean))

        x_reconstructed = self.decoder(z2)
        x_reconstructed = F.hardtanh(x_reconstructed, -0.1, 1.1)
#         x_reconstructed = torch.clamp(x_reconstructed, min=0, max=1)
#         x_reconstructed[:, self.binary_feature] = torch.sigmoid(x_reconstructed[:, self.binary_feature])
#         x_reconstructed[:, ~self.binary_feature] = torch.clamp(x_reconstructed[:, ~self.binary_feature], min=0, max=1)
        
        y_prime = randomize_class((out).float(), include=include)
        
        if self.training:
#             mask = self.distr_mask.sample((y_prime.shape[0],))
#             mask.to('cuda')
#             mask = torch.clip(mask, min=0, max=1)
            mask = self.get_mask(x)
#             mask = self.mask
        else:
            if mask_init is not None:
                mask = mask_init
                mask = mask.to(device)
                mask = mask.repeat(y_prime.shape[0], 1)
            else:
                mask = self.get_mask(x)
#                 mask = self.distr_mask.sample((y_prime.shape[0],))
#                 mask = torch.clip(mask, min=0, max=1)
#                 mask = self.mask
#         mask = mask.repeat(y_prime.shape[0], 1)
        
        z2_c_y_y_prime = torch.cat((z2, x, out, y_prime, mask), dim=1)
        z3_mu = self.concept_mean_qz3_predictor(z2_c_y_y_prime)
        z3_log_var = self.concept_var_qz3_predictor(z2_c_y_y_prime)
        z3_sigma = torch.exp(z3_log_var / 2) + EPS
        qz3_z2_c_y_y_prime = torch.distributions.Normal(z3_mu, z3_sigma)
        z3 = qz3_z2_c_y_y_prime.rsample((n,))
        
        if n == 1: 
            z3 = z3.squeeze(0)
            
            
        z2_c_y = torch.cat((z2, x, out), dim=1)
        z3_mu = self.concept_mean_z3_predictor(z2_c_y)
        z3_log_var = self.concept_var_z3_predictor(z2_c_y)
        z3_sigma = torch.exp(z3_log_var / 2) + EPS
        pz3_z2_c_y = torch.distributions.Normal(z3_mu, z3_sigma)
        
        x_prime_reconstructed = self.decoder(z3)
        x_prime_reconstructed = F.hardtanh(x_prime_reconstructed, -0.1, 1.1)
#         x_prime_reconstructed = torch.clamp(x_prime_reconstructed, min=0, max=1)
#         x_prime_reconstructed[:, self.binary_feature] = torch.sigmoid(x_prime_reconstructed[:, self.binary_feature])
#         x_prime_reconstructed[:, ~self.binary_feature] = torch.clamp(x_prime_reconstructed[:, ~self.binary_feature], min=0, max=1)
        
        x_prime_reconstructed = x_prime_reconstructed * (1 - mask) + (x * mask)
        
        if not self.training:
#             print(x_prime_reconstructed[:2, 3])
            x_prime_reconstructed = torch.clamp(x_prime_reconstructed, min=0, max=1)
            x_prime_reconstructed = scaler.inverse_transform(x_prime_reconstructed.detach().cpu().numpy())
            x_prime_reconstructed = np.round(x_prime_reconstructed)
            x_prime_reconstructed = scaler.transform(x_prime_reconstructed)
            x_prime_reconstructed = torch.Tensor(x_prime_reconstructed).to(device)
            
        out2 = self.fc1(x_prime_reconstructed)
        out2 = self.relu(out2)
        
        out2 = self.fc2(out2)
        out2 = self.relu(out2)
        
        out2 = self.fc3(out2)
        out2 = self.relu(out2)
        
        out2 = self.fc4(out2)
        out2 = self.relu(out2)
        
        out2 = self.fc5(out2)
    
        return out, x_reconstructed, qz2_x, p_z2, out2, x_prime_reconstructed, qz3_z2_c_y_y_prime, pz3_z2_c_y, y_prime

In [17]:
import torch.nn.functional as F
import numpy as np

model = Net().to(device)

# Optimizer and Loss function
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
loss_fn = nn.CrossEntropyLoss()
mse_loss = nn.MSELoss()

train_loss = list()
val_loss = list()

learning_rate = 1e-2 
n_epochs = 1000
best_model_round = 0
best_val = 0

for epoch in range(1, n_epochs+1):
    model.train()
    H, x_reconstructed, q, p, H2, x_prime, q_prime, p_prime, y_prime = model(X_train)
    loss_task = loss_fn(H, y_train)
    loss_kl = torch.distributions.kl_divergence(p, q).mean()
    loss_rec = F.mse_loss(x_reconstructed, X_train, reduction='mean')
    loss_validity = loss_fn(H2, y_prime.argmax(dim=-1))
    loss_kl2 = torch.distributions.kl_divergence(p_prime, q_prime).mean() 
    loss_p_d = torch.distributions.kl_divergence(p, p_prime).mean() 
    loss_q_d = torch.distributions.kl_divergence(q, q_prime).mean()
    
    
    #loss_h_d = ((torch.abs(X_train - x_prime)) * mask).mean()
    
    lambda1 = 2 # loss parameter for kl divergence p-q and p_prime-q_prime
    lambda2 = 10 # loss parameter for input reconstruction
    lambda3 = 0.5 # loss parameter for validity of counterfactuals
    lambda4 = 0 # loss parameter for creating counterfactuals that are closer to the initial input
    #             increasing it, decrease the validity of counterfactuals. It is expected and makes sense.
    #             It is a design choice to have better counterfactuals or closer counterfactuals.
    loss = loss_task + lambda1*loss_kl + lambda2*loss_rec + lambda3*loss_validity + lambda1*loss_kl2 + loss_p_d + lambda4*loss_q_d #+ 0*loss_h_d
    train_loss.append(loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    acc = (torch.argmax(H, dim=1) == y_train).float().mean().item()
    acc_prime = (torch.argmax(H2, dim=1) == y_prime.argmax(dim=-1)).float().mean().item()
    model.eval()
    with torch.no_grad():
        H_val, x_reconstructed, q, p, H2, x_prime, q_prime, p_prime, y_prime = model(X_val, False)
        loss_val = loss_fn(H_val, y_val)
        acc_val = (torch.argmax(H_val, dim=1) == y_val).float().mean().item()
        acc_prime_val = (torch.argmax(H2, dim=1) == y_prime.argmax(dim=-1)).float().mean().item()
        
        val_loss.append(loss_val.item())
        if acc_prime_val > best_val:
            print(f"Saved model at epoch: {epoch} with Validity: {acc_prime_val}")
            best_val = acc_prime_val
            best_model_round = epoch
            torch.save(model.state_dict(), f"checkpoints/model_round.pth")
        
    if epoch % 50 == 0:
        print('Epoch {:4d} / {}, Cost : {:.4f}, Acc : {:.2f} %, Validity : {:.2f} %, Val Cost : {:.4f}, Val Acc : {:.2f} % , Val Validity : {:.2f} %'.format(
            epoch, n_epochs, loss.item(), acc*100, acc_prime*100, loss_val.item(), acc_val*100, acc_prime_val*100))
        
model = Net().to(device)
model.load_state_dict(torch.load(f"checkpoints/model_round.pth"))

Saved model at epoch: 1 with Validity: 0.4757281541824341
Epoch   50 / 1000, Cost : 1.4673, Acc : 96.33 %, Validity : 52.32 %, Val Cost : 0.2249, Val Acc : 93.20 % , Val Validity : 31.07 %
Epoch  100 / 1000, Cost : 1.1949, Acc : 97.80 %, Validity : 53.55 %, Val Cost : 0.1733, Val Acc : 97.09 % , Val Validity : 33.98 %
Epoch  150 / 1000, Cost : 0.9952, Acc : 98.53 %, Validity : 63.33 %, Val Cost : 0.1391, Val Acc : 97.09 % , Val Validity : 34.95 %
Epoch  200 / 1000, Cost : 0.8394, Acc : 98.29 %, Validity : 80.44 %, Val Cost : 0.1074, Val Acc : 97.09 % , Val Validity : 36.89 %
Epoch  250 / 1000, Cost : 0.7544, Acc : 98.53 %, Validity : 90.71 %, Val Cost : 0.1139, Val Acc : 96.12 % , Val Validity : 37.86 %
Epoch  300 / 1000, Cost : 0.7103, Acc : 98.53 %, Validity : 90.71 %, Val Cost : 0.1044, Val Acc : 96.12 % , Val Validity : 39.81 %
Saved model at epoch: 329 with Validity: 0.5048543810844421
Epoch  350 / 1000, Cost : 0.6811, Acc : 98.78 %, Validity : 91.93 %, Val Cost : 0.1013, Val Acc 

<All keys matched successfully>