In [1]:
from numpy.core.numeric import True_
import sys
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, TensorDataset
from sklearn.linear_model import LogisticRegression
#from matplotlib import pyplot as plt
#from PIL import Image, ImageDraw
#from IPython import display
#from torchvision import transforms

In [2]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device

device(type='mps')

In [3]:
data = 'blood_alcohol'
df = pd.read_csv('../datasets/'+data+'_sc.csv')

In [4]:
df.head()

Unnamed: 0,gender,meal,units_consumed,weight,duration,class
0,0.829965,0.64502,0.272727,0.513158,0.691729,0
1,0.829965,0.64502,0.363636,0.447368,0.601504,1
2,0.829965,0.64502,0.363636,0.460526,0.330827,1
3,0.829965,0.68598,0.272727,0.565789,0.669173,0
4,0.829965,0.64502,0.636364,0.473684,0.413534,1


In [5]:
target = 'class'
y = df[target].values
del df[target]
x = df.values

In [6]:
x.shape[1]

5

In [7]:
y

array([0, 1, 1, ..., 1, 0, 1])

In [8]:
# train the classifier to be explained -> Logistic Regression Model (can be replaced with a MLP)
# clf = LogisticRegression(max_iter=1000, fit_intercept=False, class_weight='balanced')
# clf.fit(x, y)

In [9]:
# Convert numpy arrays to PyTorch tensors
x_train_tensor = torch.FloatTensor(x).to(device)
y_train_tensor = torch.LongTensor(y).to(device)

In [10]:
# Create a TensorDataset
dataset = TensorDataset(x_train_tensor, y_train_tensor)

In [11]:
# Create a DataLoader
dl_batch_size = 32  # Set your desired batch size
train_dataloader = DataLoader(dataset, batch_size=dl_batch_size, shuffle=True)

In [13]:
num_epochs = 500

In [14]:
latent_dims = 4
# capacity = no. of hidden neurons in a layer
capacity = 6
variational_beta = 1

In [15]:
class Encoder(nn.Module):
    def __init__(self, input_dim):
        super(Encoder, self).__init__()
        c = capacity
        self.fc1 = nn.Linear(in_features=input_dim, out_features=c*2)
        self.fc2 = nn.Linear(in_features=c*2, out_features=latent_dims)
        self.fc3 = nn.Linear(in_features=c*2, out_features=latent_dims)
            
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x_mu = self.fc2(x)
        x_logvar = self.fc3(x)
        return x_mu, x_logvar

class Decoder(nn.Module):
    def __init__(self, output_dim):
        super(Decoder, self).__init__()
        c = capacity
        self.fc = nn.Linear(in_features=latent_dims, out_features=c*2)
        self.fc_output = nn.Linear(in_features=c*2, out_features=output_dim)
            
    def forward(self, x):
        x = F.relu(self.fc(x))
        x_recon = self.fc_output(x)
        return x_recon

