In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
import time
import winsound as ws
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random, os, winsound
from collections import Counter
from datetime import datetime
from numpy.linalg import slogdet, inv
from scipy import optimize
from scipy.linalg import block_diag
from sklearn.metrics.pairwise import rbf_kernel

from scipy.spatial.distance import pdist, squareform
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow.keras import initializers
from tensorflow.keras.layers import Dense, Input, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, Callback
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam

In [None]:
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    tf.random.set_seed(seed)
    
class config:
    seed = 42
    device = "cuda:0"  

In [None]:
def mean_model(X):
    # make fixed part of the mean model (=marginal mean)
    s = np.sum(X, axis=1)
    mu = s * np.cos(s) + 2*X[:,0]*X[:,1]
    return mu

In [None]:
def make_mean_model():

    input_X  = Input(shape=(np.shape(X_train)[1],),  dtype='float32')
    input_Z1 = Input(shape=(np.shape(Z1_train)[1],), dtype='float32')
    input_Z2 = Input(shape=(np.shape(Z2_train)[1],), dtype='float32')
    
    m   = Dense(100, activation='relu')(input_X)
    m   = Dense(50, activation='relu')(m)
    m   = Dense(25, activation='relu')(m)
    m   = Dense(12, activation='relu')(m)
    
    xb  = Dense(1, activation='linear')(m)
    zv1 = Dense(1, activation='linear', use_bias=False)(input_Z1)
    zv2 = Dense(1, activation='linear', use_bias=False)(input_Z2)
    # zv1 = Dense(1, activation='linear', use_bias=False, kernel_initializer=initializers.Zeros())(input_Z1)
    # zv2 = Dense(1, activation='linear', use_bias=False, kernel_initializer=initializers.Zeros())(input_Z2)

    mean_model = Model(inputs=[input_X, input_Z1, input_Z2], outputs=[xb, zv1, zv2])
    return mean_model

def make_marginal_model():

    input_X  = Input(shape=(np.shape(X_train)[1],),  dtype='float32')
    input_Z1 = Input(shape=(np.shape(Z1_train)[1],), dtype='float32')
    input_Z2 = Input(shape=(np.shape(Z2_train)[1],), dtype='float32')
    
    m   = Dense(100, activation='relu')(input_X)
    m   = Dense(50, activation='relu')(m)
    m   = Dense(25, activation='relu')(m)
    m   = Dense(12, activation='relu')(m)
    
    xb  = Dense(1, activation='linear')(m)
    
    mean_model = Model(inputs=[input_X, input_Z1, input_Z2], outputs=[xb])
    return mean_model

In [None]:
def hlik_loss_mean(mean_model, mean_inputs, y, phi, lam1, lam2, batch_ratio=1.):    
    # batch_ratio = batch_size / N_train (= 1 for full batch or validation)
    
    mu = K.transpose(K.sum(mean_model(mean_inputs), axis=0))
    loss = K.sum(K.square(y - mu)) / phi / np.shape(y)[0]
    
    weights1 = mean_model.weights[-1]
    weights2= mean_model.weights[-2]
    loss += (K.sum(K.square(weights1))/lam1 + K.sum(K.square(weights2))/lam2) * batch_ratio / np.shape(y)[0]
    
    return loss

def update_mean_params(mean_model, mean_inputs, y, loss_ftn, optimizer, phi, lam1, lam2, batch_ratio=1.):  
    with tf.GradientTape() as tape:
        loss = loss_ftn(mean_model, mean_inputs, y, phi, lam1, lam2, batch_ratio)        
    gradients = tape.gradient(loss, mean_model.trainable_weights)
    optimizer.apply_gradients(zip(gradients, mean_model.trainable_weights))
    return loss

def train_mean_one_epoch(mean_model, train_batch, loss_ftn, optimizer, phi, lam1, lam2, batch_ratio=1.):    
    losses = [] 
    for step, (X_batch, Z1_batch, Z2_batch, y_batch) in enumerate(train_batch):        
        loss = update_mean_params(mean_model, [X_batch, Z1_batch, Z2_batch], y_batch, loss_ftn, optimizer, phi, lam1, lam2, batch_ratio)
        losses.append(loss)      
    return losses

