저는 통계 모델을 Tensorflow 기반 CNN 모델에 접목해 정확도를 높힌 경험이 있으며,  
이를 이용해 geo 데이터 및 이미지 데이터 분류 모델 관련 프로젝트에 기여할 수 있으리라 생각합니다.  
지리 데이터 또한 특정 좌표상 통계값이 들어간 데이터이므로 이미지 데이터와 본질적으로는 다른 것이 없기 때문입니다.  
따라서 긴 자기소개서보다는 제 프로젝트 소개 및 해당 코드를 첨부합니다.  
제가 제작한 시스템은 간단히 CNN 모델의 마지막 부분(Fully Connected Layer)에 일반선형모형을 SGM으로 배우게 하는 대신  
통계적 구조와 그 학습 절차를 사용하도록 만든 것입니다. 입력값이 CNN의 Feature 추출을 통해 step마다 다르게 만들어지므로  
기존 통계 모델처럼 모집단을 예측할 수 있는 건 아니지만, 일반 선형 모형보다 데이터 구조 반영이 가능하고  
학습이 시작되기 전 추가적인 통계 데이터도 입력 가능하므로 추가된 정보량만큼 분류율 개선을 보였습니다.  


## Multivariate Longitudinal Convolution Neural Network

신경망 모델을 선형/비선형적 통계 모델 개념으로 설명하려는 시도는 예전부터 있어왔다.  
그러나 모델의 학습 및 적용에 있어 두 분야는 독립적으로 진행되었다.  
신경망의 경우 비구조적인 데이터를 바탕으로 하고, 통계 모델의 경우 구조적 데이터를 사용하기 때문이다.  
하지만 비구조적인 데이터라 하더라도 한정된 샘플로부터 반복된 데이터를 채취하는 종단 데이터의 경우  
각 데이터별 독립성이 성립하지 않는다. 이러한 구조적 특성을 지닌 비구조적 데이터는 단순 CNN 구조 및  
Gradient Descent 기반 학습으로 접근할 경우 그 정확도가 떨어지며, 따라서 새로운 접근 방법이 필요하다.  
본 코드는 신경망 모형 및 선형 모형의 구조 및 학습방법의 융합을 통해 경시적 특성을 지닌 이미지 데이터의  
예측 정확도 향상을 그 목표로 한다. 학습 절차는 CNN 학습과 해밀토니안 몬테카를로 프로시져(이하 HMC)를  
하나의 절차로 반복하는 방식을 통해 구현되었다. CNN 파트에서는 Subject, Time, Group에 영향받지 않는 효과(Fixed effect)에  
대한 학습이 이루어지며, Subject, Time, Group의 효과(Random Effects)의 경우 HMC파트에서 학습이 이루어진다.  
Subject, Time, Group의 교호작용인 Y의 공분산은 MCD 및 Hypersphere Decomposition을 이용해 구한다.  
시뮬레이션 데이터에 대한 학습 결과 기존 CNN 구조에 비해 더 정확한 결과를 얻을 수 있었다.


