In [1]:
!pip install neural-tangents

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting neural-tangents
  Downloading neural_tangents-0.6.1-py2.py3-none-any.whl (249 kB)
[K     |████████████████████████████████| 249 kB 4.7 MB/s 
[?25hCollecting tf2jax>=0.3.0
  Downloading tf2jax-0.3.1-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 5.6 MB/s 
Collecting frozendict>=2.3
  Downloading frozendict-2.3.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (110 kB)
[K     |████████████████████████████████| 110 kB 57.7 MB/s 
Installing collected packages: tf2jax, frozendict, neural-tangents
Successfully installed frozendict-2.3.4 neural-tangents-0.6.1 tf2jax-0.3.1


In [2]:
import numpy as np
from collections import defaultdict
import jax.numpy
import neural_tangents as nt
from neural_tangents import stax
from jax import random
import math

from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import *
from sklearn.linear_model import *

In [3]:
def MI_NNGP1(input, number_of_imputation=1, W_std=1.0, b_std=0.0):
    # W_std is standard deviation of weight parameters
    # b_std is standard deviation of bias parameters
    n, p = input.shape
    mask = np.isnan(input)
    pattern = defaultdict(list)
    for i in range(n):
        pattern[tuple(mask[i])].append(i)
    
    if tuple([False]*p) in pattern:
        complete_cases_indicator = True
        complete_cases = pattern[tuple([False]*p)]
    else:
        complete_cases_indicator = False
    try:
        assert complete_cases_indicator == True
    except:
        print('no complete cases found, please use MI-NNGP2')
        return 
    
    W_std = W_std
    b_std = b_std
    init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(300, W_std=math.sqrt(W_std), b_std=b_std, parameterization="standard"), stax.Relu(),  #stax.Relu()  stax.Erf()
    stax.Dense(300, W_std=math.sqrt(W_std), b_std=b_std, parameterization="standard"), stax.Relu(),
    stax.Dense(1, W_std=math.sqrt(W_std), b_std=b_std, parameterization="standard")
    )

    imputation_list = []
    for _ in range(number_of_imputation):
        key = random.PRNGKey(i*71)
        imputation = input.copy()
        for mask in list(pattern.keys()):
            if list(mask) != [False]*p:
                incomplete_cases = pattern[mask]
                mask = np.array(list(mask))
                train_input = jax.numpy.array(np.transpose(input[complete_cases][:,mask==False]))
                test_input = jax.numpy.array(np.transpose(input[complete_cases][:,mask==True]))
                train_target = jax.numpy.array(np.transpose(input[incomplete_cases][:,mask==False]))

                predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, train_input, train_target)
                nngp_mean, nngp_covariance = predict_fn(x_test=test_input, get='nngp',compute_cov=True)
                
                intermidate = imputation[incomplete_cases]
                if number_of_imputation==1:
                    # for single imputation, use mean value as imputation
                    intermidate[:,mask==True] = jax.numpy.transpose(nngp_mean)
                else:
                    # for multiple imputation, draw imputation from posterior distribution
                    sampling = np.zeros(nngp_mean.shape)
                    for j in range(nngp_mean.shape[1]):
                        sampling[:,j] = jax.random.multivariate_normal(key, nngp_mean[:,j], nngp_covariance)
                    intermidate[:,mask==True] = jax.numpy.transpose(sampling)
                imputation[incomplete_cases] = intermidate

        imputation_list.append(imputation)
    if number_of_imputation==1:
        return imputation_list[0]
    else:
        return imputation_list

In [4]:
def MI_NNGP2(input, number_of_imputation=1, burn_in=3, interval=1, W_std=1.0, b_std=0.0):
    # W_std is standard deviation of weight parameters
    # b_std is standard deviation of bias parameters
    # burn_in is burn in period
    # interval is sampling interval
    n, p = input.shape
    mask = np.isnan(input)
    pattern = defaultdict(list)
    for i in range(n):
        pattern[tuple(mask[i])].append(i)
    
    if tuple([False]*p) in pattern:
        initial_imputation = MI_NNGP1(input)
    else:
        MICE_imputer=IterativeImputer(estimator=BayesianRidge(),skip_complete=True,max_iter=20, tol=0.01,sample_posterior=False,random_state=42)
        initial_imputation=MICE_imputer.fit_transform(input)
    print('finish initial imputation!')
    
    W_std = W_std
    b_std = b_std
    init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(300, W_std=math.sqrt(W_std), b_std=b_std, parameterization="standard"), stax.Relu(),  #stax.Relu()  stax.Erf()
    stax.Dense(300, W_std=math.sqrt(W_std), b_std=b_std, parameterization="standard"), stax.Relu(),
    stax.Dense(1, W_std=math.sqrt(W_std), b_std=b_std, parameterization="standard")
    )

    imputation = initial_imputation.copy()
    imputation_list = []
    for i in range(burn_in+number_of_imputation*interval):
        key = random.PRNGKey(i*71)
        for mask in list(pattern.keys()):
            if list(mask) != [False]*p:
                incomplete_cases = pattern[mask]
                complement_cases = [i for i in list(range(n)) if i not in incomplete_cases]
                mask = np.array(list(mask))
                train_input = jax.numpy.array(np.transpose(imputation[complement_cases][:,mask==False]))
                test_input = jax.numpy.array(np.transpose(imputation[complement_cases][:,mask==True]))
                train_target = jax.numpy.array(np.transpose(imputation[incomplete_cases][:,mask==False]))

                predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, train_input, train_target)
                nngp_mean, nngp_covariance = predict_fn(x_test=test_input, get='nngp',compute_cov=True)

                intermidate = imputation[incomplete_cases]
                if number_of_imputation==1:
                    # for single imputation, use mean value as imputation
                    intermidate[:,mask==True] = jax.numpy.transpose(nngp_mean)
                else:
                    # for multiple imputation, draw imputation from posterior distribution
                    sampling = np.zeros(nngp_mean.shape)
                    for j in range(nngp_mean.shape[1]):
                        sampling[:,j] = jax.random.multivariate_normal(key, nngp_mean[:,j], nngp_covariance)
                    intermidate[:,mask==True] = jax.numpy.transpose(sampling)
                imputation[incomplete_cases] = intermidate   

        if i>=burn_in and (i+1-burn_in)%interval==0:
            imputation_list.append(imputation.copy()) 
        print('finish epoch {}!'.format(i))

    if number_of_imputation==1:
        return imputation_list[0]
    else:
        return imputation_list

