 # Deep learning meets missing data: Doing it MIWAE on MAR MNIST

 In this notebook, we'll learn a deep generative model on the MAR‐masked MNIST dataset and impute its missing pixels.

 # Installing and loading useful stuff

In [None]:
!pip3 install --user --upgrade scikit-learn

import torch
import torchvision
import torch.nn as nn
import numpy as np
import scipy.stats
import pandas as pd
import matplotlib.pyplot as plt
import torch.distributions as td
from torch import optim
from sklearn.ensemble import ExtraTreesRegressor
from sklearn.experimental import enable_iterative_imputer  # noqa
from sklearn.linear_model import BayesianRidge
from sklearn.impute import IterativeImputer, SimpleImputer

from torchvision import transforms



 # Loading MNIST and applying MAR

In [None]:
from sklearn.model_selection import train_test_split



transform = torchvision.transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
mnist = torchvision.datasets.MNIST(root='.', train=True, download=True, transform=transform)
data = mnist.data.float().view(-1, 784).numpy() / 255.0 + np.random.normal(0, 1, mnist.data.float().view(-1, 784).numpy().shape) * 0.1

data, _ = train_test_split(
    data, train_size=0.2, random_state=42, shuffle=True
)

bias, unbias = train_test_split(data, test_size=0.2, random_state=42)
print(f"Bias data shape: {bias.shape}")
print(f"Unbias data shape: {unbias.shape}")


In [None]:
def create_mar_mask(data):
    masks = np.zeros((data.shape[0], data.shape[1]))
    for i, example in enumerate(data):
        h = (1. / (784. / 2.)) * np.sum(example[392:]) + 0.3
        pi = np.random.binomial(2, h)
        _mask = np.ones(example.shape[0])
        if pi == 0:
            _mask[196:392] = 0
        elif pi == 1:
            _mask[:392] = 0
        elif pi == 2:
            _mask[:196] = 0
        masks[i, :] = _mask
    return masks

def create_mnar_mask(data):
    masks = np.zeros((data.shape[0], data.shape[1]))
    for i, example in enumerate(data):
        h = (1. / (784. / 2.)) * np.sum(example[:392]) + 0.3
        pi = np.random.binomial(2, h)
        _mask = np.ones(example.shape[0])
        if pi == 0:
            _mask[196:392] = 0
        elif pi == 1:
            _mask[:392] = 0
        elif pi == 2:
            _mask[:196] = 0
        masks[i, :] = _mask
    return masks

In [None]:
np.random.seed(1234)
mask_bias = create_mnar_mask(bias)
data_obs_bias = bias.copy()
# data_obs_bias = (data_obs_bias - data_obs_bias.mean()) / data_obs_bias.std()
data_obs_bias[mask_bias == 0] = np.nan

mask_unbias = create_mnar_mask(unbias)
data_obs_unbias = unbias.copy()
# data_obs_unbias = (data_obs_unbias - data_obs_unbias.mean()) / data_obs_unbias.std()
data_obs_unbias[mask_unbias == 0] = np.nan



In [None]:
import matplotlib.pyplot as plt
import numpy as np

# %%
n = bias.shape[0]
indices = np.random.choice(n, 15, replace=False)
plt.figure(figsize=(30, 4))
for i, idx in enumerate(indices):
    orig = bias[idx].reshape(28, 28)
    masked_bias = data_obs_bias[idx].copy()
    masked_bias[mask_bias[idx] == 0] = 0.5  # Use grey (0.5) for masked pixels
    plt.subplot(2, 15, i + 1)
    plt.imshow(orig, cmap='gray')
    plt.axis('off')
    if i == 0:
        plt.ylabel('Original')
    plt.subplot(2, 15, i + 16)
    plt.imshow(masked_bias.reshape(28, 28), cmap='gray')
    plt.axis('off')
    if i == 0:
        plt.ylabel('Masked')
plt.show()


 # Preprocessing

In [None]:
xfull_bias = data_obs_bias.copy()
n, p = xfull_bias.shape
xobs_zero_bias = np.nan_to_num(xfull_bias, 0)
mask_bool_bias = mask_bias.copy()


In [None]:
# Plot an example of xfull with NaNs colored red
plt.figure(figsize=(8, 6))
example_idx = 0
example_data = xfull_bias[example_idx].reshape(28, 28)

# Create a custom colormap where NaNs are red
import matplotlib.colors as colors
cmap = plt.cm.gray.copy()
cmap.set_bad(color='red')

