In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"

import torch
from torch import nn
from torchvision.transforms import v2
from torch.utils.data import DataLoader
import torch.utils.data
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR
import importlib
from sklearn.impute import SimpleImputer
import utils
import data
import modelMP
import modelSDAE
import helper_train_MP
import helper_train_SDAE
import helper_train_MP_GAN
import helper_noise

importlib.reload(utils)
importlib.reload(data)
importlib.reload(modelMP)
importlib.reload(helper_train_MP)
importlib.reload(modelSDAE)
importlib.reload(helper_train_SDAE)
importlib.reload(helper_noise)
importlib.reload(helper_train_MP_GAN)

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
torch.random.manual_seed(1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
def dae_noise(x):
    return torch.rand_like(x)

## Data setup

In [43]:

dcon = {
    'dataset': 'FashionMNIST', # one of 'MNIST', 'FashionMNIST', 'CIFAR10'
    'noise_mechanism': 'mnar', # missingness mechanism ('mcar', 'mar', 'mnar', 'patch', 'patches', 'threshold', 'special_mar', 'special_mnar_log', 'special_mnar_self_log', 'special_mnar_quant', 'fixed_patch')
    'na_obs_percentage': 0.4, # the number of observations that have missing values
    'replacement': 0, # what value is plugged in for missing values in the observations with missing values (number or 'uniform')
    'noise_level': 0.2, # the percentage of missing values per observation (share of features that are missing)
    'download': False,
    'regenerate': True,
    'device': device,


    # for noise_mechanism 'patches'
    'patch_size_axis': 5, # size of the patch in the x and y direction

    # for noise_mechanism MAR and MNAR
    'randperm_cols': False, # whether to shuffle the columns of the data matrix before applying the noise mechanism
    'average_missing_rates': 'normal', # can be 'uniform', 'normal' or a list/tuple/tensor of length no. of features, determines how the missingness rates are generated
    # if uniform, noise_level is used as mean for uniform distribution and impossible values are clipped to the interval [0,1] --> noise_level != exact missingness rate in mask
    # if normal, noise_level is ignored and the options below apply
    'chol_eps': 1e-6, # epsilon for the cholesky decomposition (epsilon*I is added to the covariance matrix to make it positive definite)
    'sigmoid_offset': 0.05, # offset for the sigmoid function applied to the average missing rates generated by MultivariateNormal
    'sigmoid_k': 10, # steepness of sigmoid

    # for MNAR only
    'dependence': 'simple_unobserved', # can be 'simple_unobserved', 'complex_unobserved', 'unobserved_and_observed'

    ### for the 'special_*' missingness generators only
    # all of them
    'p': 0.8,

    # MAR_mask
    'p_obs': 0.6,

    # MNAR_mask_logistic
    # p as a before
    'p_params': 0.7,
    'exclude_inputs': False,

    # MNAR_mask_quantiles
    # p and p_params as before
    'q': 0.2,
    'cut': 'both', # can be 'upper', 'lower', 'both'
    'MCAR': False, # whether MCAR is added on the non-MNAR mask

}

In [44]:
torch.random.manual_seed(1)
data_without_nas_1 = data.ImputationDatasetGen(config=dcon, missing_vals=False)
data_with_nas_1 = data.ImputationDatasetGen(config=dcon, missing_vals=True)

## Encoder model

In [None]:
mcon0 = {
    'architecture': 'encoder_model_dae',
    'loss': 'full', # must be full otherwise error
    'epochs': 10,
    'batch_size': 64,
    'learning_rate': 3e-4,
    'lr_decay': False,
    'gamma': 2e-4,
    'step_size': 45,
    'layer_dims_enc': [784, 2000, 700],
    'layer_dims_dec': [700, 2000, 784],
    'device': device,
    'relu': True,
    'image': True,
    'noise_model': dae_noise,
    'corruption_share': 0.2, # level of the dropout noise that is used for training the DAE
    'mask_between_epochs': 'equal', # equal or random, determines the scope of the random generator that is passed to the mask bernoulli sampling function
    'additional_noise': 0, # does not apply here
}

tcon0 = {
    'new_training': 1,
    'log': 0,
    'save_model': 1,
    'img_index': 4, # index of the image to be plotted
    'activations': 1,
    'device': device,
    'train_val_test_split': [0.8, 0.2, 0]
}

if 'MNIST' in dcon['dataset']:
    mcon0['layer_dims_enc'][0] = 784
    mcon0['layer_dims_dec'][-1] = 784
elif 'CIFAR10' in dcon['dataset']:
    mcon0['layer_dims_enc'][0] = 1024
    mcon0['layer_dims_dec'][-1] = 1024
    
nona_train_loader_0 = DataLoader(data.DatasetWithSplits(data_without_nas_1, 'train', tcon0['train_val_test_split']), batch_size=mcon0['batch_size'], shuffle=True)
nona_val_loader_0 = DataLoader(data.DatasetWithSplits(data_without_nas_1, 'validation', tcon0['train_val_test_split']), batch_size=mcon0['batch_size'], shuffle=False)

model_autoencoder = modelSDAE.SyntheticDenoisingAutoEncoder(noise_model=dae_noise, layer_dims_enc=mcon0['layer_dims_enc'], layer_dims_dec=mcon0['layer_dims_dec'], relu=mcon0['relu'], image=mcon0['image']).to(device)
loss_fn_autoencoder = nn.MSELoss(reduction='none')
optimizer_autoencoder = torch.optim.Adam(model_autoencoder.parameters(), lr=mcon0['learning_rate'])
scheduler_autoencoder = StepLR(optimizer_autoencoder, step_size=mcon0['step_size'], gamma=mcon0['gamma'])
print(model_autoencoder)

In [None]:
helper_train_SDAE.train_imputation_model(model=model_autoencoder, encoder=None, loss_fn=loss_fn_autoencoder, optimizer=optimizer_autoencoder, scheduler=scheduler_autoencoder,
                                    dcon=dcon, mcon=mcon0, tcon=tcon0,
                                    train_dataloader=nona_train_loader_0, validation_dataloader=nona_val_loader_0,
                                    noise_model=dae_noise)

## MP model

In [None]:
# load an Encoder model
last_autoencoder = sorted(os.listdir(f'models/{mcon0["architecture"]}/{dcon["dataset"]}/{dcon["noise_mechanism"]}'))[-1]
last_autoencoder_model_path = os.path.join(f'models/{mcon0['architecture']}/{dcon['dataset']}/{dcon["noise_mechanism"]}', last_autoencoder)
checkpoint_autoencoder = torch.load(last_autoencoder_model_path)

model_autoencoder = modelSDAE.SyntheticDenoisingAutoEncoder(noise_model=dae_noise, 
                                                            layer_dims_enc=checkpoint_autoencoder['mcon']['layer_dims_enc'], layer_dims_dec=checkpoint_autoencoder['mcon']['layer_dims_dec'],
                                                            relu=checkpoint_autoencoder['mcon']['relu'], image=checkpoint_autoencoder['mcon']['image']).to(device)
model_autoencoder.load_state_dict(checkpoint_autoencoder['model_state_dict'])
model_autoencoder.eval()

In [None]:
mcon = {
    'architecture': 'mask_pred_mlp', # one of 'mask_pred_mlp', 'mask_pred_vae', 'mask_pred_gan', 'mask_pred_cgan'
    'epochs': 1,
    'batch_size': 64,
    'learning_rate': 3e-4,
    'lr_decay': False,
    'gamma': 2e-4,
    'step_size': 45,
    'layer_dims': [784, 2000, 2000, 2000, 784],
    'dropout': 0.5,
    'device': device,
    'relu': True,
    'image': True,
    'encoder': last_autoencoder_model_path,

    # for mask_pred_vae
    'layer_dims_enc': [784, 1000, 2],
    'layer_dims_dec': [2, 1000, 784],

    # for mask_pred_gan
    'layer_dims_gen': [10+1024, 128, 1024],
    'layer_dims_disc': [2*1024, 128, 1],
    'lr_gen': 0.002,
    'lr_disc': 0.0002,
    'betas': (0.9, 0.999),
    
}

tcon = {
    'new_training': 1,
    'log': 0,
    'save_model': 1,
    'img_index': 4, # index of the image to be plotted
    'activations': 1,
    'device': device,
    'train_val_test_split': [0.8, 0.2, 0]
}

if 'MNIST' in dcon['dataset']:
    mcon['layer_dims'][0] = checkpoint_autoencoder['mcon']['layer_dims_enc'][-1]
    mcon['layer_dims'][-1] = 784
    mcon['layer_dims_enc'][0] = checkpoint_autoencoder['mcon']['layer_dims_enc'][-1]
    mcon['layer_dims_dec'][-1] = 784
elif 'CIFAR10' in dcon['dataset']:
    mcon['layer_dims'][0] = checkpoint_autoencoder['mcon']['layer_dims_enc'][-1]
    mcon['layer_dims'][-1] = 1024
    mcon['layer_dims_enc'][0] = checkpoint_autoencoder['mcon']['layer_dims_enc'][-1]
    mcon['layer_dims_dec'][-1] = 1024

nona_train_loader_1 = DataLoader(data.DatasetWithSplits(data_without_nas_1, 'train', tcon['train_val_test_split']), batch_size=mcon['batch_size'], shuffle=True)
nona_val_loader_1 = DataLoader(data.DatasetWithSplits(data_without_nas_1, 'validation', tcon['train_val_test_split']), batch_size=mcon['batch_size'], shuffle=False)
nona_test_loader_1 = DataLoader(data.DatasetWithSplits(data_without_nas_1, 'test', tcon['train_val_test_split']), batch_size=mcon['batch_size'], shuffle=False)

na_train_loader_1 = DataLoader(data.DatasetWithSplits(data_with_nas_1, 'train', tcon['train_val_test_split']), batch_size=mcon['batch_size'], shuffle=True)
na_val_loader_1 = DataLoader(data.DatasetWithSplits(data_with_nas_1, 'validation', tcon['train_val_test_split']), batch_size=mcon['batch_size'], shuffle=False)
na_test_loader_1 = DataLoader(data.DatasetWithSplits(data_with_nas_1, 'test', tcon['train_val_test_split']), batch_size=mcon['batch_size'], shuffle=False)

if 'mlp' in mcon['architecture']:
    model_mp = modelMP.MaskPredMLP(layer_dims=mcon['layer_dims'], dropout=mcon['dropout'], relu=mcon['relu'], image=mcon['image']).to(device)
elif 'vae' in mcon['architecture']:
    model_mp = modelMP.MaskPredVAE(layer_dims_enc=mcon['layer_dims_enc'], layer_dims_dec=mcon['layer_dims_dec'], dropout=mcon['dropout'],
                                relu=mcon['relu'], image=mcon['image'], device=device).to(device)
elif 'gan' in mcon['architecture']:
    model_mp = modelMP.MaskPredGAN(layer_dims_gen=mcon['layer_dims_gen'], layer_dims_disc=mcon['layer_dims_disc'], dropout=mcon['dropout'],
                                relu=mcon['relu'], image=mcon['image']).to(device)
else:
    raise ValueError('Invalid architecture')


if 'mlp' in mcon['architecture'] or 'vae' in mcon['architecture']:
    loss_fn = nn.BCELoss()
    optimizer = torch.optim.Adam(model_mp.parameters(), lr=mcon['learning_rate'])
    scheduler = StepLR(optimizer, step_size=mcon['step_size'], gamma=mcon['gamma'])
elif 'gan' in mcon['architecture']:
    optimizer_gen = torch.optim.Adam(model_mp.generator.parameters(), lr=mcon['lr_gen'], betas=mcon['betas'])
    optimizer_disc = torch.optim.SGD(model_mp.discriminator.parameters(), lr=mcon['lr_disc'])
    scheduler_gen = StepLR(optimizer_gen, step_size=mcon['step_size'], gamma=mcon['gamma'])
    scheduler_disc = StepLR(optimizer_disc, step_size=mcon['step_size'], gamma=mcon['gamma'])
print(model_mp)

In [None]:
# train (and save) model
if not 'gan' in mcon['architecture']:
    helper_train_MP.train_model(model=model_mp, encoder=model_autoencoder.encoder, loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler,
                                        dcon=dcon, mcon=mcon, tcon=tcon,
                                        train_dataloader=na_train_loader_1, validation_dataloader=na_val_loader_1
                                        )
    helper_train_MP.test(dataloader=na_val_loader_1, model=model_mp, encoder=model_autoencoder.encoder, loss_fn=loss_fn, tcon=tcon, dcon=dcon, mcon=mcon)
elif 'gan' in mcon['architecture']:
    helper_train_MP_GAN.train_model(model=model_mp, optimizer_gen=optimizer_gen, optimizer_disc=optimizer_disc, 
                                    scheduler_gen=scheduler_gen, scheduler_disc=scheduler_disc,
                                    dcon=dcon, mcon=mcon, tcon=tcon,
                                    train_na_dataloader=DataLoader(data_with_nas_1, batch_size=mcon['batch_size'], shuffle=True),
                                    train_nona_dataloader=DataLoader(data_without_nas_1, batch_size=mcon['batch_size'], shuffle=True),
                                    train_both_loader=DataLoader(data.DatasetZipped(data_without_nas_1, data_with_nas_1), batch_size=mcon['batch_size'], shuffle=True),
                                    test_na_dataloader=DataLoader(data_with_nas_1, batch_size=mcon['batch_size'], shuffle=False),
                                    test_nona_dataloader=DataLoader(data_without_nas_1, batch_size=mcon['batch_size'], shuffle=False)
                                    )

## Define Synthetic Denoising Autoencoder

Load noise model (mask prediction model) and Encoder from above

In [None]:
last_noise_model = sorted(os.listdir(f'models/{mcon["architecture"]}/{dcon["dataset"]}/{dcon["noise_mechanism"]}'))[-1]
last_noise_model_path = os.path.join(f'models/{mcon['architecture']}/{dcon['dataset']}/{dcon["noise_mechanism"]}', last_noise_model)
checkpoint = torch.load(last_noise_model_path)

noise_model = modelMP.MaskPredMLP(layer_dims=checkpoint['mcon']['layer_dims'], dropout=checkpoint['mcon']['dropout'],
                                  relu=checkpoint['mcon']['relu'], 
                                  image=checkpoint['mcon']['image']).to(device)
noise_model.load_state_dict(checkpoint['model_state_dict'])
noise_model.eval()

# load an encoder model
last_autoencoder = sorted(os.listdir(f'models/{mcon0["architecture"]}/{dcon["dataset"]}/{dcon["noise_mechanism"]}'))[-1]
last_autoencoder_model_path = os.path.join(f'models/{mcon0['architecture']}/{dcon['dataset']}/{dcon["noise_mechanism"]}', last_autoencoder)
checkpoint_autoencoder = torch.load(last_autoencoder_model_path)

model_autoencoder = modelSDAE.SyntheticDenoisingAutoEncoder(noise_model=dae_noise, 
                                                            layer_dims_enc=checkpoint_autoencoder['mcon']['layer_dims_enc'], layer_dims_dec=checkpoint_autoencoder['mcon']['layer_dims_dec'],
                                                            relu=checkpoint_autoencoder['mcon']['relu'], image=checkpoint_autoencoder['mcon']['image']).to(device)
model_autoencoder.load_state_dict(checkpoint_autoencoder['model_state_dict'])
model_autoencoder.eval()
dcon2 = dcon.copy()
dcon2['replacement'] = 0 # the value that is inserted with the noise process for the training of the DAE (on the fully observed data) and for the test set that is used to evaluate the DAE
dcon2['regenerate'] = True if dcon2['replacement'] != dcon['replacement'] else False

mcon2 = {
    'architecture': 'synthetic_dae',
    'loss': 'full', # full or focused
    'epochs': 15,
    'batch_size': 64,
    'learning_rate': 3e-4,
    'lr_decay': False,
    'gamma': 2e-4,
    'step_size': 45,
    'layer_dims_enc': [784, 2000, 2000, 2000],
    'layer_dims_dec': [2000, 2000, 2000, 784],
    'device': device,
    'relu': True,
    'image': True,
    'noise_model': last_noise_model_path,
    'encoder': last_autoencoder_model_path,
    'corruption_share': -1, # the share of features that are corrupted in the training of the DAE, -1 means that all missingness generated by the MP model is used
    'mask_between_epochs': 'equal', # equal or random, determines the scope of the random generator that is passed to the mask bernoulli sampling function
    'additional_noise': 0, # the share of additional noise that is added to the data during training
}
if 'MNIST' in dcon2['dataset']:
    mcon2['layer_dims_enc'][0] = 784
    mcon2['layer_dims_dec'][-1] = 784
elif 'CIFAR10' in dcon2['dataset']:
    mcon2['layer_dims_enc'][0] = 1024
    mcon2['layer_dims_dec'][-1] = 1024

tcon2 = {
    'new_training': 1,
    'log': 1,
    'save_model': 1,
    'img_index': 10, # index of the image to be plotted
    'activations': 1,
    'device': device,
    'train_val_test_split': [0.8, 0.2, 0]
}


In [10]:
torch.random.manual_seed(1)
data_without_nas_2 = data.ImputationDatasetGen(config=dcon2, missing_vals=False)
data_with_nas_2 = data.ImputationDatasetGen(config=dcon2, missing_vals=True)

In [None]:
nona_train_loader_2 = DataLoader(data.DatasetWithSplits(data_without_nas_2, 'train', tcon2['train_val_test_split']), batch_size=mcon2['batch_size'], shuffle=True)
nona_val_loader_2 = DataLoader(data.DatasetWithSplits(data_without_nas_2, 'validation', tcon2['train_val_test_split']), batch_size=mcon2['batch_size'], shuffle=False)
nona_test_loader_2 = DataLoader(data.DatasetWithSplits(data_without_nas_2, 'test', tcon2['train_val_test_split']), batch_size=mcon2['batch_size'], shuffle=False)

na_test_loader_2 = DataLoader(data.DatasetWithSplits(data_with_nas_2, 'test', [0, 0, 1]), batch_size=mcon2['batch_size'], shuffle=False) #here shuffle false, because it is only used for testing
model_sdae = modelSDAE.SyntheticDenoisingAutoEncoder(noise_model=noise_model, layer_dims_enc=mcon2['layer_dims_enc'], layer_dims_dec=mcon2['layer_dims_dec'], relu=mcon2['relu'], image=mcon2['image']).to(device)
loss_fn_sdae = nn.MSELoss(reduction='none')
optimizer_sdae = torch.optim.Adam(model_sdae.parameters(), lr=mcon2['learning_rate'])
scheduler_sdae = StepLR(optimizer_sdae, step_size=mcon2['step_size'], gamma=mcon2['gamma'])
print(model_sdae)

In [None]:
helper_train_SDAE.train_imputation_model(model=model_sdae, encoder=model_autoencoder.encoder, loss_fn=loss_fn_sdae, optimizer=optimizer_sdae, scheduler=scheduler_sdae,
                                    dcon=dcon2, mcon=mcon2, tcon=tcon2,
                                    train_dataloader=nona_train_loader_2, validation_dataloader=nona_val_loader_2, test_dataloader=na_test_loader_2,
                                    noise_model=noise_model)
helper_train_SDAE.test(dataloader=na_test_loader_2, model=model_sdae, loss_fn=loss_fn_sdae, dcon=dcon2, mcon=mcon2, tcon=tcon2)

### Downstream Task

In [14]:
last_imputation_model = sorted(os.listdir(f'models/{mcon2["architecture"]}/{dcon2["dataset"]}/{dcon2["noise_mechanism"]}'))[-1]
last_imputation_model_path = os.path.join(f'models/{mcon2['architecture']}/{dcon2['dataset']}/{dcon2["noise_mechanism"]}', last_imputation_model)
checkpoint_imputation = torch.load(last_imputation_model_path)

imputation_model = modelSDAE.SyntheticDenoisingAutoEncoder(noise_model=noise_model, 
                                                           layer_dims_enc=checkpoint_imputation['mcon']['layer_dims_enc'], layer_dims_dec=checkpoint_imputation['mcon']['layer_dims_dec'],
                                                           relu=checkpoint_imputation['mcon']['relu'], 
                                                           image=checkpoint_imputation['mcon']['image']).to(device)
imputation_model.load_state_dict(checkpoint_imputation['model_state_dict'])
imputation_model.eval()

imputed_data = imputation_model(data_with_nas_2.data.to(device))
imputed_targets = data_with_nas_2.labels

full_data = torch.cat((data_without_nas_2.data.cpu(), imputed_data.cpu()), dim=0).detach()
full_targets = torch.cat((data_without_nas_2.labels.cpu(), imputed_targets.cpu()), dim=0).detach()

utils.softmaxRegression(full_data, full_targets, num_classes=10, num_epochs=10)


## Benchmark Models

### Benchmark DAE

In [66]:
def dae_noise(x):
    return torch.rand_like(x)
dcon3 = dcon2.copy()
dcon3['replacement'] = 0
dcon3['regenerate'] = True

mcon3 = {
    'architecture': 'benchmark_dae',
    'loss': 'full', # full or focused
    'epochs': 5,
    'batch_size': 64,
    'learning_rate': 3e-4,
    'lr_decay': False,
    'gamma': 2e-4,
    'step_size': 45,
    'layer_dims_enc': [784, 2000, 2000],
    'layer_dims_dec': [2000, 2000, 784],
    'device': device,
    'relu': True,
    'image': True,
    'noise_model': dae_noise,
    'corruption_share': 0.2, # the share of features that are corrupted in the training of the DAE
    'mask_between_epochs': 'random', # (DOES NOT APPLY to benchmark_DAE)
    'additional_noise': 0, # the share of additional noise that is added to the data during training
}
if 'MNIST' in dcon3['dataset']:
    mcon3['layer_dims_enc'][0] = 784
    mcon3['layer_dims_dec'][-1] = 784
elif 'CIFAR10' in dcon3['dataset']:
    mcon3['layer_dims_enc'][0] = 1024
    mcon3['layer_dims_dec'][-1] = 1024

tcon3 = {
    'new_training': 1,
    'log': 0,
    'save_model': 1,
    'img_index': 10, # index of the image to be plotted
    'activations': 1,
    'device': device,
    'train_val_test_split': [0.8, 0.2, 0]
}

In [67]:
torch.random.manual_seed(1)
data_without_nas_3 = data.ImputationDatasetGen(config=dcon3, missing_vals=False)
data_with_nas_3 = data.ImputationDatasetGen(config=dcon3, missing_vals=True)
nona_train_loader_3 = DataLoader(data.DatasetWithSplits(data_without_nas_3, 'train', tcon3['train_val_test_split']), batch_size=mcon3['batch_size'], shuffle=True)
nona_val_loader_3 = DataLoader(data.DatasetWithSplits(data_without_nas_3, 'validation', tcon3['train_val_test_split']), batch_size=mcon3['batch_size'], shuffle=False)
nona_test_loader_3 = DataLoader(data.DatasetWithSplits(data_without_nas_3, 'test', tcon3['train_val_test_split']), batch_size=mcon3['batch_size'], shuffle=False)

na_test_loader_3 = DataLoader(data.DatasetWithSplits(data_with_nas_3, 'test', [0, 0, 1]), batch_size=mcon3['batch_size'], shuffle=False) #here shuffle false, because it is only used for testing

In [None]:
model_bdae = modelSDAE.SyntheticDenoisingAutoEncoder(noise_model=dae_noise, layer_dims_enc=mcon3['layer_dims_enc'], layer_dims_dec=mcon3['layer_dims_dec'], relu=mcon3['relu'], image=mcon3['image']).to(device)
loss_fn_bdae = nn.MSELoss(reduction='none')
optimizer_bdae = torch.optim.Adam(model_bdae.parameters(), lr=mcon3['learning_rate'])
scheduler_bdae = StepLR(optimizer_bdae, step_size=mcon3['step_size'], gamma=mcon3['gamma'])
print(model_bdae)

In [None]:
helper_train_SDAE.train_imputation_model(model=model_bdae, encoder=None, loss_fn=loss_fn_bdae, optimizer=optimizer_bdae, scheduler=scheduler_bdae,
                                    dcon=dcon3, mcon=mcon3, tcon=tcon3,
                                    train_dataloader=nona_train_loader_3, validation_dataloader=nona_val_loader_3, test_dataloader=na_test_loader_3,
                                    noise_model=dae_noise)
helper_train_SDAE.test(dataloader=na_test_loader_3, model=model_bdae, loss_fn=loss_fn_bdae, dcon=dcon3, mcon=mcon3, tcon=tcon3)

In [19]:
model_bdae.eval()
imputed_data = model_bdae(data_with_nas_2.data.to(device).detach())
imputed_targets = data_with_nas_2.labels

full_data = torch.cat((data_without_nas_2.data.cpu(), imputed_data.cpu()), dim=0).detach()
full_targets = torch.cat((data_without_nas_2.labels.cpu(), imputed_targets.cpu()), dim=0).detach()

utils.softmaxRegression(full_data, full_targets, num_classes=10, num_epochs=10)

### Benchmark VAE

In [17]:
mcon4 = {
    'architecture': 'benchmark_vae',
    'loss': 'full', # full or focused
    'epochs': 1,
    'batch_size': 64,
    'learning_rate': 3e-4,
    'lr_decay': False,
    'gamma': 2e-4,
    'step_size': 45,
    'layer_dims_enc': [784, 2000, 500],
    'layer_dims_dec': [500, 2000, 784],
    'device': device,
    'relu': True,
    'image': True,
    'noise_model': None,
    'corruption_share': 0.0, # (DOES NOT APPLY to benchmark_VAE)
    'mask_between_epochs': 'random', # (DOES NOT APPLY to benchmark_VAE)
    'additional_noise': 0, # the share of additional noise that is added to the data during training
}
if 'MNIST' in dcon3['dataset']:
    mcon4['layer_dims_enc'][0] = 784
    mcon4['layer_dims_dec'][-1] = 784
elif 'CIFAR10' in dcon3['dataset']:
    mcon4['layer_dims_enc'][0] = 1024
    mcon4['layer_dims_dec'][-1] = 1024

tcon4 = {
    'new_training': 1,
    'log': 0,
    'save_model': 1,
    'img_index': 10, # index of the image to be plotted
    'activations': 1,
    'device': device,
    'train_val_test_split': [0.8, 0.2, 0]
}

In [None]:
model_bvae = modelSDAE.ImputeVAE(layer_dims_enc=mcon4['layer_dims_enc'], layer_dims_dec=mcon4['layer_dims_dec'], relu=mcon4['relu'], image=mcon4['image']).to(device)
loss_fn_bvae = nn.MSELoss(reduction='none')
optimizer_bvae = torch.optim.Adam(model_bvae.parameters(), lr=mcon4['learning_rate'])
scheduler_bvae = StepLR(optimizer_bvae, step_size=mcon4['step_size'], gamma=mcon4['gamma'])
print(model_bvae)

In [None]:
helper_train_SDAE.train_imputation_model(model=model_bvae, encoder=None, loss_fn=loss_fn_bvae, optimizer=optimizer_bvae, scheduler=scheduler_bvae,
                                    dcon=dcon3, mcon=mcon4, tcon=tcon4,
                                    train_dataloader=nona_train_loader_3, validation_dataloader=nona_val_loader_3, test_dataloader=na_test_loader_3,
                                    noise_model=None)
helper_train_SDAE.test(dataloader=na_test_loader_3, model=model_bvae, loss_fn=loss_fn_bvae, dcon=dcon3, mcon=mcon4, tcon=tcon4)

In [23]:
model_bvae.eval()
imputed_data = model_bvae(data_with_nas_2.data.to(device).detach())
imputed_targets = data_with_nas_2.labels

full_data = torch.cat((data_without_nas_2.data.cpu(), imputed_data.cpu()), dim=0).detach()
full_targets = torch.cat((data_without_nas_2.labels.cpu(), imputed_targets.cpu()), dim=0).detach()

utils.softmaxRegression(full_data, full_targets, num_classes=10, num_epochs=10)

### Mean and Mode imputation

In [None]:
# Mean imputation
data_with_nas_as_nas = data_with_nas_2.data.clone()
mask = data_with_nas_2.targets.clone()
data_with_nas_as_nas[mask == 1] = float('nan')
full_data = torch.cat((data_without_nas_2.data.cpu(), data_with_nas_as_nas.cpu()), dim=0).detach()
full_targets = torch.cat((data_without_nas_2.labels.cpu(), data_with_nas_2.labels.cpu()), dim=0).detach()

mean_imputer = SimpleImputer(strategy='mean', copy=True).fit(full_data.cpu())
full_data_imputed = torch.tensor(mean_imputer.transform(full_data)).float().to(device)
ground_truth = data_with_nas_2.unmissing_data
imputed_data = full_data_imputed[range(data_with_nas_as_nas.size(0)), :]

mse = torch.sum(nn.MSELoss(reduction='none')(imputed_data.cpu(), ground_truth.cpu()) * (mask.cpu() == 1).float()) / torch.sum(mask.cpu())
print(f'MSE: {mse}, RMSE: {torch.sqrt(mse)}')

utils.softmaxRegression(full_data_imputed, full_targets, num_classes=10, num_epochs=10)

# Visualization Code

## Visualize Simulated Noise Patterns

In [None]:
print('masked proportion: ', torch.sum(data_with_nas_1.targets)/torch.numel(data_with_nas_1.targets))

cols = 10
fig, ax = plt.subplots(3, cols)
fig.set_figwidth(15)

ax[0, 0].set_title('Original')
ax[1, 0].set_title('Mask')
ax[2, 0].set_title('Corrupted')

sizes = [28, 28]
sizes = [32, 32] if 'CIFAR10' in dcon['dataset'] else sizes

for i in range(cols):
    ax[0, i].imshow(v2.ToDtype(torch.float32)(torch.unflatten(data_with_nas_1.unmissing_data, dim=1, sizes=sizes))[i].cpu(), cmap='gray')
    ax[0, i].axis('off')
    ax[1, i].imshow(v2.ToDtype(torch.float32)(torch.unflatten(data_with_nas_1.targets, dim=1, sizes=sizes))[i].cpu(), cmap='gray')
    ax[1, i].axis('off')
    ax[2, i].imshow(v2.ToDtype(torch.float32)(torch.unflatten(data_with_nas_1.data, dim=1, sizes=sizes))[i].cpu(), cmap='gray')
    ax[2, i].axis('off')
plt.suptitle('Patch Missingness')
plt.show()

## Visualize Learned Noise Patterns

In [None]:
# insert path to relevant model below
checkpoint = torch.load('models/mask_pred_mlp/CIFAR10/patch/modelJune18_12_14.pth')
if 'mlp' in checkpoint['mcon']['architecture']:
    noise_model_temp = modelMP.MaskPredMLP(layer_dims=checkpoint['mcon']['layer_dims'], relu=checkpoint['mcon']['relu'], 
                                        image=checkpoint['mcon']['image']).to(device)
elif 'vae' in checkpoint['mcon']['architecture']:
    noise_model_temp = modelMP.MaskPredVAE(layer_dims_enc=checkpoint['mcon']['layer_dims_enc'], layer_dims_dec=checkpoint['mcon']['layer_dims_dec'], 
                                        relu=checkpoint['mcon']['relu'], image=checkpoint['mcon']['image'], device=device).to(device)
else:
    raise ValueError('Architecture not supported')
noise_model_temp.load_state_dict(checkpoint['model_state_dict'])
noise_model_temp.eval()
print(noise_model_temp)
mask_probs = noise_model_temp(model_autoencoder.encoder(data_without_nas_1.data[range(100)].to(device)))
mask = torch.bernoulli(mask_probs).detach()
corrupted = data_without_nas_1.data[range(100)].clone()
corrupted[mask == 1] = 0
fig, ax = plt.subplots(3, 10)
fig.set_figwidth(15)

# give a title to the plot
ax[0, 0].set_title('Original')
ax[1, 0].set_title('Mask')
ax[2, 0].set_title('Corrupted')

sizes = [28, 28]
sizes = [32, 32] if 'CIFAR10' in dcon['dataset'] else sizes

for i in range(10):
    ax[0, i].imshow(v2.ToDtype(torch.float32)(torch.unflatten(data_without_nas_1.data, dim=1, sizes=sizes))[i].cpu(), cmap='gray')
    ax[0, i].axis('off')
    ax[1, i].imshow(v2.ToDtype(torch.float32)(torch.unflatten(mask, dim=1, sizes=sizes))[i].cpu(), cmap='gray')
    ax[1, i].axis('off')
    ax[2, i].imshow(v2.ToDtype(torch.float32)(torch.unflatten(corrupted, dim=1, sizes=sizes))[i].cpu(), cmap='gray')
    ax[2, i].axis('off')
# add title to plot
plt.suptitle('Learned Missingness Patterns')
plt.show()

# Visualization Code for some of the Thesis Figures

## Example Missingness Figures

In [None]:
n_img = 5
fig, ax = plt.subplots(3, n_img, figsize=(10, 6))

sizes = [28, 28]
sizes = [32, 32] if 'CIFAR10' in dcon['dataset'] else sizes

for i in range(n_img):
    ax[0, i].imshow(v2.ToDtype(torch.float32)(torch.unflatten(data_with_nas_1.unmissing_data, dim=1, sizes=sizes))[i].cpu(), cmap='gray')
    ax[1, i].imshow(v2.ToDtype(torch.float32)(torch.unflatten(data_with_nas_1.targets, dim=1, sizes=sizes))[i].cpu(), cmap='gray')
    ax[2, i].imshow(v2.ToDtype(torch.float32)(torch.unflatten(data_with_nas_1.data, dim=1, sizes=sizes))[i].cpu(), cmap='gray')

# Labels for the y-axis
y_labels = ['Ground Truth', 'Mask', 'Corrupted']

# Apply y-axis labels to the leftmost column
for i in range(3):
    ax[i, 0].set_ylabel(y_labels[i], rotation=90, labelpad=8, fontsize=12, va='center', ha='center')
    ax[i, 0].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)  # Hide ticks and labels except y-axis labels
    for spine in ['top', 'right', 'bottom', 'left']:
        ax[i, 0].spines[spine].set_visible(False)