In [None]:
###################################################
# MULTILONGCNN MODEL : COMBINE CNN AND MULTIVARIATE LONGITUDONAL MODEL
# CODE BLOCK INDEX
# 1. BASE INFORMATION
# 1-1. LIBRARY LOAD
# 1-2. PARAMETER DESIGNATION
# 1-3. LOAD AND CREATE BASE DATASET
#
# 2. STATISTICAL MODEL CREATION(CREATE Dit)
# 2-1. SET CONDITIONAL INDEX FOR Dit FUNCTION
# 2-2. CREATE AND DISTRIBUTE "F" FOR Dit FUNCTION
# 2-3. CREATE Dit FUNCTION
#      2-3-1. CREATE F_TMP FOR F
#      2-3-2. CALCULATE F WITH F_TMP AND 2-2 ALONG WITH MATRIX LOCATION INFO
#      2-3-3. CREATE F USING SparseMatrix
#      2-3-4. CALCULATE Dit USING F, Ri, Si
#
# 3. STATISTICAL MODEL CREATION(CREATE Y ACCORDING TO STAT MODEL)
# 3-1. CREATE COVARIANCE MATRIX
# 3-2. CREATE RESIDUAL(FOLLOWS NORMAL, COV:3-1)
# 3-3. CREATE ID/TIME/Y_DIM SPECIFIC EFFECT
# 3-4. CREATE Y ELEMENT CREATION FUNCTION USING 3-1, 3-2, 3-3. USE SparseMatrix METHOD
# 3-5. CREATE Y
#
# 4. CONSTRUCT ARCHITECTURE(A BIT MESSY. NEED OPTIMIZATION)
# 4-1. DESIGNATE CNN MODEL, LOSS FUNCTION, OPTIMIZER AND OTHER PARAMS FOR CNN
# 4-2. READ AND CREATE IMAGE(ONLY FOR DEMO. SHOULD BE ADJUSTED FOR REAL DATA)
#      4-2-1.READ IMAGES FROM FOLDER. FILE NAME CONTAINS I AND T INFO
#      4-2-2. CREATE TRUE Y ACCORDING TO ANSWER MODEL. THIS SHOULD BE ADJUSTED WHEN TREATING REAL DATA
# 4-3. SPECIFY FUNCTION FOR CNN AND BETA UPDATE
#      4-3-1. USE IMAGE(x_image) AND Y ONLY AFFECTED BY BETA(y_fixed)
#      4-3-2. UPDATE CNN AND BETA USING INPUTS
#      4-3-3. RETURN POST FORWARDED X VALUE(x_input) FROM IMAGES(x_image)
#             AND BETA
# 4-4. SET BETA, CNN LEARNING PROCEDURE
#      4-4-1. GET POST FORWARDED X VALUE(x_input) FROM IMAGES(x_image)
#      4-4-2. CONCATENATE X INPUTS WHICH IS NOT X IMAGE INPUT(meta_data)
#      4-4-3. FROM GIVEN LAMBDA, ALPHA, DELTA, AND NU, GET FIXED EFFECT Y VALUE
#             THIS THEROETICALLY CREATE Y ONLY AFFECTED FROM BETA
#             (THIS IS ONLY ESTIMATED VALUE SINCE PARAMETERS ARE ALL ESTIMATED VALUE)
#      4-4-4. UPDATE BETA AND CNN(CALCULATION OF x_input) N TIMES USING 4-3
# 4-5. CREATE TOTAL LEARNING PROCEDURE
#      4-5-1. SET MCMC TO TRAIN PARAMETERS EXCEPT BETA
#      4-5-2. REPEAT LEARNING BETA PROCEDURE(4-4) AND MCMC PROCEDURE(4-5-1)
#      4-5-*. 4-5-2 STEP IS UNNECESSARY ONCE TENSORFLOW'S MEMORY LEAK PROBLEM IS SOLVED
#################################################################




import collections
import matplotlib.pyplot as plt
import numpy as np
import warnings
import datetime

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

# import tensorflow_datasets as tfds
import tensorflow_probability as tfp
import math


tfk = tf.keras
tfkl = tf.keras.layers
tfpl = tfp.layers
tfd = tfp.distributions
tfb = tfp.bijectors
date = datetime.date.today()
today = date.strftime("%Y%m%d")

warnings.simplefilter('ignore')



from random import choices


plt.style.use("ggplot")
warnings.filterwarnings('ignore')

# set parameters
N = 100     # number of samples
T = 10      # repeated number per sample
K = 3       # output dimension
P = 128     # dimension of model parameters

# parameter of each parameter's distribution
sigma2_beta = 1e5
sigma2_alpha = 1e5
sigma2_lambda = 1e5
sigma2_nu = 1e5
p_kgm = 0.5

num_chains = 4
num_results = 2000
num_burnin_steps = 500
tf.keras.backend.set_floatx('float64')

x_pre = tf.Variable(tf.random.normal([N, T, 125], dtype=tf.float64), trainable=True)

group_indices = tf.cast(np.array(choices([0, 1], k=N) * T).reshape(T, N).T, dtype=tf.float64)
group_indices = tf.Variable(tf.reshape(group_indices, shape=[N, T, 1]))
time_ij = tf.random.normal([N, T, 1], dtype=tf.float64)
group_time = group_indices * time_ij

y = tf.Variable(tf.zeros([N, T, K], dtype=tf.float64), trainable=True)

beta_1_unif = beta_2_unif = beta_3_unif = np.zeros(P)

for i in range(0,P):
    if i == 0:
        sample = tf.cast(np.random.uniform(0.1, 0.3, 3), dtype=tf.float64)
        beta_1_unif[i] = sample[0]
        beta_2_unif[i] = sample[1]
        beta_3_unif[i] = sample[2]
    elif i < 20:
        beta_1_unif[i] = tf.cast(np.random.uniform(-1, -0.8, 1), dtype=tf.float64)
        beta_2_unif[i] = tf.cast(np.random.uniform(-0.9, -0.5, 1), dtype=tf.float64)
        beta_3_unif[i] = tf.cast(np.random.uniform(-0.8, -0.5, 1), dtype=tf.float64)
    else:
        beta_1_unif[i] = tf.cast(np.random.uniform(0.5, 0.6, 1), dtype=tf.float64)
        beta_2_unif[i] = tf.cast(np.random.uniform(0.5, 0.7, 1), dtype=tf.float64)
        beta_3_unif[i] = tf.cast(np.random.uniform(0.7, 0.8, 1), dtype=tf.float64)