plt.imshow(example_data, cmap=cmap, vmin=0, vmax=1)
plt.title(f'Example {example_idx}: Original data with NaNs (red)')
plt.colorbar()
plt.axis('off')
plt.show()


 # Hyperparameters

In [None]:
h = 256
d = 1
K = 20


 # Model building

In [None]:
device = "cpu"
p_z = td.Independent(td.Normal(torch.zeros(d).to(device), torch.ones(d).to(device)), 1)

decoder = nn.Sequential(
    nn.Linear(d, h),
    nn.ReLU(),
    nn.Linear(h, h),
    nn.ReLU(),
    nn.Linear(h, 3 * p),
)

encoder = nn.Sequential(
    nn.Linear(p, h),
    nn.ReLU(),
    nn.Linear(h, h),
    nn.ReLU(),
    nn.Linear(h, 2 * d),
)

encoder.to(device)
decoder.to(device)


 # MIWAE loss

In [None]:
def miwae_loss(iota_x, mask):
    batch = iota_x.shape[0]
    out = encoder(iota_x)
    q = td.Independent(td.Normal(out[..., :d], torch.nn.Softplus()(out[..., d:])), 1)
    z = q.rsample([K]).reshape([K * batch, d])
    dec = decoder(z)
    mu, scale, df = dec[:, :p], torch.nn.Softplus()(dec[:, p:2*p]) + 1e-3, torch.nn.Softplus()(dec[:, 2*p:]) + 3
    scale = torch.full_like(scale, 0.1)  # Ensure scale is not too small

    data_flat = iota_x.repeat(K, 1).reshape(-1, 1)
    mask_flat = mask.repeat(K, 1)
    # log_px = td.StudentT(df=df.reshape(-1,1), loc=mu.reshape(-1,1), scale=scale.reshape(-1,1)).log_prob(data_flat)
    log_px = td.Normal(loc=mu.reshape(-1, 1), scale=scale.reshape(-1, 1)).log_prob(data_flat)
    log_px = log_px.reshape(K * batch, p)
    # plt.imshow(mask_flat.reshape(K, batch, 28,28)[0][0].cpu().numpy(), cmap='gray')
    # plt.show()
    log_px_obs = (log_px * mask_flat).reshape(K, batch, p).sum(-1)
    log_pz = p_z.log_prob(z.reshape(K, batch, d))
    log_q = q.log_prob(z.reshape(K, batch, d))
    bound = torch.logsumexp(log_px_obs + log_pz - log_q, 0) - torch.log(torch.tensor(K, dtype=torch.float, device=iota_x.device))
    return -bound.mean()


In [None]:
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)


 # Single imputation

In [None]:
def miwae_impute(iota_x, mask, L, device):
    batch = iota_x.shape[0]
    out = encoder(iota_x)
    q = td.Independent(td.Normal(out[..., :d], torch.nn.Softplus()(out[..., d:])), 1)
    z = q.rsample([L]).reshape([L * batch, d])
    dec = decoder(z)
    mu, scale, df = dec[:, :p], torch.nn.Softplus()(dec[:, p:2*p]) + 1e-3, torch.nn.Softplus()(dec[:, 2*p:]) + 3
    scale = torch.full_like(scale, 0.1)  # Ensure scale is not too small
    # scale = 0.01
    # log_px = td.StudentT(df=df.reshape(-1,1), loc=mu.reshape(-1,1), scale=scale.reshape(-1,1)).log_prob(iota_x.repeat(L,1).reshape(-1,1)).reshape(L, batch, p)
    log_px = td.Normal(loc=mu.reshape(-1, 1), scale=scale.reshape(-1, 1)).log_prob(iota_x.repeat(L, 1).reshape(-1, 1)).reshape(L, batch, p)

    log_pz = p_z.log_prob(z.reshape(L, batch, d))
    log_q = q.log_prob(z.reshape(L, batch, d))
    w = torch.nn.functional.softmax(log_px.sum(-1) + log_pz - log_q, 0)
    x_samples = td.Independent(td.StudentT(df=df.reshape(-1,1), loc=mu.reshape(-1,1), scale=scale.reshape(-1,1)),1).sample().reshape(L, batch, p)
    return torch.einsum('lb,lbp->bp', w, x_samples)


 # Training

In [None]:
xhat_bias = xobs_zero_bias.copy()
mask_bias_t = mask_bool_bias.astype(float)
bs = 64
epochs = 20

