In [1]:
import random
import pandas as pd
import numpy as np
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers.core import Dropout, Dense
import keras as k
import keras.backend as K
from scipy.stats import zscore
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt

In [2]:
# Read the Data
#df = pd.read_csv('../../data/original.csv',  index_col=0)
df = pd.read_csv('../../data/MCAR/mcar10/mcar_10.csv',  index_col=0)
y = np.array(df['group'])
encoder = LabelEncoder()
encoder.fit(y)
y = encoder.transform(y)

#scale the data
X = df.drop(['group'], axis = 1)
#scaler = StandardScaler()
#X = scaler.fit_transform(X)

X = np.array(X)  
n_dims = X.shape[1]

latent_dim = 2

In [3]:
def masked_mae(X_true, X_pred, mask):
    masked_diff = X_true[mask] - X_pred[mask]
    return np.mean(np.abs(masked_diff))


def fill(self, missing_mask):
  self.data[missing_mask] = -1


def create_missing_mask(X):
  if X.dtype != "f" and X.dtype != "d":
      X = X.astype(float)
  return np.isnan(X.data)


def bool_to_binary(matrix):
  """
  Converts a boolean matrix to a binary matrix

  :param matrix: a boolean matrix
  :return: a binary matrix
  """
  binary_matrix = []
  for row in matrix:
      binary_row = []
      for value in row:
          binary_row.append(1 if value else 0)
      binary_matrix.append(binary_row)
  return binary_matrix


def replace_nan(data, replacement):
  """
  Replace NaN values in a given array with a specific number.

  Args:
  data (array): The data to be processed.
  replacement (float or int): The number to replace NaN values with.

  Returns:
  The processed data with NaN values replaced by the specified number.
  """

  if not isinstance(data, np.ndarray):
      raise ValueError("Unsupported data type. Function supports numpy arrays only.")

  data[np.isnan(data)] = replacement

  return data



In [14]:
import tensorflow as tf
from keras.layers import Input, Dense, Flatten, Reshape, Dropout
from keras.models import Model, Sequential
from keras.optimizers import Adam
from keras.backend import binary_crossentropy
from keras.layers import concatenate
from keras.losses import mean_squared_error


encoder_input = Input(shape=(n_dims*2,))
encoder_seq = Sequential()
encoder_seq.add(Dense(32, activation='linear', input_shape= (n_dims*2,)))
encoder_seq.add(Dense(32, activation='linear'))

encoder_mu = Dense(latent_dim, activation='linear')(encoder_seq(encoder_input))
encoder_log_sigma = Dense(latent_dim, activation='linear')(encoder_seq(encoder_input))

def _sample_z(args):
    mu, log_sigma = args
    eps = K.random_normal(shape=K.shape(mu), mean=0., stddev=1.)
    return mu + K.exp(log_sigma / 2) * eps

encoder_output = k.layers.Lambda(_sample_z)([encoder_mu, encoder_log_sigma])

decoder_input = Input(shape=[latent_dim])

decoder_seq = Sequential()
decoder_seq.add(Dense(32, activation='linear', input_shape=[latent_dim]))
decoder_seq.add(Dense(32, activation='linear'))
decoder_seq.add(Dense(n_dims, activation='linear'))

encoder_model = Model(inputs=encoder_input, outputs=encoder_output)
decoder_model = Model(inputs=decoder_input, outputs=decoder_seq(decoder_input))
full_model = Model(inputs=encoder_input, 
                   outputs=concatenate([encoder_mu, encoder_log_sigma, decoder_seq(encoder_output)]))

adam_opt = Adam(lr=0.01)