beta_1 = np.array(beta_1_unif).reshape(1, P)
beta_2 = np.array(beta_2_unif).reshape(1, P)
beta_3 = np.array(beta_3_unif).reshape(1, P)


beta_all = tf.Variable(np.concatenate((beta_1.T, beta_2.T, beta_3.T), axis=1), dtype=tf.float64, trainable=False)
alpha = tf.Variable(tf.cast(tf.reshape([.3, .4, .1, .1, .3, .1, .1, .3, .2], shape=[K, K]), dtype=tf.float64), trainable=False)
lambd = tf.Variable(tf.cast(tf.reshape([0.1, 0.2, 0.2, -0.1, 0.1, 0.1], shape=[K, 2]), dtype=tf.float64), trainable=False)
nu = tf.Variable(tf.cast([-0.5, -0.4, -0.3], dtype=tf.float64), trainable=False)

#optional delta value
true_delta = tf.ones_like(alpha, dtype=tf.float64)

#saved loss value of cnn model to prevent error
model_loss = np.inf


#End pre-setting parameters
######################################################################################################################
# create Dit to Estimate error of y

def combine(x, y):
  xx, yy = tf.meshgrid(x, y, indexing='ij')
  return tf.cast(tf.stack([tf.reshape(xx, [-1]), tf.reshape(yy, [-1])], axis=1), dtype=tf.float64)

# get every index combination of Dit matrix
comb_lm = combine(tf.range(0,K,1), tf.range(0,K,1))

# condition 1 = [0,0]
lm_cond_1 = tf.reduce_all(tf.equal(comb_lm, tf.constant([0,0], dtype=tf.float64)), axis=1)
lm_idx_1 = tf.cast(tf.boolean_mask(comb_lm, lm_cond_1), dtype=tf.int32)

#condition 2 = [1:, ]
lm_cond_2 = tf.logical_and(tf.equal(tf.constant([0], dtype=tf.float64), comb_lm[:,1]),
                           tf.less(tf.constant([0], dtype=tf.float64), comb_lm[:,0]))
lm_idx_2 = tf.cast(tf.boolean_mask(comb_lm, lm_cond_2), dtype=tf.int32)

lm_cond_3_1 = tf.logical_and(tf.less(comb_lm[:,1], comb_lm[:,0]),
                             tf.less(tf.constant([0], dtype=tf.float64), comb_lm[:,1]))
lm_cond_3_2 = tf.logical_and(lm_cond_3_1, tf.less(comb_lm[:,1], K-1))
lm_idx_3 = tf.cast(tf.boolean_mask(comb_lm, lm_cond_3_2), dtype=tf.int32)

lm_cond_4_1 = tf.equal(comb_lm[:,1], comb_lm[:,0])
lm_cond_4_2 = tf.logical_and(lm_cond_4_1, tf.less(tf.constant([0], dtype=tf.float64), comb_lm[:,1]))
lm_idx_4 = tf.cast(tf.boolean_mask(comb_lm, lm_cond_4_2), dtype=tf.int32)


@tf.function
def assign_lm(lm_i, F_tmp, comb_idx, F):
    idx = tf.cast(comb_idx[lm_i], tf.int32)

    def updates_1(idx_0, idx_1, F_tmp):
        return tf.constant([1.], dtype=tf.float64)
    def updates_2(idx_0, idx_1, F_tmp):
        return tf.cast(tf.reshape(tf.cos(F_tmp[idx_0, 0]), shape=[1]), dtype=tf.float64)
    def updates_3(idx_0, idx_1, F_tmp):
        return tf.cast(tf.reshape(tf.cos(F_tmp[idx_0, idx_1]) * tf.reduce_prod(tf.sin(F_tmp[idx_0, 0:idx_1])), shape=[1]), dtype=tf.float64)
    def updates_4(idx_0, idx_1, F_tmp):
        return tf.cast(tf.reshape(tf.reduce_prod(tf.sin(F_tmp[idx_0, 0:idx_0])), shape=[1]), dtype=tf.float64)
    def updates_9(idx_0, idx_1, F_tmp):
        return tf.cast(tf.constant([0.], dtype=tf.float64), dtype=tf.float64)

    _cond_1 = tf.reduce_all(tf.equal(idx, lm_idx_1))
    _cond_2 = tf.reduce_any(tf.reduce_all(tf.equal(idx, lm_idx_2),1))
    _cond_3 = tf.reduce_any(tf.reduce_all(tf.equal(idx, lm_idx_3),1))
    _cond_4 = tf.reduce_any(tf.reduce_all(tf.equal(idx, lm_idx_4),1))

    F_elem = tf.cond(_cond_1,
                     true_fn= lambda: updates_1(idx[0], idx[1], F_tmp),
                     false_fn= lambda: tf.cond(_cond_2,
                                      true_fn= lambda: updates_2(idx[0], idx[1], F_tmp),
                                      false_fn= lambda: tf.cond(_cond_3,
                                                       true_fn=lambda: updates_3(idx[0], idx[1], F_tmp),
                                                       false_fn=lambda: tf.cond(_cond_4,
                                                                        true_fn=lambda: updates_4(idx[0], idx[1], F_tmp),
                                                                        false_fn=lambda: updates_9(idx[0], idx[1], F_tmp)))))

    F = tf.concat([F,F_elem], axis=0)

    return tf.add(lm_i, tf.ones(shape=(), dtype=tf.int32)), F_tmp, comb_idx, F