for ep in range(0, epochs):
    if ep % 5 == 0:
        with torch.no_grad():
            total_bound = -miwae_loss(torch.tensor(xhat_bias, dtype=torch.float).to(device), torch.tensor(mask_bias_t, dtype=torch.float).to(device))
            print(f'Epoch {ep} bound {total_bound.item()}')
            xhat_bias_tensor = miwae_impute(torch.tensor(xhat_bias, dtype=torch.float).to(device), torch.tensor(mask_bias_t, dtype=torch.float).to(device), 10, device).cpu().numpy()
            print(f'Imputation MSE {np.mean((xhat_bias_tensor - bias)[mask_bool_bias == 0]**2)}')
            fig, axs = plt.subplots(4, 4, figsize=(20, 8))
            axs[0,0].set_title('Bias')
            axs[0,1].set_title('Masked')
            axs[0,2].set_title('Just MIWAE')
            axs[0,3].set_title('MIWAE + Bias')
            for i in range(4):                
                axs[i,0].imshow(bias[i].reshape(28, 28), cmap='gray', vmin=0, vmax=1)
                axs[i,1].imshow(xobs_zero_bias[i].reshape(28, 28), cmap='gray', vmin=0, vmax=1)
                axs[i,1].set_title('Masked')
                axs[i,2].imshow(xhat_bias_tensor[i].reshape(28, 28), cmap='gray', vmin=0, vmax=1)
                axs[i,2].set_title('Just MIWAE')
                axs[i,3].imshow(xhat_bias_tensor[i].reshape(28, 28) * (1-mask_bool_bias[i]).reshape(28, 28) + bias[i].reshape(28, 28) * mask_bool_bias[0].reshape(28, 28)
                            , cmap='gray', vmin=0, vmax=1)
            plt.show()
    idx = np.random.permutation(n)
    for i in range(0, n, bs):
        batch_id = idx[i:i+bs]
        b_x = torch.tensor(xhat_bias[batch_id], dtype=torch.float).to(device)
        b_m = torch.tensor(mask_bias_t[batch_id], dtype=torch.float).to(device)
        optimizer.zero_grad()
        loss = miwae_loss(b_x, b_m)
        loss.backward()
        optimizer.step()
    



## EBM Example