def _vae_loss(input_and_mask, model_output):
    K.print_tensor(model_output)
    encoder_mu = model_output[:, 0:latent_dim]
    encoder_log_sigma = model_output[:, latent_dim:latent_dim*2]
    y_pred = model_output[:, latent_dim*2:]

    X_values = input_and_mask[:, :n_dims]
    missing_mask = input_and_mask[:, n_dims:]
    observed_mask = 1 - missing_mask
    y_true = X_values * observed_mask
    pred_observed = y_pred * observed_mask
    


    # E[log P(X|z)] - this is because we model our P(X_i|z) as a normal distribution
    #recon =  K.sum(K.binary_crossentropy(y_truey_true), axis=1)

    recon = mean_squared_error(y_true, y_pred)
    # D_KL(Q(z|X) || P(z|X)); calculate in closed form as both dist. are Gaussian

    kl = 0.5 * K.sum(K.exp(encoder_log_sigma) + K.square(encoder_mu) - 1. - encoder_log_sigma, axis=1)

    return recon + kl

opt = Adam(lr=0.01)

full_model.compile(optimizer=opt, loss=_vae_loss)



In [15]:
mask = np.array(bool_to_binary(create_missing_mask(X)))
X_no_na = replace_nan(X, replacement = 1)
input_with_mask = np.hstack([X_no_na, mask])

full_model.fit(x=input_with_mask, y=input_with_mask, epochs=5, batch_size=32, verbose=1)

