In [18]:
import torch
import numpy as np
from torch.utils.data import TensorDataset, Dataset, DataLoader

class MissingEntryDataset(Dataset):
    def __init__(self, data: torch.tensor, min_missing: float, max_missing: float):
        self.data = data

        n, d = data.shape
        self.min_missing = int(min_missing * d)
        self.max_missing = int(max_missing * d)
        
    def __getitem__(self, idx):
        sample = self.data[idx]
        
        n_missing = np.random.randint(self.min_missing, self.max_missing)
        idx = np.arange(len(sample))
        np.random.shuffle(idx)
        idx = idx[:n_missing]
        mask = torch.zeros(sample.shape, dtype=torch.long)
        mask[idx] = 1
        mask = mask.bool()

        return sample, mask
    
    def __len__(self):
        return len(self.data)

In [19]:
import pandas as pd
df = pd.read_csv("TCPA_data_sel.csv")

# select the 189 real valued columns only
X = df.iloc[:, 2:].values.astype("float32")
Xs = (X - X.mean(axis=0)) / X.var(axis=0)

# split train/test
n = len(Xs)
idx = np.arange(len(Xs))
ncut = int(n * 0.8)
Xtrain = MissingEntryDataset(torch.tensor(Xs[idx[:ncut]]), 0.1, 0.4)
Xtest = MissingEntryDataset(torch.tensor(Xs[idx[ncut:]]), 0.1, 0.4)
len(Xtrain), len(Xtest)

(3784, 946)