def hlik_loss(mean_model, mean_inputs, y, phi, lam1, lam2, batch_ratio=1.):    
    # batch_ratio = batch_size / N_train (= 1 for full batch or validation)
    
    mu = K.transpose(K.sum(mean_model(mean_inputs), axis=0))
    loss = K.sum(K.square(y - mu)) / phi / np.shape(y)[0]
    
    weights1 = mean_model.weights[-1]
    weights2= mean_model.weights[-2]
    loss += (K.sum(K.square(weights1))/lam1 + K.sum(K.square(weights2))/lam2) * batch_ratio / np.shape(y)[0]
    
    ZZ = np.float32(np.concatenate(([mean_inputs[1], mean_inputs[2]]), axis=1))
    DD = np.diag(np.repeat([lam1, lam2],q))
    loss += (tf.linalg.slogdet(K.transpose(ZZ)@ZZ@DD+phi*tf.eye(Q))[1]+(np.shape(y)[0]-2*q)*K.log(phi)) / np.shape(y)[0]
    
    return loss

def update_hlik(mean_model, mean_inputs, y, loss_ftn, optimizer, phi, lam1, lam2, batch_ratio=1.):  
    with tf.GradientTape() as tape:
        loss = loss_ftn(mean_model, mean_inputs, y, phi, lam1, lam2, batch_ratio)        
    gradients = tape.gradient(loss, (mean_model.trainable_weights + [phi, lam1, lam2]))
    optimizer.apply_gradients(zip(gradients, (mean_model.trainable_weights + [phi, lam1, lam2])))
    return loss

def train_hlik_one_epoch(mean_model, train_batch, loss_ftn, optimizer, phi, lam1, lam2, batch_ratio=1.):    
    losses = [] 
    for step, (X_batch, Z1_batch, Z2_batch, y_batch) in enumerate(train_batch):        
        loss = update_hlik(mean_model, [X_batch, Z1_batch, Z2_batch], y_batch, loss_ftn, optimizer, phi, lam1, lam2, batch_ratio)
        losses.append(loss)      
    return losses

In [None]:
def update_disp_mle(ehat, v1, v2, phi, lam1, lam2, ZTZ):
    
    a0, a1, a2 = np.log(phi), np.log(lam1), np.log(lam2)
    D1 = np.diag(np.repeat([np.exp(a1), 0], q))
    D2 = np.diag(np.repeat([0, np.exp(a2)], q))
    
    Ainv = inv(ZTZ@np.diag(np.repeat([np.exp(a1), np.exp(a2)],q)) + np.eye(Q)*np.exp(a0))
    A1, A2 = Ainv@ZTZ@D1, Ainv@ZTZ@D2
    
    ete, vtv1, vtv2 = np.sum(ehat**2), np.sum(v1**2), np.sum(v2**2)
    
    ddetda0 = np.exp(a0)*np.trace(Ainv)
    ddetda1, ddetda2 = np.trace(A1), np.trace(A2)
    
    grad0 = -N +Q + np.exp(-a0)*ete + ddetda0
    grad1 = np.exp(-a1)*vtv1 + ddetda1
    grad2 = np.exp(-a2)*vtv2 + ddetda2
    
    hess0 = - np.exp(-a0)*ete  + ddetda0 - np.exp(2*a0)*np.trace(Ainv@Ainv)
    hess1 = - np.exp(-a1)*vtv1 + ddetda1 - np.trace(A1@A1)
    hess2 = - np.exp(-a2)*vtv2 + ddetda2 - np.trace(A2@A2)
    
    a0 += - grad0/hess0
    a1 += - grad1/hess1
    a2 += - grad2/hess2
    
    diff = np.abs(grad0/hess0+grad1/hess1+grad2/hess2)
    
    return(np.exp(a0), np.exp(a1), np.exp(a2), diff)

def find_disp_mle(y, mu, v1, v2, phi_init, lam1_init, lam2_init, ZTZ, tolorence=1e-3, max_iter=100):
    
    ehat = y - mu
    phi, lam1, lam2 = phi_init, lam1_init, lam2_init
    for iteration in range(max_iter):
        phi, lam1, lam2, diff = update_disp_mle(ehat, v1, v2, phi, lam1, lam2, ZTZ)
        if np.abs(diff)<tolorence:
            break
    return(phi, lam1, lam2)

In [None]:
def marginal_loss(mean_model, mean_inputs, y, phi, lam1, lam2):    
    # batch_ratio = batch_size / N_train (= 1 for full batch or validation)
    
    fX = K.transpose(mean_model(mean_inputs))
    ZZ = np.float32(np.concatenate(([mean_inputs[1], mean_inputs[2]]), axis=1))
    DD = (lam1*np.diag(np.repeat([1,0],q)) + lam2*np.diag(np.repeat([0,1],q)))
    
    loss = K.sum(K.square(y - fX)) / phi / np.shape(y)[0]    
    loss -= ((y-fX)@ZZ@(lam1*np.diag(np.repeat([1,0],q)) + lam2*np.diag(np.repeat([0,1],q)))@tf.linalg.inv(K.transpose(ZZ)@ZZ@(lam1*np.diag(np.repeat([1,0],q)) + lam2*np.diag(np.repeat([0,1],q)))+phi*tf.eye(Q))@K.transpose(ZZ)@K.transpose(y-fX)/phi)[0,0] / np.shape(y)[0]
    loss += (tf.linalg.slogdet(K.transpose(ZZ)@ZZ@DD+phi*tf.eye(Q))[1]+(np.shape(y)[0]-2*q)*K.log(phi)) / np.shape(y)[0]
    
    return loss