# Hide all other axes elements and labels
for i in range(3):
    for j in range(n_img):
        ax[i, j].set_xticks([])  # Hide x-axis ticks
        ax[i, j].set_yticks([])  # Hide y-axis ticks
        ax[i, j].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)  # Hide ticks and labels
        if j != 0:  # Hide spines for all but the leftmost column
            for spine in ['top', 'right', 'bottom', 'left']:
                ax[i, j].spines[spine].set_visible(False)

# Add title to plot
plt.suptitle('QMNAR Examples', fontsize=16, y=0.95, ha='center')
plt.subplots_adjust(top=0.92)
plt.subplots_adjust(wspace=0.1, hspace=0)
plt.show()

## Visualize Imputation Performance

In [None]:
start_img = 25
n_img = 5
sizes = [28, 28]
sizes = [32, 32] if 'CIFAR10' in dcon['dataset'] else sizes
reconstructed_sdae = torch.clip(model_sdae(data_with_nas_2.data.to(device)).detach(), 0, 1).cpu()[range(start_img, start_img+n_img)]
reconstructed_bdae = torch.clip(model_bdae(data_with_nas_2.data.to(device)).detach(), 0, 1).cpu()[range(start_img, start_img+n_img)]
reconstructed_bvae = torch.clip(model_bvae(data_with_nas_2.data.to(device)).detach(), 0, 1).cpu()[range(start_img, start_img+n_img)]
reconstructed_mean = torch.tensor(mean_imputer.transform(data_with_nas_as_nas)).float().cpu()[range(start_img, start_img+n_img)]