In [20]:
import torch
from torch import nn
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self, ins=189, hidden=512, latent=64, variational=True):
        super(VAE, self).__init__()
        self.variational = variational
        
        self.enc = nn.Sequential(nn.Linear(ins, hidden), 
                                nn.ReLU(),
                                #nn.BatchNorm1d(hidden),
                                #nn.Linear(hidden, hidden),
                                #nn.ReLU(),
                                nn.BatchNorm1d(hidden))
        
        self.mean = nn.Linear(hidden, latent)
        self.log_variance = nn.Linear(hidden, latent)
        
        self.dec = nn.Sequential(nn.Linear(latent, hidden),
                                nn.ReLU(),
                                #nn.BatchNorm1d(hidden),
                                #nn.Linear(hidden, hidden),
                                #nn.ReLU(),
                                nn.BatchNorm1d(hidden),
                                nn.Linear(hidden, ins))

    def sample(self, mean, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mean + eps * std

    def forward(self, x):
        x = self.enc(x)
        mu = self.mean(x)
        log_var = self.log_variance(x)
        z = self.sample(mu, log_var) if self.variational else mu
        x = self.dec(z)
        return x, mu, log_var
    
    def gibbs(self, x0, mask):
        # initalize unobserved (masked out) with random entries
        xn = x0[:]
        xn[mask] = torch.randn(mask.shape)[mask]
        # iterativly predict
        for _ in range(20):
            # reconstruction step
            xn,_,_ = self.forward(xn)
            # reset observed values
            xn[~mask] = x0[~mask]
        return xn
    
vae = VAE(hidden=512, latent=64, variational=False)
print(vae)
assert vae(torch.tensor(np.random.randn(20, 189).astype("float32")))

VAE(
  (enc): Sequential(
    (0): Linear(in_features=189, out_features=512, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (mean): Linear(in_features=512, out_features=64, bias=True)
  (log_variance): Linear(in_features=512, out_features=64, bias=True)
  (dec): Sequential(
    (0): Linear(in_features=64, out_features=512, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Linear(in_features=512, out_features=189, bias=True)
  )
)


In [21]:
import torch.optim as optim
optimizer = optim.Adam(vae.parameters(), lr=0.001)

def train(epoch):
    sample, mask = Xtrain[:]
    reconstruction = sample[:]
    reconstruction[mask] = torch.randn(mask.shape)[mask]
    for n in range(10):
        vae.train()
        optimizer.zero_grad()
        reconstruction, mu, log_var = vae(reconstruction)
        mse = F.mse_loss(reconstruction, sample, reduction="mean")
        kl = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())
        loss = mse + 0 * kl  # set to 1 for variational regularization (centered gaussian)
        loss.backward()
        optimizer.step()
        print(f"Train Epoch {epoch}.{n}", 
              f"var loss {loss.item()}", 
              f"reconstruction mse {mse.item()}",
              f"imputation mse {F.mse_loss(reconstruction[mask], sample[mask], reduction='mean')}")
        reconstruction = reconstruction.detach() # drop previous computation graph before going into next iter
    
def test():
    vae.eval()
    with torch.no_grad():
        sample, mask = Xtest[:]
        reconstruction = vae.gibbs(sample, mask)   
        test_loss = F.mse_loss(reconstruction[mask], sample[mask], reduction="mean").item()
    print('====> Test imputation mse: {:.8f}'.format(test_loss))

In [None]:
test()
test()
test()
for epoch in range(1, 501):
    train(epoch)
    if epoch%10==0:
        test()
        test()
        test()

====> Test imputation mse: 1.01946390
====> Test imputation mse: 1.01727521
====> Test imputation mse: 1.00465620
Train Epoch 1.0 var loss 6.3818159103393555 reconstruction mse 6.3818159103393555 imputation mse 1.1431632041931152
Train Epoch 1.1 var loss 6.330203533172607 reconstruction mse 6.330203533172607 imputation mse 1.1528018712997437
Train Epoch 1.2 var loss 6.226535320281982 reconstruction mse 6.226535320281982 imputation mse 1.1786030530929565
Train Epoch 1.3 var loss 6.111794948577881 reconstruction mse 6.111794948577881 imputation mse 1.214613914489746
Train Epoch 1.4 var loss 6.003492832183838 reconstruction mse 6.003492832183838 imputation mse 1.2472885847091675
Train Epoch 1.5 var loss 5.942350387573242 reconstruction mse 5.942350387573242 imputation mse 1.2768481969833374
Train Epoch 1.6 var loss 5.953159809112549 reconstruction mse 5.953159809112549 imputation mse 1.2974590063095093
Train Epoch 1.7 var loss 6.0398125648498535 reconstruction mse 6.0398125648498535 imput

Train Epoch 8.2 var loss 4.7006120681762695 reconstruction mse 4.7006120681762695 imputation mse 1.043976068496704
Train Epoch 8.3 var loss 4.7307047843933105 reconstruction mse 4.7307047843933105 imputation mse 1.048693299293518
Train Epoch 8.4 var loss 4.770837306976318 reconstruction mse 4.770837306976318 imputation mse 1.059625267982483
Train Epoch 8.5 var loss 4.782680511474609 reconstruction mse 4.782680511474609 imputation mse 1.0709208250045776
Train Epoch 8.6 var loss 4.779810428619385 reconstruction mse 4.779810428619385 imputation mse 1.080039381980896
Train Epoch 8.7 var loss 4.763060569763184 reconstruction mse 4.763060569763184 imputation mse 1.0831209421157837
Train Epoch 8.8 var loss 4.74363374710083 reconstruction mse 4.74363374710083 imputation mse 1.0828630924224854
Train Epoch 8.9 var loss 4.7332892417907715 reconstruction mse 4.7332892417907715 imputation mse 1.088200330734253
Train Epoch 9.0 var loss 4.222461223602295 reconstruction mse 4.222461223602295 imputatio

Train Epoch 15.3 var loss 4.593225479125977 reconstruction mse 4.593225479125977 imputation mse 1.1455354690551758
Train Epoch 15.4 var loss 4.663137435913086 reconstruction mse 4.663137435913086 imputation mse 1.155791997909546
Train Epoch 15.5 var loss 4.673597812652588 reconstruction mse 4.673597812652588 imputation mse 1.1595126390457153
Train Epoch 15.6 var loss 4.643106937408447 reconstruction mse 4.643106937408447 imputation mse 1.1639668941497803
Train Epoch 15.7 var loss 4.602234840393066 reconstruction mse 4.602234840393066 imputation mse 1.1765488386154175
Train Epoch 15.8 var loss 4.575121879577637 reconstruction mse 4.575121879577637 imputation mse 1.2016748189926147
Train Epoch 15.9 var loss 4.574089050292969 reconstruction mse 4.574089050292969 imputation mse 1.239866018295288
Train Epoch 16.0 var loss 3.860809803009033 reconstruction mse 3.860809803009033 imputation mse 1.0378267765045166
Train Epoch 16.1 var loss 4.18718147277832 reconstruction mse 4.18718147277832 imp

Train Epoch 22.4 var loss 3.5913138389587402 reconstruction mse 3.5913138389587402 imputation mse 1.052688717842102
Train Epoch 22.5 var loss 3.5952632427215576 reconstruction mse 3.5952632427215576 imputation mse 1.064509391784668
Train Epoch 22.6 var loss 3.602904796600342 reconstruction mse 3.602904796600342 imputation mse 1.0760841369628906
Train Epoch 22.7 var loss 3.6301651000976562 reconstruction mse 3.6301651000976562 imputation mse 1.0861073732376099
Train Epoch 22.8 var loss 3.6871206760406494 reconstruction mse 3.6871206760406494 imputation mse 1.0941162109375
Train Epoch 22.9 var loss 3.7684314250946045 reconstruction mse 3.7684314250946045 imputation mse 1.1004258394241333
Train Epoch 23.0 var loss 3.1018178462982178 reconstruction mse 3.1018178462982178 imputation mse 1.0070816278457642
Train Epoch 23.1 var loss 3.345503807067871 reconstruction mse 3.345503807067871 imputation mse 1.0268869400024414
Train Epoch 23.2 var loss 3.4664556980133057 reconstruction mse 3.4664556

Train Epoch 29.5 var loss 3.5769715309143066 reconstruction mse 3.5769715309143066 imputation mse 1.0659540891647339
Train Epoch 29.6 var loss 3.577104091644287 reconstruction mse 3.577104091644287 imputation mse 1.062965989112854
Train Epoch 29.7 var loss 3.574521541595459 reconstruction mse 3.574521541595459 imputation mse 1.0596139430999756
Train Epoch 29.8 var loss 3.5720760822296143 reconstruction mse 3.5720760822296143 imputation mse 1.0562784671783447
Train Epoch 29.9 var loss 3.5728402137756348 reconstruction mse 3.5728402137756348 imputation mse 1.053379774093628
Train Epoch 30.0 var loss 2.986536979675293 reconstruction mse 2.986536979675293 imputation mse 1.0157467126846313
Train Epoch 30.1 var loss 3.2497336864471436 reconstruction mse 3.2497336864471436 imputation mse 1.0274782180786133
Train Epoch 30.2 var loss 3.3458855152130127 reconstruction mse 3.3458855152130127 imputation mse 1.0305675268173218
Train Epoch 30.3 var loss 3.400271415710449 reconstruction mse 3.4002714

Train Epoch 36.5 var loss 3.69392991065979 reconstruction mse 3.69392991065979 imputation mse 1.1589802503585815
Train Epoch 36.6 var loss 3.6847846508026123 reconstruction mse 3.6847846508026123 imputation mse 1.160784125328064
Train Epoch 36.7 var loss 3.6510260105133057 reconstruction mse 3.6510260105133057 imputation mse 1.1612516641616821
Train Epoch 36.8 var loss 3.6143438816070557 reconstruction mse 3.6143438816070557 imputation mse 1.1644487380981445
Train Epoch 36.9 var loss 3.586603879928589 reconstruction mse 3.586603879928589 imputation mse 1.1719541549682617
Train Epoch 37.0 var loss 2.7282893657684326 reconstruction mse 2.7282893657684326 imputation mse 1.0337860584259033
Train Epoch 37.1 var loss 2.994641065597534 reconstruction mse 2.994641065597534 imputation mse 1.0645363330841064
Train Epoch 37.2 var loss 3.1456458568573 reconstruction mse 3.1456458568573 imputation mse 1.0802165269851685
Train Epoch 37.3 var loss 3.2410848140716553 reconstruction mse 3.2410848140716

Train Epoch 43.5 var loss 2.8711740970611572 reconstruction mse 2.8711740970611572 imputation mse 0.9998350143432617
Train Epoch 43.6 var loss 2.896728277206421 reconstruction mse 2.896728277206421 imputation mse 1.003167748451233
Train Epoch 43.7 var loss 2.903547763824463 reconstruction mse 2.903547763824463 imputation mse 1.0045933723449707
Train Epoch 43.8 var loss 2.897083282470703 reconstruction mse 2.897083282470703 imputation mse 1.0044139623641968
Train Epoch 43.9 var loss 2.8856210708618164 reconstruction mse 2.8856210708618164 imputation mse 1.0033488273620605
Train Epoch 44.0 var loss 2.2676408290863037 reconstruction mse 2.2676408290863037 imputation mse 0.983416736125946
Train Epoch 44.1 var loss 2.506145715713501 reconstruction mse 2.506145715713501 imputation mse 0.9954383969306946
Train Epoch 44.2 var loss 2.665285110473633 reconstruction mse 2.665285110473633 imputation mse 1.0010173320770264
Train Epoch 44.3 var loss 2.7905972003936768 reconstruction mse 2.7905972003

Train Epoch 50.6 var loss 2.7931361198425293 reconstruction mse 2.7931361198425293 imputation mse 1.0396511554718018
Train Epoch 50.7 var loss 2.8209426403045654 reconstruction mse 2.8209426403045654 imputation mse 1.0493696928024292
Train Epoch 50.8 var loss 2.8377418518066406 reconstruction mse 2.8377418518066406 imputation mse 1.0584149360656738
Train Epoch 50.9 var loss 2.8464088439941406 reconstruction mse 2.8464088439941406 imputation mse 1.0668350458145142
====> Test imputation mse: 3.28914976
====> Test imputation mse: 2.92469835
====> Test imputation mse: 2.84176493
Train Epoch 51.0 var loss 2.0476293563842773 reconstruction mse 2.0476293563842773 imputation mse 0.9540315866470337
Train Epoch 51.1 var loss 2.2482149600982666 reconstruction mse 2.2482149600982666 imputation mse 0.9690876007080078
Train Epoch 51.2 var loss 2.3681461811065674 reconstruction mse 2.3681461811065674 imputation mse 0.9783523082733154
Train Epoch 51.3 var loss 2.4498233795166016 reconstruction mse 2.4

Train Epoch 57.6 var loss 2.3402750492095947 reconstruction mse 2.3402750492095947 imputation mse 1.0067161321640015
Train Epoch 57.7 var loss 2.3622660636901855 reconstruction mse 2.3622660636901855 imputation mse 1.0102601051330566
Train Epoch 57.8 var loss 2.373817205429077 reconstruction mse 2.373817205429077 imputation mse 1.0139367580413818
Train Epoch 57.9 var loss 2.378225326538086 reconstruction mse 2.378225326538086 imputation mse 1.018074870109558
Train Epoch 58.0 var loss 1.8026800155639648 reconstruction mse 1.8026800155639648 imputation mse 0.9716179966926575
Train Epoch 58.1 var loss 1.9622136354446411 reconstruction mse 1.9622136354446411 imputation mse 0.9883421063423157
Train Epoch 58.2 var loss 2.068556070327759 reconstruction mse 2.068556070327759 imputation mse 0.9976480007171631
Train Epoch 58.3 var loss 2.151287078857422 reconstruction mse 2.151287078857422 imputation mse 1.0049577951431274
Train Epoch 58.4 var loss 2.2166831493377686 reconstruction mse 2.2166831

Train Epoch 64.6 var loss 2.1566929817199707 reconstruction mse 2.1566929817199707 imputation mse 1.0086021423339844
Train Epoch 64.7 var loss 2.1854801177978516 reconstruction mse 2.1854801177978516 imputation mse 1.017305612564087
Train Epoch 64.8 var loss 2.198931932449341 reconstruction mse 2.198931932449341 imputation mse 1.0246371030807495
Train Epoch 64.9 var loss 2.202540874481201 reconstruction mse 2.202540874481201 imputation mse 1.0304334163665771
Train Epoch 65.0 var loss 1.6281039714813232 reconstruction mse 1.6281039714813232 imputation mse 0.9469174146652222
Train Epoch 65.1 var loss 1.772411584854126 reconstruction mse 1.772411584854126 imputation mse 0.9656484127044678
Train Epoch 65.2 var loss 1.888127326965332 reconstruction mse 1.888127326965332 imputation mse 0.9764717817306519
Train Epoch 65.3 var loss 1.9915108680725098 reconstruction mse 1.9915108680725098 imputation mse 0.985380232334137
Train Epoch 65.4 var loss 2.080082893371582 reconstruction mse 2.080082893