@tf.function
def Dit(i, nu, lambd, group_indices):
    K = lambd.shape[0]
    cK = tf.range(0, K, 1)
    cK_1, cK_0 = tf.meshgrid(cK, cK)
    F_tmp = tf.add(tf.gather(nu, tf.reshape(cK_0, shape=(9,))),
                   tf.gather(nu, tf.reshape(cK_1, shape=(9,))))
    F_tmp = tf.reshape(F_tmp, shape=[K, K])
    F_tmp = (tf.exp(F_tmp) * math.pi) / (1 + tf.exp(F_tmp))
    F = tf.constant([], dtype=tf.float64)

    lm_i = tf.zeros(shape=(), dtype=tf.int32)
    lm_c = lambda lm_i, F_tmp, comb_lm, F: tf.less(lm_i, comb_lm.shape[0])

    lm_i, _, _, F = tf.while_loop(lm_c, assign_lm, [lm_i, F_tmp, comb_lm, F],
                                  shape_invariants=[lm_i.get_shape(), F_tmp.get_shape(),
                                                    comb_lm.get_shape(), tf.TensorShape([None])])

    F = tf.SparseTensor(indices=tf.cast(comb_lm, dtype=tf.int64), values=F, dense_shape=[K, K])
    F = tf.sparse.to_dense(F)
    F = tf.cast(F, dtype=tf.float64)

    Ri = tf.linalg.matmul(F, F, transpose_b=True)
    Si_pre = tf.concat([tf.constant([1.], dtype=tf.float64), group_indices[i, 0]], axis=0)
    Si = tf.linalg.diag(tf.exp(tf.reduce_sum(tf.cast(lambd, dtype=tf.float64) * Si_pre, axis=1)))

    return tf.linalg.matmul(tf.linalg.matmul(Si, Ri), Si)

#End create Dit
#################################################################################################################




def create_covariance(idx, nu, lambd, group_indices):
    operator_1 = tf.linalg.LinearOperatorFullMatrix(tf.eye(10, dtype=tf.float64))
    operator_2 = tf.linalg.LinearOperatorFullMatrix(Dit(idx, nu, lambd, group_indices))
    operator = tf.linalg.LinearOperatorKronecker([operator_1, operator_2])
    return operator.to_dense()



@tf.function
def make_y_elem_2(i, t, k, x, beta, delta):
    result = tf.constant(0., dtype=tf.float64)
    g = tf.constant(0, dtype=tf.int32)
    c_1 = lambda g, result: tf.less(g, K)
    b_1 = lambda g, result: [g+1, tf.add(result, (delta[k, g] * alpha[k, g]) *
                                (y[i, t - 1, g] - tf.reduce_sum(x[i, t - 1, :] * beta[:, g])))]
    idx, y_elem_2 = tf.while_loop(c_1, b_1, [g, result], shape_invariants=[g.get_shape(), result.get_shape()])
    return y_elem_2