ground_truth = data_with_nas_2.unmissing_data[range(start_img, start_img+n_img)]
corrupted = data_with_nas_2.data[range(start_img, start_img+n_img)]

imputed_sdae = data_with_nas_2.unmissing_data[range(start_img, start_img+n_img)].clone()
imputed_sdae[data_with_nas_2.targets[range(start_img, start_img+n_img)] == 1] = reconstructed_sdae[data_with_nas_2.targets[range(start_img, start_img+n_img)] == 1]
imputed_bdae = data_with_nas_2.unmissing_data[range(start_img, start_img+n_img)].clone()
imputed_bdae[data_with_nas_2.targets[range(start_img, start_img+n_img)] == 1] = reconstructed_bdae[data_with_nas_2.targets[range(start_img, start_img+n_img)] == 1]
imputed_bvae = data_with_nas_2.unmissing_data[range(start_img, start_img+n_img)].clone()
imputed_bvae[data_with_nas_2.targets[range(start_img, start_img+n_img)] == 1] = reconstructed_bvae[data_with_nas_2.targets[range(start_img, start_img+n_img)] == 1]
imputed_mean = data_with_nas_2.unmissing_data[range(start_img, start_img+n_img)].clone()
imputed_mean[data_with_nas_2.targets[range(start_img, start_img+n_img)] == 1] = reconstructed_mean[data_with_nas_2.targets[range(start_img, start_img+n_img)] == 1]