Epoch 1/5
 [[-4.67002678 -2.3309989 6.5585537 ... -3.03366065 5.74913168 -5.28998566]
 [-6.28934097 -2.72854614 6.58684254 ... 5.17115927 -6.8244133 5.02400494]
 [-4.44401073 -7.02756262 3.10342574 ... -3.35327888 3.35673237 -1.82373524]
 ...
 [-5.72956324 -0.756473601 8.03039169 ... -1.8603493 -0.133950621 1.6672231]
 [-6.2995863 -5.98570395 7.85362673 ... -10.1821623 9.24944782 -4.27191734]
 [-4.58925533 -0.474122524 6.00594473 ... -2.33316016 0.101660907 1.72906148]]
 1/32 [..............................] - ETA: 10s - loss: 4287.1978 [[69.5416489 83.3843918 -31.475626 ... -3.92592359 -2.47907448 4.29857779]
 [65.4281311 79.8623428 -30.7602749 ... -3.68148398 -2.37574959 4.06108046]
 [69.2316742 86.285553 -34.0587883 ... -3.88078427 -2.57058144 4.31693316]
 ...
 [71.0702057 83.9567795 -33.6574211 ... -4.02296734 -2.49437714 4.3784]
 [70.6572495 87.8782349 -32.4423447 ... -3.96235442 -2.61805606 4.4034977]
 [68.2344437 87.2866058 -30.3112755 ... -3.80587482 -2.60369396 4.28059244]]
 [

 [[-34.3396797 32.4526176 -151.358902 ... -0.158806548 -0.341198117 0.838866651]
 [-35.5620728 33.3989296 -161.013077 ... -0.162491664 -0.361125082 0.858947694]
 [-33.6479797 35.8395653 -157.799973 ... -0.243824854 -0.371946424 0.927737832]
 ...
 [-34.5394402 31.8926468 -158.773453 ... -0.143539354 -0.336797893 0.823886216]
 [-34.3231621 35.7992325 -154.917969 ... -0.233357564 -0.376921386 0.924442053]
 [-34.8709259 34.5173378 -155.503937 ... -0.197125301 -0.367573231 0.889846385]]
 [[-11.4968634 57.0771942 -168.063049 ... -1.05009806 -0.324921191 1.37476087]
 [-9.9232893 59.679451 -171.080978 ... -1.12558162 -0.349520922 1.45022392]
 [-11.4138479 54.6806107 -168.625458 ... -0.998160601 -0.299455 1.31908393]
 ...
 [-11.3970728 56.6980553 -166.899017 ... -1.04287577 -0.320733 1.3667345]
 [-12.3225079 57.3961258 -166.998978 ... -1.04766273 -0.329804 1.3748219]
 [-9.853405 58.8483963 -167.879578 ... -1.10804343 -0.340614796 1.431288]]
 [[17.7246437 21.061533 -180.445847 ... -0.522634625 0

 [[19.5429859 -6.25919724 -114.317787 ... 0.311685145 0.162184194 0.993852198]
 [18.4337234 -9.84434128 -112.926018 ... 0.316781 0.244191542 0.861853063]
 [19.6715851 -11.8416843 -114.472198 ... 0.330214262 0.239824653 0.865432501]
 ...
 [18.4024067 -8.72835636 -114.863251 ... 0.313044935 0.228821903 0.887292087]
 [20.8399525 -10.1409454 -118.333008 ... 0.331462324 0.18360211 0.954735041]
 [19.1897602 -8.57695484 -116.444862 ... 0.317056865 0.205386057 0.923637807]]
 [[12.7347 11.009901 -92.0014 ... 0.231459215 0.11343725 1.13164926]
 [11.6501513 12.8541307 -97.429718 ... 0.221257284 0.111635774 1.13099217]
 [13.2670498 10.1966133 -95.7229462 ... 0.236247301 0.112840816 1.13415754]
 ...
 [12.6033897 11.740325 -95.8380127 ... 0.229012311 0.105052739 1.14362383]
 [10.8636217 13.7150097 -95.7536545 ... 0.214997604 0.118003637 1.11918783]
 [13.6809635 9.52436638 -99.8056564 ... 0.240065396 0.113019586 1.13515937]]
 [[-6.72357273 17.2303886 -77.0144653 ... 0.137404084 0.508416891 0.49928784

 [[-185.063705 35.8729591 -2177.40356 ... -0.035048008 -6.22579193 27.5489674]
 [-184.680008 36.4939308 -2164.81128 ... -0.0158910751 -6.17525768 27.4590397]
 [-185.638245 36.4406 -2191.10913 ... -0.0254884958 -6.22369337 27.6097584]
 ...
 [-185.771606 35.9733391 -2171.74731 ... -0.0385690928 -6.25431681 27.6534176]
 [-185.295792 37.9066849 -2179.33545 ... 0.0148991346 -6.1312356 27.4845352]
 [-184.702087 36.8239403 -2174.90674 ... -0.00765049458 -6.15916872 27.4461746]]
 [-192.770325 10.1157694 -2083.5166 ... 4.2110796 -3.31205297 25.2813606]
 [-200.38504 10.8290129 -2173.32764 ... 4.35384178 -3.46824455 26.2513409]
 ...
 [-202.616623 11.2626295 -2192.78516 ... 4.3972559 -3.51166177 26.5357418]
 [-194.146851 13.1514349 -2090.03394 ... 4.25729084 -3.30978847 25.4584942]
 [-198.567856 12.5217619 -2128.42798 ... 4.33284712 -3.41142368 26.0210114]]
 [[-164.98291 -21.6679287 -2082.89697 ... 7.68915606 1.51861417 16.8101177]
 [-168.117264 -23.8499393 -2109.75049 ... 7.8301363 1.57162106 17.

 [[89.7628479 18.0545197 -1871.2207 ... -10.4083843 -8.19469357 2.54395342]
 [86.5128632 18.8299484 -1835.44495 ... -9.88562 -7.99718904 2.72093415]
 [86.8477783 18.1423569 -1836.95264 ... -9.98858833 -7.96230364 2.59984851]
 ...
 [85.0291595 18.1938171 -1812.93811 ... -9.72696114 -7.81701899 2.63415384]
 [87.4624329 14.8410463 -1844.47522 ... -10.3423367 -7.71290541 2.03245974]
 [86.5738831 15.2024879 -1813.70374 ... -10.1873379 -7.67248964 2.10613585]]
 [[50.4270439 -5.52596617 -1880.43884 ... -6.52688217 -2.49027085 -1.20164645]
 [51.1757 -4.77495241 -1823.84717 ... -6.56443787 -2.61410618 -1.08997166]
 [52.2105217 -4.10115385 -1841.44702 ... -6.647861 -2.75289559 -0.997101784]
 ...
 [53.3275871 -3.68553948 -1897.25562 ... -6.76488304 -2.87500381 -0.949475288]
 [51.5640373 -4.37019682 -1827.8092 ... -6.58260298 -2.67969131 -1.0294795]
 [49.1254539 -3.03594661 -1803.52869 ... -6.13323164 -2.61237288 -0.75506711]]
 [[0.920239329 -18.9948406 -1828.76733 ... -0.98897934 2.46806574 -2.47

 [[33.6700134 -1.83478522 -1777.06177 ... -1.71136212 1.19520116 -2.78216314]
 [32.3834381 -2.71650577 -1781.14148 ... -1.73880696 1.25014615 -2.79527688]
 [30.8941669 -3.21892047 -1812.61401 ... -1.7004559 1.27839 -2.72716212]
 ...
 [33.0128174 0.191931739 -1759.68665 ... -1.39020562 1.05425191 -2.39070439]
 [33.5731506 -1.28084755 -1779.86987 ... -1.62949276 1.15701342 -2.6834445]
 [32.2481842 0.114853404 -1799.62756 ... -1.34604096 1.05641186 -2.3266592]]
 [[36.9011116 7.37459373 -1839.02515 ... -0.60390687 0.677407742 -1.70628262]
 [34.0735512 9.53907108 -1754.2749 ... -0.112975478 0.51233983 -1.06932259]
 [36.3983498 7.71161652 -1807.39087 ... -0.523141146 0.651296258 -1.60070229]
 ...
 [35.4770241 7.97215319 -1760.67871 ... -0.423848152 0.627617121 -1.46452403]
 [35.283741 9.97480106 -1764.84473 ... -0.137271285 0.490777731 -1.12335658]
 [35.36623 8.79366875 -1778.5459 ... -0.304109693 0.571277142 -1.32134485]]
 [[26.9635162 5.35142708 -1772.10522 ... -0.127531528 0.817165673 -1.

 [[10.0489941 -4.5745039 -1771.69629 ... -0.175613165 1.53356266 -0.875455081]
 [12.5371532 -4.82371092 -1777.71545 ... -0.337421656 1.61711192 -1.21952319]
 [13.0834789 -2.52638197 -1798.12842 ... -0.0313619375 1.48642159 -0.922735393]
 ...
 [12.2657452 -4.35693121 -1782.50537 ... -0.255929 1.58014393 -1.11240339]
 [11.9557629 -1.46342313 -1759.35925 ... 0.179945111 1.388358 -0.616403401]
 [10.435545 -2.80856657 -1780.42676 ... 0.0613386631 1.43219256 -0.643226683]]
 [[17.196804 4.12417 -1819.61914 ... 0.737645566 1.18643665 -0.368740439]
 [16.2213554 1.14551735 -1801.19482 ... 0.353339732 1.34791088 -0.720299244]
 [12.3803978 2.78548 -1747.92664 ... 0.78413564 1.13792622 0.0109863877]
 ...
 [15.415966 4.51651764 -1774.82898 ... 0.883866668 1.11236191 -0.0879218]
 [17.7432976 3.12858129 -1820.922 ... 0.565507174 1.26453972 -0.593414307]
 [13.6233025 3.67229247 -1801.23547 ... 0.850861132 1.11619461 -0.00133609772]]
 [[10.8883467 7.54537296 -1791.31323 ... 1.55934012 0.802027524 0.9554

 [[7.86963 -0.0846278295 -1822.68164 ... 0.673733711 1.25758386 0.283087075]
 [5.6685648 1.04659081 -1739.32837 ... 0.943043172 1.11904025 0.738198519]
 [4.82927036 -0.461229056 -1750.64392 ... 0.760896444 1.18741179 0.605670571]
 ...
 [6.15754795 1.0306592 -1754.12634 ... 0.91778779 1.13510776 0.674242735]
 [5.98947239 -0.0132879857 -1800.1156 ... 0.772322476 1.19517493 0.530590475]
 [5.80476475 2.62469816 -1811.6167 ... 1.16846347 1.02460551 0.970164716]]
 [[4.72542 0.409696192 -1827.48499 ... 0.89778775 1.1327045 0.767365873]
 [4.9236927 2.12626529 -1818.58655 ... 1.14070415 1.03156877 1.01336181]
 [3.46374679 5.69399595 -1729.56946 ... 1.73319077 0.763561726 1.75994575]
 ...
 [4.75515938 2.86246753 -1750.10669 ... 1.25675273 0.980363786 1.15073657]
 [6.21939278 0.464465469 -1766.38733 ... 0.83591938 1.17541385 0.588226438]
 [3.55180526 2.44769144 -1805.5542 ... 1.25212896 0.969122648 1.23652911]]
 [[-0.161071151 1.46107423 -1778.29199 ... 1.28478265 0.918888867 1.55889273]
 [2.0320

 [[0.0730641782 3.98174572 -1763.18909 ... 1.7116785 0.820422888 2.10135603]
 [1.57932031 1.1726377 -1742.94946 ... 1.22852337 1.04262459 1.4685204]
 [3.17738843 -2.01184869 -1758.87854 ... 0.685872734 1.29109478 0.764909506]
 ...
 [1.97884119 -0.771713 -1760.17969 ... 0.923873663 1.17631078 1.11150241]
 [1.66101587 3.62285185 -1782.68298 ... 1.5853436 0.892459273 1.8447113]
 [1.92511499 -0.519349813 -1802.31506 ... 0.963501871 1.15890074 1.15807652]]
 [[-2.08966875 -3.32320714 -1812.95056 ... 0.74115485 1.21167243 1.23214793]
 [1.20787764 -4.74620199 -1804.18079 ... 0.379089713 1.40351641 0.592309117]
 [-2.91315508 -1.99104 -1768.27026 ... 0.975337386 1.10288501 1.54600668]
 ...
 [-0.562035 -8.08220196 -1777.42285 ... -0.0299750566 1.55606306 0.289068341]
 [-0.486320913 -5.09151793 -1760.49414 ... 0.406684935 1.37203884 0.75125742]
 [-0.313332021 -4.4621172 -1831.42822 ... 0.491311908 1.33822322 0.828742862]]
 [[-5.39457178 -0.17718178 -1751.94031 ... 1.36172378 0.915395677 2.15567565

<keras.callbacks.History at 0x7fbb0c16bc70>

In [16]:
results = full_model.predict(input_with_mask)
results



array([[ 4.7382228e-02, -4.2887330e+00, -1.7698503e+03, ...,
         5.7552850e-01,  1.3981266e+00,  9.7979093e-01],
       [-1.5468408e+00, -1.8908488e+00, -1.7870522e+03, ...,
         1.0022393e+00,  1.1989162e+00,  1.5587084e+00],
       [ 9.7514713e-01, -1.4744421e+00, -1.7731614e+03, ...,
         9.4701505e-01,  1.2519460e+00,  1.3067415e+00],
       ...,
       [ 1.7991065e+00, -5.8841252e+00, -1.7769749e+03, ...,
         2.5969285e-01,  1.5523016e+00,  5.0758564e-01],
       [-4.6897519e-01, -8.0504334e-01, -1.7301400e+03, ...,
         1.1123008e+00,  1.1650558e+00,  1.5941820e+00],
       [-3.1022303e+00, -7.6912820e-01, -1.7649939e+03, ...,
         1.2392627e+00,  1.0803796e+00,  1.9314913e+00]], dtype=float32)

In [None]:
embedding = encoder_model.predict(input_with_mask)



In [None]:
embedding