@tf.function
def assign_itk(i_itk, comb_itk, y, y_other, beta, delta, x):
    idx = tf.cast(comb_itk[i_itk], tf.int32)

    # since idx is i,t,k,g, and i,k is of interest, we use idx_0, idx_2
    def updates_1(idx_0, idx_1, idx_2):
        fixed_effect = tf.reduce_sum(x[idx_0, idx_1, :] * beta[:, idx_2])
        other_effect = tf.zeros_like(fixed_effect, dtype=tf.float64)
        return tf.stack([tf.add(fixed_effect, other_effect), other_effect], axis=0)

    def updates_2(idx_0, idx_1, idx_2):
        fixed_effect = tf.reduce_sum(x[idx_0, idx_1, :] * beta[:, idx_2])
        other_effect = make_y_elem_2(idx_0, idx_1, idx_2, x, beta, delta)
        return tf.stack([tf.add(fixed_effect, other_effect), other_effect], axis=0)

    _cond_1 = tf.equal(idx[1], tf.zeros(shape=(), dtype=tf.int32))

    y_elem = tf.cond(_cond_1,
                     true_fn= lambda: updates_1(idx[0], idx[1], idx[2]),
                     false_fn= lambda: updates_2(idx[0], idx[1], idx[2]))
    y_elem_all = tf.cast(tf.reshape(y_elem[0], shape=tf.TensorShape([1])), dtype=tf.float64)
    y_elem_other = tf.cast(tf.reshape(y_elem[1], shape=tf.TensorShape([1])), dtype=tf.float64)
    y = tf.concat([y, y_elem_all], axis=0)
    y_other = tf.concat([y_other, y_elem_other], axis=0)

    return tf.add(i_itk, tf.ones(shape=(), dtype=tf.int32)), comb_itk, y, y_other, beta, delta, x



# this function also returns [fixed_y, y] form result
# need to know what is exact shape of result
# need to specify how to get real value of y => result[0] : fixed_y / result[1] : y

@tf.function
def create_y(x, beta, lambd, nu, delta, alpha, group_indices, true_y):
    y_elem = tf.constant([], dtype=tf.float64)
    y_elem_other = tf.constant([], dtype=tf.float64)

    i = tf.range(0, N)
    t = tf.range(0, T)
    k = tf.range(0, K)

    ii, tt, kk = tf.meshgrid(i, t, k)
    comb_itk = tf.stack([tf.reshape(ii, [-1]), tf.reshape(tt, [-1]), tf.reshape(kk, [-1])], axis=1)


    itk_i = tf.zeros(shape=(), dtype=tf.int32)
    itk_c = lambda itk_i, comb_itk, y_elem, y_elem_other, beta, delta, x: tf.less(itk_i, comb_itk.shape[0])

    itk_i, _, y_elem, y_elem_other, _, _, _ = tf.while_loop(itk_c, assign_itk, [itk_i, comb_itk,
                                                                                y_elem, y_elem_other, beta,
                                                                                delta, x],
                                  shape_invariants=[itk_i.get_shape(), comb_itk.get_shape(),
                                                    tf.TensorShape([None]), tf.TensorShape([None]),
                                                    beta.get_shape(), delta.get_shape(),
                                                    x.get_shape()])
    #true_y - y_elem[:,1] : only fixed effect by true_y - y_other_effect => y_elem_fixed

    y_other = tf.SparseTensor(indices=tf.cast(comb_itk, dtype=tf.int64), values=y_elem_other, dense_shape=[N, T, K])
    y_other = tf.sparse.reorder(y_other)
    y_other = tf.sparse.to_dense(y_other)
    y_other = tf.cast(y_other, dtype=tf.float64)
    y_fixed = tf.subtract(true_y, y_other)
    y_fixed = tf.expand_dims(y_fixed, axis=0)

    y = tf.SparseTensor(indices=tf.cast(comb_itk, dtype=tf.int64), values=y_elem, dense_shape=[N, T, K])
    y = tf.sparse.reorder(y)
    y = tf.sparse.to_dense(y)
    y = tf.cast(y, dtype=tf.float64)
    y = tf.expand_dims(y, axis=0)


    return tf.concat([y_fixed, y], axis=0)



def create_y_est_block(predicted_x_input, betas_input, lambd, nu, true_delta, alpha, group_indices, y_answer):
    dist_block = []
    est_y = create_y(predicted_x_input, betas_input, lambd, nu, true_delta, alpha, group_indices, y_answer)[1]
    est_y = tf.reshape(est_y, shape=(N, T*K))
    for i in range(0, N):
        i_dist = tfd.MultivariateNormalFullCovariance(loc=est_y[i, :],
                                                      covariance_matrix=create_covariance(i, nu, lambd,
                                                                                          group_indices))
        dist_block.append(i_dist)
    return dist_block


#end of create y
#######################################################################################################################
#######################################################################################################################

# create beta estimation model(cnn)
# create cnn model
# first, create data feed for test
# then create cnn

from tensorflow.keras import Model


batch_size=20

main_input = tf.keras.Input(shape=(32,32,3), dtype=tf.float64, name='main_input')
Rn50 = tf.keras.applications.ResNet50(
    include_top=False, weights='imagenet', input_tensor=main_input, input_shape=(32,32,3),
    pooling=None, classes=1000)
mx = Rn50.output
mx_out = tf.keras.layers.Dense(125, use_bias=False)(mx)
mx_out = tf.keras.layers.Flatten(name='aux_output')(mx_out)