fig, ax = plt.subplots(6, n_img, figsize=(10, 10))

sizes = [28, 28]
sizes = [32, 32] if 'CIFAR10' in dcon['dataset'] else sizes

for i in range(n_img):
    ax[0, i].imshow(torch.unflatten(ground_truth, dim=1, sizes=sizes)[i].cpu().float(), cmap='gray')
    ax[1, i].imshow(torch.unflatten(corrupted, dim=1, sizes=sizes)[i].cpu().float(), cmap='gray')
    ax[2, i].imshow(torch.unflatten(imputed_sdae, dim=1, sizes=sizes)[i].cpu().float(), cmap='gray')
    ax[3, i].imshow(torch.unflatten(imputed_bdae, dim=1, sizes=sizes)[i].cpu().float(), cmap='gray')
    ax[4, i].imshow(torch.unflatten(imputed_bvae, dim=1, sizes=sizes)[i].cpu().float(), cmap='gray')
    ax[5, i].imshow(torch.unflatten(imputed_mean, dim=1, sizes=sizes)[i].cpu().float(), cmap='gray')

y_labels = ['Original', 'Corrupted', 'ImputeLM', 'DAE', 'VAE', 'Mean']

for i in range(6):
    ax[i, 0].set_ylabel(y_labels[i], rotation=90, labelpad=8, fontsize=12, va='center', ha='center')
    ax[i, 0].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
    for spine in ['top', 'right', 'bottom', 'left']:
        ax[i, 0].spines[spine].set_visible(False)