In [None]:
import torch.nn as nn
class EBM(nn.Module):
    def __init__(self, input_dim, hidden_dim=256):
        super(EBM, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        return self.net(x).squeeze(-1)

ebm = EBM(input_dim=unbias.shape[1], hidden_dim=32).to(device)
ebm_optimizer = optim.Adam(ebm.parameters(), lr=1e-3)

In [None]:
def miwae_sample(N_samples, device):
    """
    Sample L imputations from the MIWAE model for missing values in iota_x.
    Returns: samples of shape (L, batch_size, p)
    """
    z = td.Normal(loc=torch.zeros(N_samples, d, device=device), scale=torch.ones(N_samples, d, device=device)).sample()
    z =  z.reshape([N_samples, -1, d])  # Reshape to (L, batch_size, d)

    out_decoder = decoder(z)
    all_means_obs_model = out_decoder[..., :p]
    all_scales_obs_model = torch.nn.Softplus()(out_decoder[..., p:(2*p)]) + 0.001
    all_degfreedom_obs_model = torch.nn.Softplus()(out_decoder[..., (2*p):(3*p)]) + 3

    xgivenz = td.Independent(
        td.Normal(
            loc=all_means_obs_model,
            scale=0.1,
            # df=all_degfreedom_obs_model
        ), 1
    )
    x_samples = xgivenz.sample()
    return x_samples

In [None]:
xfull_unbias = data_obs_unbias.copy()
n, p = xfull_unbias.shape
xobs_unbias_zero = np.nan_to_num(xfull_unbias, 0)
mask_unbias_bool = mask_unbias.copy()




n_epochs = 1000

In [None]:
def miwaebm_impute(iota_x, mask, L, device):
    
    batch = iota_x.shape[0]
    out = encoder(iota_x)
    q = td.Independent(td.Normal(out[..., :d], torch.nn.Softplus()(out[..., d:])), 1)
    z = q.rsample([L]).reshape([L * batch, d])
    dec = decoder(z)
    mu, scale, df = dec[:, :p], torch.nn.Softplus()(dec[:, p:2*p]) + 1e-3, torch.nn.Softplus()(dec[:, 2*p:]) + 3
    scale = torch.full_like(scale, 0.1)  # Ensure scale is not too small
    # scale = 0.01
    # log_px = td.StudentT(df=df.reshape(-1,1), loc=mu.reshape(-1,1), scale=scale.reshape(-1,1)).log_prob(iota_x.repeat(L,1).reshape(-1,1)).reshape(L, batch, p)
    log_px = td.Normal(loc=mu.reshape(-1, 1), scale=scale.reshape(-1, 1)).log_prob(iota_x.repeat(L, 1).reshape(-1, 1)).reshape(L, batch, p).sum(-1)
    energy_px = ebm(iota_x.repeat(L,1).reshape(-1, p)).reshape(L, batch)
    log_px_corrected = log_px - energy_px
    log_pz = p_z.log_prob(z.reshape(L, batch, d))
    log_q = q.log_prob(z.reshape(L, batch, d))
    w = torch.nn.functional.softmax(log_px_corrected + log_pz - log_q, 0)
    x_samples = td.Independent(td.StudentT(df=df.reshape(-1,1), loc=mu.reshape(-1,1), scale=scale.reshape(-1,1)),1).sample().reshape(L, batch, p)
    return torch.einsum('lb,lbp->bp', w, x_samples)

In [None]:
from torch.utils.data import TensorDataset, DataLoader

# Prepare TensorDataset and DataLoader for batching
xhat_unbias = np.copy(xfull_unbias)
xhat_unbias[np.isnan(data_obs_unbias)] = 0
xhat_unbias_0_tensor = torch.from_numpy(xhat_unbias).float().to(device)
dataset = TensorDataset(xhat_unbias_0_tensor)
loader = DataLoader(dataset, batch_size=bs, shuffle=True)

miwae_loss_train = np.array([])
mse_train = np.array([])

for param in decoder.parameters():
    param.requires_grad = False

for ep in range(1, n_epochs):

    # if ep % 1 == 1:
    
    for b_data in loader:
        b_data = b_data[0].to(device)
        ebm_optimizer.zero_grad()

        energy_gt = ebm(b_data)

        # Sample from MIWAE
        x_samples = miwae_sample(N_samples=b_data.shape[0], device=device).squeeze()
        energy_miwae = ebm(x_samples)

        loss = torch.mean(energy_gt) - torch.mean(energy_miwae)
        reg_loss = torch.mean(energy_gt**2) + torch.mean(energy_miwae**2)
        interp = torch.rand(b_data.shape[0], 1, device=device)
        x_interp = interp * b_data + (1 - interp) * x_samples
        x_interp.requires_grad_(True)
        energy_interp = torch.mean(ebm(x_interp))
        grad_interp = torch.autograd.grad(
            outputs=energy_interp,
            inputs=x_interp,
            create_graph=True,
            retain_graph=True
        )[0]

        grad_reg_loss = grad_interp.norm(2, dim=1) 
        loss += 0.1 * reg_loss + 0.1 * grad_reg_loss.mean()
        loss.backward()
        ebm_optimizer.step()
    if True :
        print(f'Epoch {ep}')
        print(f'EBM loss: {loss.item()}, EBM gt : {torch.mean(energy_gt).item()}, EBM MIWAE: {torch.mean(energy_miwae).item()}')
        # Fix shape mismatch: mask_mcar is shape (n, p), but may contain extra columns due to concatenation
        # TODO: imputation MSE calc
        print('-----')
        print(f'Epoch {ep} bound {total_bound.item()}')
        xhat_bias_tensor = miwaebm_impute(torch.tensor(xhat_bias, dtype=torch.float).to(device), torch.tensor(mask_bias, dtype=torch.float).to(device), 10, device).cpu().detach().numpy()
        print(f'Imputation MSE {np.mean((xhat_bias_tensor - bias)[mask_bool_bias == 0]**2)}')
        fig, axs = plt.subplots(4, 4, figsize=(20, 8))
        axs[0,0].set_title('Bias')
        axs[0,1].set_title('Masked')
        axs[0,2].set_title('Just MIWAE')
        axs[0,3].set_title('MIWAE + Bias')
        for i in range(4):                
            axs[i,0].imshow(unbias[i].reshape(28, 28), cmap='gray', vmin=0, vmax=1)
            axs[i,1].imshow(xobs_zero_bias[i].reshape(28, 28), cmap='gray', vmin=0, vmax=1)
            axs[i,1].set_title('Masked')
            axs[i,2].imshow(xhat_bias_tensor[i].reshape(28, 28), cmap='gray', vmin=0, vmax=1)
            axs[i,2].set_title('Just MIWAE')
            axs[i,3].imshow(xhat_bias_tensor[i].reshape(28, 28) * (1-mask_bool_bias[i]).reshape(28, 28) + bias[i].reshape(28, 28) * mask_bool_bias[0].reshape(28, 28)
                        , cmap='gray', vmin=0, vmax=1)
        plt.show()
    