auxiliary_input = tf.keras.Input(shape=(3,), name='aux_input')
x_cct = tf.keras.layers.concatenate([mx_out, auxiliary_input])
mx2 = tf.keras.layers.Dense(128, name='main_output')(x_cct)
main_output = tf.keras.layers.Dense(3, name='main_output', use_bias=False)(x_cct)


Mymodel = Model(inputs={'main_input' : main_input, 'aux_input' : auxiliary_input},
                outputs={'main_output' : main_output, 'aux_output' : mx_out})

Mymodel.compile(optimizer=tf.keras.optimizers.Adam(),
                loss={'main_output' : 'mse', 'aux_output' : 'mse'}, loss_weights=[0.0, 1.0])


loss_history = []
fm_loss_history = []
beta_loss_history = []
x_loss_history = []




# read images

x_image = np.array([]).reshape((0,32,32,3))
idx_x = np.array([]).reshape((0,2))
for i in range(1, N+1):
    for t in range(1, T+1):
        image = tf.io.read_file('./images/image_i_{:04d}_t_{:04d}.jpg'.format(i, t))
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.convert_image_dtype(image, dtype=tf.float64)
        image = tf.image.resize(image, size=[32, 32])
        img_data_arr = np.reshape(image, (1, 32, 32, 3))
        img_idx = np.array([i-1,t-1]).reshape(1,2)

        x_image = np.vstack([x_image, img_data_arr])
        idx_x = np.vstack([idx_x, img_idx])

x_image = x_image/255.


# make meta data dataset
meta_data = tf.concat([group_indices, time_ij, group_time], axis=2)
meta_data_flatten = tf.reshape(meta_data, [-1,3])
x_input = tf.concat([x_pre, meta_data], axis=2)


# function that returns beta
# required input : y_fixed for answer comparison, image(=x) for input
# this needs i, t information as indexing method
# P: dimension of beta. number of image_output + number of meta_data
# in example case, meta_data is group_indicies, time_ij, group_time

dummy_x_answer = tf.reshape(x_pre, shape=(N*T, 125))

def train_beta(x_image, idx_x, meta_data, y_fixed_estimation, Mymodel, i_loss):
    #step1 : calculate beta fixed
    #step2 : get x, y_fixed
    #step3 : update gradient
    batch_size = 20

    y_fixed_flatten = tf.reshape(y_fixed_estimation, shape=(N*T, K))
    meta_data_flatten2 = tf.reshape(meta_data, shape=(N*T, 3))
    Mymodel.fit(x={'main_input' : x_image, 'aux_input' : meta_data_flatten2},
                y={'main_output' : y_fixed_flatten, 'aux_output' : dummy_x_answer},
                epochs=50, batch_size=60)


    beta = tf.constant(Mymodel.layers[-1].get_weights()[0])
    i_loss_prev = i_loss
    i_loss = Mymodel.evaluate({'main_input' : x_image, 'aux_input' : meta_data_flatten2},
                              {'main_output' : y_fixed_flatten, 'aux_output' : dummy_x_answer})[0]
    
    if i_loss > i_loss_prev and i_loss_prev < 0.5 :
        while i_loss > 10.:
            tf.print("execute additional training due to cnn loss spike")
            Mymodel.fit(x={'main_input' : x_image, 'aux_input' : meta_data_flatten2},
                        y={'main_output' : y_fixed_flatten, 'aux_output' : dummy_x_answer}, 
                        epochs=10, batch_size=60)
            i_loss = Mymodel.evaluate({'main_input' : x_image, 'aux_input' : meta_data_flatten2}, 
                                      {'main_output' : y_fixed_flatten, 'aux_output' : dummy_x_answer})[0]
    
                
    x_complete = Mymodel.predict({'main_input' : x_image, 'aux_input' : meta_data_flatten2})['aux_output']
    x_complete = tf.reshape(x_complete, shape=(N, T, 125))
    x_complete = tf.concat([x_complete, meta_data], axis=2)

    return i_loss, x_complete, beta



def x_beta_normal(x_image, meta_data, Mymodel,
                  lambd, nu, delta, alpha, y_answer, model_loss_prev):
    tf.print("initiating x_beta_normal")
    meta_data_flatten = tf.reshape(meta_data, [-1,3])
    meta_group_indices = tf.expand_dims(meta_data[:,:,0], axis=2)
    x_input_img_part = Mymodel.predict({'main_input' : x_image, 'aux_input' : meta_data_flatten})['aux_output']
    beta = Mymodel.layers[-1].weights[0]
    x_input_img_part = tf.reshape(x_input_img_part, [N, T, 125])
    x_input = tf.concat([x_input_img_part, meta_data], axis=2)
    #meta_data[:,:,0] means group_indices
    y_fixed_estimation = create_y(x_input, beta, lambd, nu, delta, 
                                  alpha, meta_group_indices, y_answer)[0]

    model_loss, x_post_forwarded, betas = train_beta(
        x_image, idx_x, meta_data, y_fixed_estimation, 
        Mymodel, model_loss_prev)

    return model_loss, x_post_forwarded, betas