for i in range(6):
    for j in range(n_img):
        ax[i, j].set_xticks([])
        ax[i, j].set_yticks([])
        ax[i, j].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
        if j != 0: 
            for spine in ['top', 'right', 'bottom', 'left']:
                ax[i, j].spines[spine].set_visible(False)

# Add title to plot
plt.suptitle('Imputation Results on MNIST Digits', fontsize=16, y=0.95, ha='center')
plt.subplots_adjust(top=0.92)
plt.subplots_adjust(wspace=0.1, hspace=0)
plt.show()

## Visualize Denoising Autoencoder (Replication)

In [None]:
start_img = 25
n_img = 8
sizes = [28, 28]
sizes = [32, 32] if 'CIFAR10' in dcon['dataset'] else sizes


ground_truth = data_without_nas_2.data[range(start_img, start_img+n_img)]
corrupted = helper_noise.add_noise_with_model(noise_model=dae_noise, encoder=None, data=data_without_nas_2.data[range(start_img, start_img+n_img)], corruption_share=0.4, device=device)[0]
reconstructed = torch.clip(model_bdae(corrupted.to(device)).detach(), 0, 1).cpu()


fig, ax = plt.subplots(3, n_img, figsize=(16, 6))

sizes = [28, 28]
sizes = [32, 32] if 'CIFAR10' in dcon['dataset'] else sizes