def update_marginal(mean_model, mean_inputs, y, phi, lam1, lam2, loss_ftn, optimizer):  
    with tf.GradientTape() as tape:
        loss = loss_ftn(mean_model, mean_inputs, y, phi, lam1, lam2)
    gradients = tape.gradient(loss, (mean_model.trainable_weights + [phi, lam1, lam2]))
    optimizer.apply_gradients(zip(gradients, (mean_model.trainable_weights + [phi, lam1, lam2])))
    return loss

def train_marginal_one_epoch(mean_model, train_batch, phi, lam1, lam2, loss_ftn, optimizer):    
    losses = [] 
    for step, (X_batch, Z1_batch, Z2_batch, y_batch) in enumerate(train_batch):        
        loss = update_marginal(mean_model, [X_batch, Z1_batch, Z2_batch], y_batch,  phi, lam1, lam2, loss_ftn, optimizer)
        losses.append(loss)      
    return losses

In [None]:
n_simul = 20
N = 10000
K_num = 2
p, q = 10, 100
Q = q*K_num
lam, sig2 = .5, .5
phi_init, lam1_init, lam2_init = 1., 1., 1.

batch_size = 1024
max_epochs = 100
patience = 5

mean_lr = 0.002
mean_optimizer = Adam(learning_rate = mean_lr)

In [None]:
time_marginal = np.zeros((n_simul, max_epochs))
time_hlik = np.zeros((n_simul, max_epochs))

train_mse_marginal = np.zeros((n_simul, max_epochs))
valid_mse_marginal = np.zeros((n_simul, max_epochs))

train_mse_hlik = np.zeros((n_simul, max_epochs))
valid_mse_hlik = np.zeros((n_simul, max_epochs))

phi_marginal = np.zeros((n_simul, max_epochs))
lam1_marginal = np.zeros((n_simul, max_epochs))
lam2_marginal = np.zeros((n_simul, max_epochs))