y_answer = create_y(x_input, beta_all, lambd, nu, true_delta, alpha, group_indices, tf.zeros_like(y, dtype=tf.float64))
y_fixed_answer=y_answer[0]
y_answer = y_answer[1]



####################################################################################################################

#create total model procedure
def multivariate_longitudinal_cnn_model(predicted_x_input, meta_data, y_answer, betas_input):
  # create cnn updated beta distribution
  # since using y_est directly is best, JointDistributionNamesd prohibits citation cycle
  # which will create infinite citation error
  # so we create y again inside beta update
  return tfd.JointDistributionNamed(dict(
      lambd=    tfd.MultivariateNormalDiag(loc=tf.zeros(shape=[K, 2], dtype=tf.float64), scale_identity_multiplier=sigma2_lambda),  # lambd
      nu=       tfd.MultivariateNormalDiag(loc=tf.zeros(shape=[3], dtype=tf.float64), scale_identity_multiplier=sigma2_nu),  # nu
      alpha=    tfd.MultivariateNormalDiag(loc=tf.zeros(shape=[K, K], dtype=tf.float64), scale_identity_multiplier=sigma2_alpha),  # alpha
      y_est=    lambda alpha, nu, lambd: tfd.Blockwise(  # y
          create_y_est_block(predicted_x_input, betas_input, lambd, nu, true_delta, alpha,
                             tf.expand_dims(meta_data[:,:,0], axis=2), y_answer), dtype_override=tf.float64)
                             #meta_data[:,:,0] is group_indices
  ))




def sample_contextual_effects(num_results, num_burnin_steps, initial_state_input):
  """Samples from the hierarchical intercepts model."""

  hmc = tfp.mcmc.HamiltonianMonteCarlo(
      target_log_prob_fn=multivariate_longitudinal_cnn_log_prob,
      num_leapfrog_steps=2,
      step_size=0.003) #default 0.003

  initial_state = [
      tf.zeros([K, 2], dtype=tf.float64, name='init_lambd'),
      tf.zeros([3], dtype=tf.float64, name='init_nu'),
      tf.zeros([K, K], dtype=tf.float64, name='init_alpha')
  ]
  unconstraining_bijectors = [
      tfb.Identity(),  # lambd
      tfb.Identity(),  # nu
      tfb.Identity()  # alpha
  ]
  kernel = tfp.mcmc.TransformedTransitionKernel(
      inner_kernel=hmc, bijector=unconstraining_bijectors)

  samples, kernel_results = tfp.mcmc.sample_chain(
      num_results=num_results,
      num_burnin_steps=num_burnin_steps,
      current_state=initial_state_input,
      kernel=kernel,
      parallel_iterations=10)

  acceptance_probs = tf.reduce_mean(
      tf.cast(kernel_results.inner_results.is_accepted, tf.float64), axis=0)

  return samples, acceptance_probs


# !!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!
# THIS PROCEDURE SONSUMES LONG TIME TO OPERATE
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