In [5]:
def data_G_linear(n,p,miss_pat1=[False,False,False,True,False],miss_pat2=[False,False,False,False,True],a0=0,a1=0.1,a2=0.1,a3=0,a4=0.1,a5=0.1,rho=0.1,seed=1,sigma1=0.2,sigma2=0.5):
  # obs is observed data matrix
  # mis is truth data matrix but masked
  # missing_row is indicator of whether a row has missing values
  # missing_col is indicator of whether a col has missing values
  np.random.seed(seed)
  data=np.zeros((n,p))
  data[:,0]=np.random.normal(size=(n,),scale=1)        
  for col in range(1,p):
      data[:,col]=rho*data[:,col-1]+np.random.normal(size=(n,),scale=sigma1)

  # split to observed col and missing col
  missing_col1=np.array(miss_pat1*int(p/len(miss_pat1)))
  missing_col2=np.array(miss_pat2*int(p/len(miss_pat2)))
  obs=data[:,np.logical_or(missing_col1,missing_col2)==0]
  mis1=data[:,missing_col1==1]
  mis2=data[:,missing_col2==1]
  truth=np.concatenate((obs,mis1.copy(),mis2.copy()),axis=1)

  missing_row1=np.zeros((n,))
  missing_row2=np.zeros((n,))
  p_miss1 = 0.7
  p_miss2 = 0.6
  for i in range(n):
    missing_row1[i]= np.random.choice(2, 1, p=[1-p_miss1,p_miss1])
    missing_row2[i]= np.random.choice(2, 1, p=[1-p_miss2,p_miss2])
  mis1[missing_row1==1]=np.nan
  mis2[missing_row2==1]=np.nan
  return np.concatenate((obs,mis1,mis2),axis=1), truth

# generate a four-pattern missing data

In [6]:
input, truth = data_G_linear(n=200, p=250)

# conduct imputation

In [7]:
imp1 = MI_NNGP1(input.copy())
imp2 = MI_NNGP2(input.copy())
mse1 = ((imp1-truth)**2).mean(axis=None) 
mse2 = ((imp2-truth)**2).mean(axis=None)
print('minngp1 mse',mse1)
print('minngp2 mse',mse2)



finish initial imputation!
finish epoch 0!
finish epoch 1!
finish epoch 2!
minngp1 mse 0.016245352536295866
minngp2 mse 0.01369883496306891
