# Train VAE to evaluate counterfactual realism
In this notebook, we train a VAE on the MNIST dataset so that we can evaluate the realism of any given counterfactual

In [1]:
inDrive = True

In [3]:
if inDrive:
    from google.colab import drive
    drive.mount('/content/drive')
    import os
    os.chdir('/content/drive/My Drive/Hybrid-CLUE/MyImplementation/training_notebooks')
    import sys

    # Add the parent directory to the system path
    current_dir = os.getcwd()
    parent_dir = os.path.dirname(current_dir)
    sys.path.insert(0, parent_dir)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Setup


Import libraries

In [8]:
import importlib
import models.regene_models
importlib.reload(models.regene_models)
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os
import clue.new_CLUE
importlib.reload(clue.new_CLUE)

<module 'clue.new_CLUE' from '/content/drive/My Drive/Hybrid-CLUE/MyImplementation/clue/new_CLUE.py'>

Set the device

In [9]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: cuda


Load the Datasets

In [10]:
# Load the MNIST dataset
from sklearn.model_selection import train_test_split

transform = transforms.Compose([transforms.ToTensor()])
mnist_dataset = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform)

# Split the dataset into training and validation sets
train_indices, val_indices = train_test_split(np.arange(len(mnist_dataset)), test_size=0.2, random_state=42)

train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

trainloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=64, sampler=train_sampler, num_workers=2)
valloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=64, sampler=val_sampler, num_workers=2)

Create a models directory if it doesn't exist

In [11]:
# Create models directory if it doesn't exist
os.makedirs('../model_saves', exist_ok=True)

## Train

In [14]:
from models.VAE_likelihood import train_vae_for_likelihood

vae = train_vae_for_likelihood(trainloader, valloader, device=device, epochs=400, patience=10, model_saves_dir='model_saves')

Epoch 1/400: Train Loss: 122.5581 (Recon: 109.1336, KL: 13.4245)
Epoch 1/400: Val Loss: 21.9799 (Recon: 17.1787, KL: 4.8012)
Saved best model to: model_saves/vae_likelihood_estimator_20.pt
Epoch 2/400: Train Loss: 85.2722 (Recon: 65.9348, KL: 19.3374)
Epoch 2/400: Val Loss: 20.8407 (Recon: 15.8421, KL: 4.9986)
Saved best model to: model_saves/vae_likelihood_estimator_20.pt
Epoch 3/400: Train Loss: 82.4757 (Recon: 62.8625, KL: 19.6132)
Epoch 3/400: Val Loss: 20.4224 (Recon: 15.4254, KL: 4.9970)
Saved best model to: model_saves/vae_likelihood_estimator_20.pt
Epoch 4/400: Train Loss: 81.2311 (Recon: 61.5472, KL: 19.6838)
Epoch 4/400: Val Loss: 20.2076 (Recon: 15.4794, KL: 4.7282)
Saved best model to: model_saves/vae_likelihood_estimator_20.pt
Epoch 5/400: Train Loss: 80.3767 (Recon: 60.7145, KL: 19.6621)
Epoch 5/400: Val Loss: 20.0462 (Recon: 15.0032, KL: 5.0430)
Saved best model to: model_saves/vae_likelihood_estimator_20.pt
Epoch 6/400: Train Loss: 79.7105 (Recon: 60.0623, KL: 19.6482)