class VariationalAutoencoder(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = Encoder(input_dim)
        self.decoder = Decoder(output_dim)
    
    def forward(self, x):
        latent_mu, latent_logvar = self.encoder(x)
        latent = self.latent_sample(latent_mu, latent_logvar)
        x_recon = self.decoder(latent)
        return x_recon, latent_mu, latent_logvar
    
    def latent_sample(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = torch.empty_like(std).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu
    
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')  # Mean Squared Error loss for tabular data
    kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return recon_loss + variational_beta * kldivergence

In [16]:
# Train VAE
print(device)
vae = VariationalAutoencoder(input_dim=x.shape[1], output_dim=x.shape[1])
vae = vae.to(device)

optimizer = torch.optim.Adam(params=vae.parameters(), lr=1e-3, weight_decay=1e-5)

train_loss_avg = []

train_vae = True

if train_vae == False:   
    filename = 'blood_alcohol_vae.pth'
    vae.load_state_dict(torch.load(filename))
    vae.eval()
    print('done')

else:
    for epoch in range(num_epochs):
        vae.train()
        train_loss_avg.append(0)
        num_batches = 0

        for x_batch, _ in train_dataloader:

            x_batch = x_batch.to(device)

            # Forward pass
            x_recon, latent_mu, latent_logvar = vae(x_batch)
            loss = vae_loss(x_recon, x_batch, latent_mu, latent_logvar)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss_avg[-1] += loss.item()
            num_batches += 1

        train_loss_avg[-1] /= num_batches
        print('Epoch [%d / %d] average loss: %f' % (epoch + 1, num_epochs, train_loss_avg[-1]))

    # Save the model
    model_name = data+'_vae.pth'
    torch.save(vae.state_dict(), model_name)

mps
Epoch [1 / 500] average loss: 42.142004
Epoch [2 / 500] average loss: 19.496738
Epoch [3 / 500] average loss: 11.248014
Epoch [4 / 500] average loss: 7.117998
Epoch [5 / 500] average loss: 5.241925
Epoch [6 / 500] average loss: 4.707400
Epoch [7 / 500] average loss: 4.303568
Epoch [8 / 500] average loss: 4.105686
Epoch [9 / 500] average loss: 3.974809
Epoch [10 / 500] average loss: 3.836573
Epoch [11 / 500] average loss: 3.778034
Epoch [12 / 500] average loss: 3.729431
Epoch [13 / 500] average loss: 3.721640
Epoch [14 / 500] average loss: 3.647621
Epoch [15 / 500] average loss: 3.636164
Epoch [16 / 500] average loss: 3.608647
Epoch [17 / 500] average loss: 3.624574
Epoch [18 / 500] average loss: 3.583366
Epoch [19 / 500] average loss: 3.583718
Epoch [20 / 500] average loss: 3.574937
Epoch [21 / 500] average loss: 3.544930
Epoch [22 / 500] average loss: 3.561048
Epoch [23 / 500] average loss: 3.568011
Epoch [24 / 500] average loss: 3.536254
Epoch [25 / 500] average loss: 3.530812
Ep

Epoch [204 / 500] average loss: 3.482290
Epoch [205 / 500] average loss: 3.485061
Epoch [206 / 500] average loss: 3.487126
Epoch [207 / 500] average loss: 3.483248
Epoch [208 / 500] average loss: 3.485379
Epoch [209 / 500] average loss: 3.485599
Epoch [210 / 500] average loss: 3.486572
Epoch [211 / 500] average loss: 3.479945
Epoch [212 / 500] average loss: 3.487856
Epoch [213 / 500] average loss: 3.480717
Epoch [214 / 500] average loss: 3.484770
Epoch [215 / 500] average loss: 3.476067
Epoch [216 / 500] average loss: 3.483261
Epoch [217 / 500] average loss: 3.483432
Epoch [218 / 500] average loss: 3.489142
Epoch [219 / 500] average loss: 3.487213
Epoch [220 / 500] average loss: 3.483817
Epoch [221 / 500] average loss: 3.487183
Epoch [222 / 500] average loss: 3.481345
Epoch [223 / 500] average loss: 3.477669
Epoch [224 / 500] average loss: 3.483326
Epoch [225 / 500] average loss: 3.482565
Epoch [226 / 500] average loss: 3.487175
Epoch [227 / 500] average loss: 3.492199
Epoch [228 / 500

Epoch [404 / 500] average loss: 3.482475
Epoch [405 / 500] average loss: 3.480536
Epoch [406 / 500] average loss: 3.483107
Epoch [407 / 500] average loss: 3.477901
Epoch [408 / 500] average loss: 3.486045
Epoch [409 / 500] average loss: 3.481052
Epoch [410 / 500] average loss: 3.489034
Epoch [411 / 500] average loss: 3.479826
Epoch [412 / 500] average loss: 3.483731
Epoch [413 / 500] average loss: 3.479947
Epoch [414 / 500] average loss: 3.488205
Epoch [415 / 500] average loss: 3.482322
Epoch [416 / 500] average loss: 3.485060
Epoch [417 / 500] average loss: 3.482186
Epoch [418 / 500] average loss: 3.482924
Epoch [419 / 500] average loss: 3.483619
Epoch [420 / 500] average loss: 3.485049
Epoch [421 / 500] average loss: 3.483232
Epoch [422 / 500] average loss: 3.480690
Epoch [423 / 500] average loss: 3.481406
Epoch [424 / 500] average loss: 3.485642
Epoch [425 / 500] average loss: 3.481554
Epoch [426 / 500] average loss: 3.485291
Epoch [427 / 500] average loss: 3.486458
Epoch [428 / 500

In [17]:
c2c_latent_dims = 3
c2c_capacity = 4
c2c_variational_beta = 1
label_size = 2  # Assuming two classes

In [18]:
class c2c_Encoder(nn.Module):
    def __init__(self):
        super(c2c_Encoder, self).__init__()
        self.fc1 = nn.Linear(in_features=latent_dims + label_size, out_features=c2c_capacity)
        self.fc_mu = nn.Linear(in_features=c2c_capacity, out_features=c2c_latent_dims)
        self.fc_logvar = nn.Linear(in_features=c2c_capacity, out_features=c2c_latent_dims)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x_mu = self.fc_mu(x)
        x_logvar = self.fc_logvar(x)
        return x_mu, x_logvar

class c2c_Decoder(nn.Module):
    def __init__(self):
        super(c2c_Decoder, self).__init__()
        self.fc2 = nn.Linear(in_features=c2c_latent_dims + label_size, out_features=c2c_capacity)
        self.fc1 = nn.Linear(in_features=c2c_capacity, out_features=latent_dims)
            
    def forward(self, x):
        x = self.fc2(x)
        x = self.fc1(x)
        return x
    
class c2c_VariationalAutoencoder(nn.Module):
    def __init__(self):
        super(c2c_VariationalAutoencoder, self).__init__()
        self.encoder = c2c_Encoder()
        self.decoder = c2c_Decoder()
    
    def forward(self, x):
        x,label = torch.split(x,(latent_dims,label_size),dim = 1)
        c2c_latent_mu, c2c_latent_logvar = self.encoder(torch.cat((x,label), dim=1))
        c2c_latent = self.latent_sample(c2c_latent_mu, c2c_latent_logvar)
        x_recon = self.decoder(torch.cat((c2c_latent,label), dim=1))
        return x_recon, c2c_latent_mu, c2c_latent_logvar
    
    def latent_sample(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = torch.empty_like(std).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu
    
def c2c_vae_loss(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x.view(-1, latent_dims), x.view(-1, latent_dims), reduction='sum')
    
    kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return recon_loss + c2c_variational_beta * kldivergence

In [19]:
print(device)
c2c_vae = c2c_VariationalAutoencoder()
c2c_vae = c2c_vae.to(device)

optimizer = torch.optim.Adam(params=c2c_vae.parameters(), lr=1e-3, weight_decay=1e-5)

train_loss_avg = []

train_c2c_vae = True

if train_c2c_vae == False:
    filename = 'blood_alcohol_c2c.pth'
    c2c_vae.load_state_dict(torch.load(filename))
    c2c_vae.eval()
    print('done')
else:
    for epoch in range(num_epochs):
        c2c_vae.train()
        train_loss_avg.append(0)
        num_batches = 0

        for i in range(dl_batch_size):

            x_batch1, labels1 = next(iter(train_dataloader))
            x_batch2, labels2 = next(iter(train_dataloader))

            x_batch1 = x_batch1.to(device)
            x_batch_recon1, latent_mu1, latent_logvar1 = vae(x_batch1)

            x_batch2 = x_batch2.to(device)
            x_batch_recon2, latent_mu2, latent_logvar2 = vae(x_batch2) 

            latent_diff = latent_mu1 - latent_mu2 
            label_diff = torch.cat((labels1.unsqueeze(1), labels2.unsqueeze(1)), dim=1)
            #print(label_diff)

            latent_diff = latent_diff.to(device)
            label_diff = label_diff.float().to(device)

            c2c_input = torch.cat((latent_diff,label_diff), dim=1)
            c2c_input = c2c_input.to(device)

            c2c_recon, c2c_latent_mu, c2c_latent_logvar = c2c_vae(c2c_input)
            loss = c2c_vae_loss(c2c_recon, latent_diff.detach(), c2c_latent_mu, c2c_latent_logvar)
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            train_loss_avg[-1] += loss.item()
            num_batches += 1

        train_loss_avg[-1] /= num_batches
        print('Epoch [%d / %d] average reconstruction error: %f' % (epoch + 1, num_epochs, train_loss_avg[-1]))

    # Save the model
    model_name = data+'_c2c.pth'
    torch.save(c2c_vae.state_dict(), model_name)

mps
Epoch [1 / 500] average reconstruction error: 35.686619
Epoch [2 / 500] average reconstruction error: 24.619245
Epoch [3 / 500] average reconstruction error: 17.011734
Epoch [4 / 500] average reconstruction error: 12.343631
Epoch [5 / 500] average reconstruction error: 8.651665
Epoch [6 / 500] average reconstruction error: 5.989851
Epoch [7 / 500] average reconstruction error: 3.629751
Epoch [8 / 500] average reconstruction error: 2.374440
Epoch [9 / 500] average reconstruction error: 1.467598
Epoch [10 / 500] average reconstruction error: 0.934400
Epoch [11 / 500] average reconstruction error: 0.664530
Epoch [12 / 500] average reconstruction error: 0.526600
Epoch [13 / 500] average reconstruction error: 0.403222
Epoch [14 / 500] average reconstruction error: 0.332474
Epoch [15 / 500] average reconstruction error: 0.262767
Epoch [16 / 500] average reconstruction error: 0.228804
Epoch [17 / 500] average reconstruction error: 0.184004
Epoch [18 / 500] average reconstruction error: 0.

Epoch [147 / 500] average reconstruction error: 0.000175
Epoch [148 / 500] average reconstruction error: 0.000169
Epoch [149 / 500] average reconstruction error: 0.000159
Epoch [150 / 500] average reconstruction error: 0.000168
Epoch [151 / 500] average reconstruction error: 0.000166
Epoch [152 / 500] average reconstruction error: 0.000156
Epoch [153 / 500] average reconstruction error: 0.000158
Epoch [154 / 500] average reconstruction error: 0.000155
Epoch [155 / 500] average reconstruction error: 0.000161
Epoch [156 / 500] average reconstruction error: 0.000166
Epoch [157 / 500] average reconstruction error: 0.000156
Epoch [158 / 500] average reconstruction error: 0.000177
Epoch [159 / 500] average reconstruction error: 0.000164
Epoch [160 / 500] average reconstruction error: 0.000164
Epoch [161 / 500] average reconstruction error: 0.000162
Epoch [162 / 500] average reconstruction error: 0.000159
Epoch [163 / 500] average reconstruction error: 0.000175
Epoch [164 / 500] average recon

Epoch [291 / 500] average reconstruction error: 0.000201
Epoch [292 / 500] average reconstruction error: 0.000180
Epoch [293 / 500] average reconstruction error: 0.000179
Epoch [294 / 500] average reconstruction error: 0.000177
Epoch [295 / 500] average reconstruction error: 0.000183
Epoch [296 / 500] average reconstruction error: 0.000172
Epoch [297 / 500] average reconstruction error: 0.000177
Epoch [298 / 500] average reconstruction error: 0.000199
Epoch [299 / 500] average reconstruction error: 0.000187
Epoch [300 / 500] average reconstruction error: 0.000185
Epoch [301 / 500] average reconstruction error: 0.000198
Epoch [302 / 500] average reconstruction error: 0.000181
Epoch [303 / 500] average reconstruction error: 0.000202
Epoch [304 / 500] average reconstruction error: 0.000194
Epoch [305 / 500] average reconstruction error: 0.000226
Epoch [306 / 500] average reconstruction error: 0.000202
Epoch [307 / 500] average reconstruction error: 0.000197
Epoch [308 / 500] average recon

Epoch [435 / 500] average reconstruction error: 0.000198
Epoch [436 / 500] average reconstruction error: 0.000207
Epoch [437 / 500] average reconstruction error: 0.000232
Epoch [438 / 500] average reconstruction error: 0.000201
Epoch [439 / 500] average reconstruction error: 0.000195
Epoch [440 / 500] average reconstruction error: 0.000182
Epoch [441 / 500] average reconstruction error: 0.000185
Epoch [442 / 500] average reconstruction error: 0.000197
Epoch [443 / 500] average reconstruction error: 0.000215
Epoch [444 / 500] average reconstruction error: 0.000190
Epoch [445 / 500] average reconstruction error: 0.000185
Epoch [446 / 500] average reconstruction error: 0.000181
Epoch [447 / 500] average reconstruction error: 0.000194
Epoch [448 / 500] average reconstruction error: 0.000225
Epoch [449 / 500] average reconstruction error: 0.000210
Epoch [450 / 500] average reconstruction error: 0.000195
Epoch [451 / 500] average reconstruction error: 0.000223
Epoch [452 / 500] average recon

In [86]:
def c2c_vae_cfe_v3(label1, label2, query):
    N_guides = 4
    lambd = 0.2
    label_diff = torch.cat((torch.tensor([[label1]]), torch.tensor([[label2]])), dim=1)
    label_diff = label_diff.to(device)

    
    #print(query.to(device))
    with torch.no_grad():
        _, latent_mu1, _ = vae(query.to(device))
    
    for i in range(1):
        #flag = False
        latent_diffs = torch.empty(N_guides,c2c_latent_dims).normal_(mean=0,std=0.5).to(device)
        label_diff_repeated = label_diff.repeat(N_guides, 1)
        c2c_encoding = torch.cat((latent_diffs,label_diff_repeated.float()), dim=1)
        
        with torch.no_grad():
            c2c_recon = c2c_vae.decoder(c2c_encoding)
            recon_from_c2c_recon = vae.decoder(latent_mu1 - c2c_recon).cpu()
        #print(recon_from_c2c_recon.shape)
    
        queries = torch.cat([query]*N_guides,axis=0)
        #print(queries.shape)
        diff = (recon_from_c2c_recon-queries)**2
        #print(torch.sum(diff,axis=2).shape)
        guide_idx = torch.argmin(torch.sum(diff, dim=1))
        #print(recon_from_c2c_recon.shape)
        native_guide = recon_from_c2c_recon[guide_idx,:]
        #guide_latent = (latent_mu1 - c2c_recon)[guide_idx]
        
        with torch.no_grad():
            _, latent_mu2, _ = vae(native_guide.to(device).unsqueeze(0))
            new_latent_mu = (1-lambd)*latent_mu1 + lambd*latent_mu2
            sf = vae.decoder(new_latent_mu)
    
    return sf[0].cpu()

In [87]:
label1 = 0
label2 = 0

query = torch.tensor([x[0]], dtype=torch.float32)

sf = c2c_vae_cfe_v3(label1, label2, query)
#print(sf)

In [88]:
sf = sf.numpy()
sf

array([0.6830274 , 0.52631366, 0.3713135 , 0.41656673, 0.44854012],
      dtype=float32)

In [93]:
# for categorical idxs, replace the values with the nearest categorical embedding values
cat_embed = {0: {'Female': 0.5046101778656127, 'Male': 0.8299651515151515},
 1: {'No': 0.6450204795204795, 'Yes': 0.6859795204795204}}

In [94]:
# Function to find the nearest value in the dictionary
def find_nearest_value(index, value):
    values = cat_embed[index].values()
    nearest_value = min(values, key=lambda x: abs(x - value))
    return nearest_value

In [96]:
# Replace values in the numpy array
cat_feats_idx = [0, 1]
for i in (cat_feats_idx):
    value = sf[i]
    nearest_value = find_nearest_value(i, value)
    sf[i] = nearest_value

In [97]:
sf

array([0.8299652 , 0.6450205 , 0.3713135 , 0.41656673, 0.44854012],
      dtype=float32)

In [30]:
x[0]

array([0.82996515, 0.64502048, 0.27272727, 0.51315789, 0.69172932])

In [31]:
y[0]

0