In [None]:
for repeat in range(n_simul):
    
    K.clear_session()
    seed_everything(repeat)
    
    D = [ lam*np.identity(q) for k in range(K_num)]
    v = [ np.random.multivariate_normal(np.zeros(q), D[k], 1)[0] for k in range(K_num) ]
    z = [ np.random.choice(range(q), size=N, replace=True) for k in range(K_num) ]
    epsilon = np.random.normal(0, np.sqrt(sig2), N)

    X = np.random.uniform(-1, 1, (N,p))
    fX = mean_model(X)

    Z = [ pd.get_dummies(z[k]) for k in range(K_num)]
    Z1 = np.float32(Z[0])
    Z2 = np.float32(Z[1])

    y = np.float32(fX + epsilon + sum([Z[k]@v[k] for k in range(K_num)]) )

    X_train, X_valid, y_train, y_valid, Z1_train, Z1_valid, Z2_train, Z2_valid = train_test_split(
        X, y, Z1, Z2, test_size=0.2, random_state=42)

    N_train, N_valid = np.shape(y_train)[0], np.shape(y_valid)[0]
    batch_ratio = batch_size/N_train
    marginal_batch = tf.data.Dataset.from_tensor_slices((X_train, Z1_train, Z2_train, y_train)).shuffle(N_train).batch(batch_size)
    hlik_batch = tf.data.Dataset.from_tensor_slices((X_train, Z1_train, Z2_train, y_train)).shuffle(N_train).batch(batch_size)

    # make marginal model
    M = make_marginal_model()
    phi  = tf.Variable(phi_init,  name='phi',  trainable=True, constraint=lambda x: tf.clip_by_value(x, 1e-18, np.infty))
    lam1 = tf.Variable(lam1_init, name='lam1', trainable=True, constraint=lambda x: tf.clip_by_value(x, 1e-18, np.infty))
    lam2 = tf.Variable(lam2_init, name='lam2', trainable=True, constraint=lambda x: tf.clip_by_value(x, 1e-18, np.infty))
    
    patience_marginal = 0
    min_marginal_val_loss = np.infty
    temp_start = time.time()
    for epoch in range(max_epochs):
        
        if epoch!=0:        
            marginal_train_loss = train_marginal_one_epoch(M, marginal_batch, phi, lam1, lam2, marginal_loss, mean_optimizer)
            marginal_val_loss = marginal_loss(M, [X_valid, Z1_valid, Z2_valid], y_valid, phi, lam1, lam2)
          
        fX_train, fX_valid = M([X_train, Z1_train, Z2_train]), M([X_valid, Z1_valid, Z2_valid])
        
        ZZ = np.float32(np.concatenate(([Z1_train, Z2_train]), axis=1))
        DD = np.diag(np.repeat([lam1, lam2],q))
        
        v_pred = DD@ZZ.T@(tf.eye(N_train) - ZZ@DD@inv(ZZ.T@ZZ@DD+phi*tf.eye(Q))@ZZ.T)@(y_train.reshape(N_train,1)-fX_train)/phi
        v1_pred, v2_pred = v_pred[:q], v_pred[q:]
        
        mu_train = fX_train + Z1_train@v1_pred + Z2_train@v2_pred
        mu_valid = fX_valid + Z1_valid@v1_pred + Z2_valid@v2_pred
        
        train_mse_marginal[repeat, epoch] = (np.sum(np.square(y_train.reshape(N_train,1)-mu_train))/N_train)
        valid_mse_marginal[repeat, epoch] = (np.sum(np.square(y_valid.reshape(N_valid,1)-mu_valid))/N_train)
        time_marginal[repeat, epoch] = (time.time() - temp_start)
    
    H = make_mean_model()
    phi, lam1, lam2 = phi_init, lam1_init, lam2_init
    
    patience_hlik = 0
    min_hlik_val_loss = np.infty
    temp_start = time.time()    
    for epoch in range(max_epochs):
        
        if epoch!=0:  
            hlik_train_loss = train_mean_one_epoch(H, hlik_batch, hlik_loss_mean, mean_optimizer, phi, lam1, lam2, batch_ratio)
            hlik_val_loss = hlik_loss_mean(H, [X_valid, Z1_valid, Z2_valid], y_valid, phi, lam1, lam2, batch_ratio=(0.2/0.8))
            
            v1, v2 = H.get_weights()[-2], H.get_weights()[-1]
            phi, lam1, lam2 = np.var(y_train-mu_train), np.var(v1), np.var(v2)
        
        mu_train = np.sum(H([X_train, Z1_train, Z2_train]),axis=0).T
        mu_valid = np.sum(H([X_valid, Z1_valid, Z2_valid]),axis=0).T
        
        train_mse_hlik[repeat, epoch] = (np.sum(np.square(y_train-mu_train))/N_train)
        valid_mse_hlik[repeat, epoch] = (np.sum(np.square(y_valid-mu_valid))/N_valid)
        time_hlik[repeat, epoch] = (time.time() - temp_start)

In [None]:
plt.figure(figsize=(10,5), dpi=300)

plt.subplot(121)
for repeat in range(n_simul):
    if repeat==0:
        plt.plot(time_marginal[repeat,], train_mse_marginal[repeat,], linestyle='dashed', color='orangered', dashes=(10, 5), linewidth=0.5, label='integrated') 
        plt.plot(time_hlik[repeat,], train_mse_hlik[repeat,], linestyle='solid', color='royalblue', linewidth=0.3, label='h-likelihood')
    else:
        plt.plot(time_marginal[repeat,], train_mse_marginal[repeat,], linestyle='dashed', color='orangered', dashes=(10, 5), linewidth=0.5) 
        plt.plot(time_hlik[repeat,], train_mse_hlik[repeat,], linestyle='solid', color='royalblue', linewidth=0.3)
plt.xlim([0, 10])
plt.xlabel('Time (sec)', fontsize=12)
plt.ylabel('Train MSE', fontsize=12)
plt.legend(loc='upper right', ncol=1, fontsize=12)
plt.title('Train MSE')

plt.subplot(122)
for repeat in range(n_simul):
    if repeat==0:
        plt.plot(time_marginal[repeat,], valid_mse_marginal[repeat,], linestyle='dashed', color='orangered', dashes=(10, 5), linewidth=0.5, label='integrated') 
        plt.plot(time_hlik[repeat,], valid_mse_hlik[repeat,], linestyle='solid', color='royalblue', linewidth=0.3, label='h-likelihood')
    else:
        plt.plot(time_marginal[repeat,], valid_mse_marginal[repeat,], linestyle='dashed', color='orangered', dashes=(10, 5), linewidth=0.5) 
        plt.plot(time_hlik[repeat,], valid_mse_hlik[repeat,], linestyle='solid', color='royalblue', linewidth=0.3) 
plt.xlim([0, 10])
plt.xlabel('Time (sec)', fontsize=12)
plt.ylabel('Validation MSE', fontsize=12)
plt.legend(loc='upper right', ncol=1, fontsize=12)
plt.title('Validation MSE')

plt.show()