for i in range(0, 45):
    print(i)
    if i == 0:
        lambd_i= tf.zeros([K, 2], dtype=tf.float64)
        nu_i= tf.zeros([3], dtype=tf.float64)
        alpha_i= tf.zeros([K, K], dtype=tf.float64)


    else:
        lambd_i= contextual_effects_samples[0][-1]
        nu_i= contextual_effects_samples[1][-1]
        alpha_i= contextual_effects_samples[2][-1]


    initial_state_inp = [
        tf.constant(lambd_i, shape=[K, 2], name='init_lambd'),
        tf.constant(nu_i, shape=[3], name='init_nu'),
        tf.constant(alpha_i, shape=[K, K], name='init_alpha'),
    ]

    model_loss, predicted_x_input, betas_input = x_beta_normal(x_image, meta_data, Mymodel,
                                                   lambd_i, nu_i, true_delta, alpha_i, y_answer, model_loss)


    def multivariate_longitudinal_cnn_log_prob(lambd, nu, alpha):
        """Computes joint log prob pinned at `log_radon`."""
        log_prob_list = multivariate_longitudinal_cnn_model(predicted_x_input,
                                                            meta_data, y_answer, betas_input).log_prob_parts(
            dict(lambd=lambd, nu=nu, alpha=alpha, y_est=tf.reshape(y_answer, shape=(N*T*K))))
        log_prob_list2 = [tf.reduce_sum(log_prob_list['lambd']),
                          tf.reduce_sum(log_prob_list['nu']),
                          tf.reduce_sum(log_prob_list['alpha']),
                          tf.reduce_sum(log_prob_list['y_est'])]
        result = tf.reduce_sum(log_prob_list2)
        tf.print(result)
        return result

    samples, acceptance_probs = sample_contextual_effects(
        num_results=50, num_burnin_steps=10, initial_state_input=initial_state_inp)

    MultivariateLongitudinalCnnModel = collections.namedtuple(
        'MultivariateLongitudinalCnnModel',
        ['lambd', 'nu', 'alpha'])

    print('Acceptance Probabilities: ', acceptance_probs.numpy())

    contextual_effects_samples = MultivariateLongitudinalCnnModel._make(samples)

    if i == 0:
        lambda_all_samples = contextual_effects_samples[0]
        nu_all_samples = contextual_effects_samples[1]
        alpha_all_samples = contextual_effects_samples[2]
    else:
        lambda_all_samples = tf.concat([lambda_all_samples, contextual_effects_samples[0]], axis=0)
        nu_all_samples = tf.concat([nu_all_samples, contextual_effects_samples[1]], axis=0)
        alpha_all_samples = tf.concat([alpha_all_samples, contextual_effects_samples[2]], axis=0)

    Mymodel.save_weights('ckpt_{}_{:03d}'.format(today, i))
    print(i)
    print("end")
    
Mymodel.save('./multilong_cnn_model_201016.h5')



# check whether parameters converge
import matplotlib.pyplot as plt

fig_lambda, axs_l = plt.subplots(3, 2, sharex=True)
axs_l[0, 0].plot(lambda_all_samples[:,0,0])
axs_l[0, 0].set_title("lambda[0,0]")
axs_l[0, 1].plot(lambda_all_samples[:,0,1])
axs_l[0, 1].set_title("lambda[0,1]")
axs_l[1, 0].plot(lambda_all_samples[:,1,0])
axs_l[1, 0].set_title("lambda[1,0]")
axs_l[1, 1].plot(lambda_all_samples[:,1,1])
axs_l[1, 1].set_title("lambda[1,1]")
axs_l[2, 0].plot(lambda_all_samples[:,2,0])
axs_l[2, 0].set_title("lambda[2,0]")
axs_l[2, 1].plot(lambda_all_samples[:,2,1])
axs_l[2, 1].set_title("lambda[2,1]")
fig_lambda.tight_layout()

fig_nu, axs_n = plt.subplots(3, sharex=True)
axs_n[0].plot(nu_all_samples[:,0])
axs_n[0].set_title("nu[0]")
axs_n[1].plot(nu_all_samples[:,1])
axs_n[1].set_title("nu[1]")
axs_n[2].plot(nu_all_samples[:,2])
axs_n[2].set_title("nu[2]")
fig_nu.tight_layout()


fig_alpha, axs_a = plt.subplots(3, 3, sharex=True)
axs_a[0, 0].plot(alpha_all_samples[:,0,0])
axs_a[0, 0].set_title("alpha[0,0]")
axs_a[0, 1].plot(alpha_all_samples[:,0,1])
axs_a[0, 1].set_title("alpha[0,1]")
axs_a[0, 2].plot(alpha_all_samples[:,0,2])
axs_a[0, 2].set_title("alpha[0,2]")
axs_a[1, 0].plot(alpha_all_samples[:,1,0])
axs_a[1, 0].set_title("alpha[1,0]")
axs_a[1, 1].plot(alpha_all_samples[:,1,1])
axs_a[1, 1].set_title("alpha[1,1]")
axs_a[1, 2].plot(alpha_all_samples[:,1,2])
axs_a[1, 2].set_title("alpha[1,2]")
axs_a[2, 0].plot(alpha_all_samples[:,2,0])
axs_a[2, 0].set_title("alpha[2,0]")
axs_a[2, 1].plot(alpha_all_samples[:,2,1])
axs_a[2, 1].set_title("alpha[2,1]")
axs_a[2, 2].plot(alpha_all_samples[:,2,2])
axs_a[2, 2].set_title("alpha[2,2]")
fig_alpha.tight_layout()



# save parameter samples
np.savez('sample_saves_201016', alpha_samples=alpha_all_samples.numpy(),
                        lambda_samples=lambda_all_samples.numpy(), nu_samples=nu_all_samples.numpy(),
                        x_pre = x_pre.numpy(), y_answer=y_answer.numpy(), beta_all=beta_all.numpy())