for i in range(n_img):
    ax[0, i].imshow(torch.unflatten(ground_truth, dim=1, sizes=sizes)[i].cpu().float(), cmap='gray')
    ax[1, i].imshow(torch.unflatten(corrupted, dim=1, sizes=sizes)[i].cpu().float(), cmap='gray')
    ax[2, i].imshow(torch.unflatten(reconstructed, dim=1, sizes=sizes)[i].cpu().float(), cmap='gray')


# Labels for the y-axis
y_labels = ['Original', 'Corrupted', 'Reconstructed']

# Apply y-axis labels to the leftmost column
for i in range(3):
    ax[i, 0].set_ylabel(y_labels[i], rotation=90, labelpad=8, fontsize=12, va='center', ha='center')
    ax[i, 0].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)  # Hide ticks and labels except y-axis labels
    for spine in ['top', 'right', 'bottom', 'left']:
        ax[i, 0].spines[spine].set_visible(False)

# Hide all other axes elements and labels
for i in range(3):
    for j in range(n_img):
        ax[i, j].set_xticks([])  # Hide x-axis ticks
        ax[i, j].set_yticks([])  # Hide y-axis ticks
        ax[i, j].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)  # Hide ticks and labels
        if j != 0:  # Hide spines for all but the leftmost column
            for spine in ['top', 'right', 'bottom', 'left']:
                ax[i, j].spines[spine].set_visible(False)

# Add title to plot
plt.suptitle('Denoising Autoencoder Reconstructions', fontsize=16, y=0.95, ha='center')
plt.subplots_adjust(top=0.92)
plt.subplots_adjust(wspace=0.1, hspace=0)
plt.show()