In [33]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.utils.data
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR
import importlib
import utils
import data
import modelMP
import modelSDAE
import helper_train_MP
import helper_train_SDAE
import helper_train_MP_GAN
import helper_noise

importlib.reload(utils)
importlib.reload(data)
importlib.reload(modelMP)
importlib.reload(helper_train_MP)
importlib.reload(modelSDAE)
importlib.reload(helper_train_SDAE)
importlib.reload(helper_noise)
importlib.reload(helper_train_MP_GAN)

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
torch.random.manual_seed(1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# randomness for the training of the Benchmark DAE
def dae_noise(x):
    return torch.rand_like(x)

cuda


## Data setup

In [34]:

dcon = {
    'dataset': 'FashionMNIST', # one of 'MNIST', 'FashionMNIST', 'CIFAR10'
    'noise_mechanism': 'mcar', # (Does not apply here. The dataset for the MP model is created inside the for loop)
    'na_obs_percentage': 0.4, # the number of observations that have missing values
    'replacement': 'uniform', # what value is plugged in for missing values in the observations with missing values (number or 'uniform')
    'noise_level': 0.3, # the percentage of missing values per observation (share of features that are missing)
    'download': False,
    'regenerate': True,
    'device': device,


    # for noise_mechanism 'patches'
    'patch_size_axis': 5, # size of the patch in the x and y direction

    # for noise_mechanism MAR and MNAR
    'randperm_cols': False, # whether to shuffle the columns of the data matrix before applying the noise mechanism
    'average_missing_rates': 'normal', # can be 'uniform', 'normal' or a list/tuple/tensor of length no. of features, determines how the missingness rates are generated
    # if uniform, noise_level is used as mean for uniform distribution and impossible values are clipped to the interval [0,1] --> noise_level != exact missingness rate in mask
    # if normal, noise_level is ignored and the options below apply
    'chol_eps': 1e-6, # epsilon for the cholesky decomposition (epsilon*I is added to the covariance matrix to make it positive definite)
    'sigmoid_offset': 0.15, # offset for the sigmoid function applied to the average missing rates generated by MultivariateNormal
    'sigmoid_k': 10, # steepness of sigmoid

    # for MNAR only
    'dependence': 'simple_unobserved', # can be 'simple_unobserved', 'complex_unobserved', 'unobserved_and_observed'
}


In [35]:
torch.random.manual_seed(1)
data_without_nas_1 = data.ImputationDatasetGen(config=dcon, missing_vals=False)

## Encoder model

In [None]:
mcon0 = {
    'architecture': 'encoder_model_dae',
    'loss': 'full', # must be full otherwise error
    'epochs': 10,
    'batch_size': 64,
    'learning_rate': 3e-4,
    'lr_decay': False,
    'gamma': 2e-4,
    'step_size': 45,
    'layer_dims_enc': [784, 2000, 50],
    'layer_dims_dec': [50, 2000, 784],
    'device': device,
    'relu': True,
    'image': True,
    'noise_model': dae_noise,
    'corruption_share': 0.2, #blevel of the masking noise that is used for training the DAE
    'mask_between_epochs': 'equal', # equal or random, determines the scope of the random generator that is passed to the mask bernoulli sampling function
    'additional_noise': 0, # does not apply here
}

tcon0 = {
    'new_training': 1,
    'log': 0,
    'save_model': 0,
    'img_index': 4, # index of the image to be plotted
    'activations': 1,
    'device': device,
    'train_val_test_split': [0.8, 0.2, 0]
}

if 'MNIST' in dcon['dataset']:
    mcon0['layer_dims_enc'][0] = 784
    mcon0['layer_dims_dec'][-1] = 784
elif 'CIFAR10' in dcon['dataset']:
    mcon0['layer_dims_enc'][0] = 1024
    mcon0['layer_dims_dec'][-1] = 1024
    
nona_train_loader_0 = DataLoader(data.DatasetWithSplits(data_without_nas_1, 'train', tcon0['train_val_test_split']), batch_size=mcon0['batch_size'], shuffle=True)
nona_val_loader_0 = DataLoader(data.DatasetWithSplits(data_without_nas_1, 'validation', tcon0['train_val_test_split']), batch_size=mcon0['batch_size'], shuffle=False)

model_autoencoder = modelSDAE.SyntheticDenoisingAutoEncoder(noise_model=dae_noise, layer_dims_enc=mcon0['layer_dims_enc'], layer_dims_dec=mcon0['layer_dims_dec'], relu=mcon0['relu'], image=mcon0['image']).to(device)
loss_fn_autoencoder = nn.MSELoss(reduction='none')
optimizer_autoencoder = torch.optim.Adam(model_autoencoder.parameters(), lr=mcon0['learning_rate'])
scheduler_autoencoder = StepLR(optimizer_autoencoder, step_size=mcon0['step_size'], gamma=mcon0['gamma'])
print(model_autoencoder)

In [None]:
helper_train_SDAE.train_imputation_model(model=model_autoencoder, encoder=None, loss_fn=loss_fn_autoencoder, optimizer=optimizer_autoencoder, scheduler=scheduler_autoencoder,
                                    dcon=dcon, mcon=mcon0, tcon=tcon0,
                                    train_dataloader=nona_train_loader_0, validation_dataloader=nona_val_loader_0,
                                    noise_model=dae_noise)

## MP model

In [None]:
rmse_all = []
rmse_all_no_enc = []
accuracy_all = []
accuracy_all_no_enc = []
corr_all = []
corr_all_no_enc = []
images = torch.zeros(3, 6, 784) if 'MNIST' in dcon['dataset'] else torch.zeros(3, 6, 1024)
idx = 10
for i, noise_mechanism in enumerate([helper_noise.missingness_adder_patch, helper_noise.missingness_adder_mnar, helper_noise.missingness_adder_mcar]):
    dcon_temp = {
        'dataset': dcon['dataset'],
        'noise_mechanism': ['patch', 'mnar', 'mcar'][i], # missingness mechanism
        'na_obs_percentage': 0.4, # the number of observations that have missing values
        'replacement': 'uniform', # what value is plugged in for missing values in the observations with missing values (number or 'uniform')
        'noise_level': 0.2, # the percentage of missing values per observation (share of features that are missing)
        'download': False,
        'regenerate': True,
        'device': device,


        # for noise_mechanism 'patches'
        'patch_size_axis': 5, #size of the patch in the x and y direction

        # for noise_mechanism MAR and MNAR
        'randperm_cols': False, #whether to shuffle the columns of the data matrix before applying the noise mechanism
        'average_missing_rates': 'normal', #can be 'uniform', 'normal' or a list/tuple/tensor of length no. of features, determines how the missingness rates are generated
        # if uniform, noise_level is used as mean for uniform distribution and impossible values are clipped to the interval [0,1] --> noise_level != exact missingness rate in mask
        # if normal, noise_level is ignored and the options below apply
        'chol_eps': 1e-6, #epsilon for the cholesky decomposition (epsilon*I is added to the covariance matrix to make it positive definite)
        'sigmoid_offset': 0.15, #offset for the sigmoid function applied to the average missing rates generated by MultivariateNormal
        'sigmoid_k': 10, #steepness of sigmoid

        # for MNAR only
        'dependence': 'simple_unobserved', #can be 'simple_unobserved', 'complex_unobserved', 'unobserved_and_observed'

    }

    mcon = {
        'architecture': 'mask_pred_mlp',
        'epochs': 10,
        'batch_size': 64,
        'learning_rate': 3e-4,
        'lr_decay': False,
        'gamma': 2e-4,
        'step_size': 45,
        'layer_dims': [784, 2000, 2000, 2000, 784],
        'dropout': 0.5,
        'device': device,
        'relu': True,
        'image': True,
        'encoder': str(model_autoencoder),    
    }
    mcon_no_enc = {
        'architecture': 'mask_pred_mlp',
        'epochs': 10,
        'batch_size': 64,
        'learning_rate': 3e-4,
        'lr_decay': False,
        'gamma': 2e-4,
        'step_size': 45,
        'layer_dims': [784, 2000, 2000, 2000, 784],
        'dropout': 0.5,
        'device': device,
        'relu': True,
        'image': True,
        'encoder': str(model_autoencoder),    
    }

    tcon = {
        'new_training': 1,
        'log': 0,
        'save_model': 0,
        'img_index': 4, # index of the image to be plotted
        'activations': 0,
        'device': device,
        'train_val_test_split': [0.8, 0.2, 0]
    }

    if 'MNIST' in dcon['dataset']:
        mcon['layer_dims'][0] = mcon0['layer_dims_enc'][-1]
        mcon['layer_dims'][-1] = 784
        mcon_no_enc['layer_dims'][0] = 784
        mcon_no_enc['layer_dims'][-1] = 784

    elif 'CIFAR10' in dcon['dataset']:
        mcon['layer_dims'][0] = mcon0['layer_dims_enc'][-1]
        mcon['layer_dims'][-1] = 1024
        mcon_no_enc['layer_dims'][0] = 1024
        mcon_no_enc['layer_dims'][-1] = 1024

    torch.random.manual_seed(1)
    data_with_nas_1 = data.ImputationDatasetGen(config=dcon_temp, missing_vals=True)

    na_train_loader_1 = DataLoader(data.DatasetWithSplits(data_with_nas_1, 'train', tcon['train_val_test_split']), batch_size=mcon['batch_size'], shuffle=True)
    na_val_loader_1 = DataLoader(data.DatasetWithSplits(data_with_nas_1, 'validation', tcon['train_val_test_split']), batch_size=mcon['batch_size'], shuffle=False)
    na_test_loader_1 = DataLoader(data.DatasetWithSplits(data_with_nas_1, 'test', tcon['train_val_test_split']), batch_size=mcon['batch_size'], shuffle=False)

    model_mp = modelMP.MaskPredMLP(layer_dims=mcon['layer_dims'], dropout=mcon['dropout'], relu=mcon['relu'], image=mcon['image']).to(device)
    model_mp_no_enc = modelMP.MaskPredMLP(layer_dims=mcon_no_enc['layer_dims'], dropout=mcon_no_enc['dropout'], relu=mcon_no_enc['relu'], image=mcon_no_enc['image']).to(device)

    loss_fn = nn.BCELoss()
    loss_fn_no_enc = nn.BCELoss()
    optimizer = torch.optim.Adam(model_mp.parameters(), lr=mcon['learning_rate'])
    optimizer_no_enc = torch.optim.Adam(model_mp_no_enc.parameters(), lr=mcon_no_enc['learning_rate'])
    scheduler = StepLR(optimizer, step_size=mcon['step_size'], gamma=mcon['gamma'])

    helper_train_MP.train_model(model=model_mp, encoder=model_autoencoder.encoder, loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler,
                                        dcon=dcon_temp, mcon=mcon, tcon=tcon,
                                        train_dataloader=na_train_loader_1, validation_dataloader=na_val_loader_1
                                        )
    helper_train_MP.train_model(model=model_mp_no_enc, encoder=None, loss_fn=loss_fn_no_enc, optimizer=optimizer_no_enc, scheduler=scheduler,
                                        dcon=dcon_temp, mcon=mcon_no_enc, tcon=tcon,
                                        train_dataloader=na_train_loader_1, validation_dataloader=na_val_loader_1
                                        )
    ## Comparison Learned and True Masks
    _, mask_true = noise_mechanism(dataset=data_without_nas_1.data, config=dcon)

    model_autoencoder.eval()
    model_mp.eval()
    model_mp_no_enc.eval()
    mask_probs = model_mp(model_autoencoder.encoder(data_without_nas_1.data.to(device)))
    mask_probs_no_enc = model_mp_no_enc(data_without_nas_1.data.to(device))
    mask_gen = torch.bernoulli(mask_probs).detach()
    mask_gen_no_enc = torch.bernoulli(mask_probs_no_enc).detach()


    corrupted_true= mask_true[None, idx].cpu()
    corrupted_gen = mask_gen[None, idx].cpu()
    corrupted_gen_no_enc = mask_gen_no_enc[None, idx].cpu()
    mean_mask_true = torch.mean(mask_true, dim=0, keepdim=True).cpu()
    mean_mask_gen = torch.mean(mask_gen, dim=0, keepdim=True).cpu()
    mean_mask_gen_no_enc = torch.mean(mask_gen_no_enc, dim=0, keepdim=True).cpu()

    images[i] = torch.cat([corrupted_true, corrupted_gen, corrupted_gen_no_enc, 
                           mean_mask_true, mean_mask_gen, mean_mask_gen_no_enc], dim=0)

    # compute rmse between true and generated missingness shares
    rmse = torch.sqrt(torch.mean((mean_mask_true - mean_mask_gen)**2))
    rmse_no_enc = torch.sqrt(torch.mean((mean_mask_true - mean_mask_gen_no_enc)**2))

    # numpy correlation between the two sequences
    corr_np = np.corrcoef(mean_mask_true.squeeze().numpy(), mean_mask_gen.squeeze().numpy())
    corr_np_no_enc = np.corrcoef(mean_mask_true.squeeze().numpy(), mean_mask_gen_no_enc.squeeze().numpy())

    accuracy = torch.mean((mask_true.cpu() == mask_gen.cpu()).float())
    accuracy_no_enc = torch.mean((mask_true.cpu() == mask_gen_no_enc.cpu()).float())

    rmse_all.append(rmse)
    rmse_all_no_enc.append(rmse_no_enc)
    corr_all.append(corr_np[0,1])
    corr_all_no_enc.append(corr_np_no_enc[0,1])
    accuracy_all.append(accuracy)
    accuracy_all_no_enc.append(accuracy_no_enc)


In [None]:
# create table with quantitative results
df = pd.DataFrame([rmse_all, rmse_all_no_enc, corr_all, corr_all_no_enc, accuracy_all, accuracy_all_no_enc], columns=['patch', 'mnar', 'mcar'], index=['rmse', 'rmse_no_enc', 'corr', 'corr_no_enc', 'accuracy', 'accuracy_no_enc'])
df = df.map(lambda x: x.item())
df

In [None]:
# create plot with qualitative results
# number of missingness mechanisms (rows)
n_mech = 3
fig, ax = plt.subplots(n_mech, 6, figsize=(16, 8))

sizes = [28, 28]
sizes = [32, 32] if 'CIFAR10' in dcon['dataset'] else sizes

for i in range(n_mech):
    for j in range(6):
        ax[i, j].imshow(torch.unflatten(images[i], dim=1, sizes=sizes)[j].cpu().float(), cmap='gray')
        ax[i, j].set_xticks([])
        ax[i, j].set_yticks([])
        ax[i, j].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
        for spine in ['top', 'right', 'bottom', 'left']:
            ax[i, j].spines[spine].set_visible(False)

x_labels = ['Simulated', 'MP', 'MPnoEnc', 'Avg Simulated', 'Avg MP', 'Avg MPnoEnc']
y_labels = ['Patch', 'MNAR', 'MCAR']

for j in range(6):
    ax[0, j].set_xlabel(x_labels[j], labelpad=8, fontsize=12, va='center', ha='center')
    ax[0, j].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
for j in range(6):
    ax[1, j].set_xlabel(x_labels[j], labelpad=8, fontsize=12, va='center', ha='center')
    ax[1, j].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
for i in range(3):
    ax[i, 0].set_ylabel(y_labels[i], rotation=90, labelpad=8, fontsize=12, va='center', ha='center')
    ax[i, 0].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)

plt.suptitle('Simulated and Learned Missingness Masks', fontsize=16, y=0.95, ha='center')
plt.subplots_adjust(top=0.92)
plt.subplots_adjust(wspace=0.1, hspace=0)
plt.show()