# Import Libraries

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import numpy as np
np.random.seed(530)
import warnings
warnings.filterwarnings('ignore')
from tqdm.auto import tqdm
from sklearn.preprocessing import OneHotEncoder, LabelEncoder

import tensorflow as tf
print(tf.__version__)
# tf.debugging.set_log_device_placement(True)
import keras
print(keras.__version__)
import tensorflow_probability as tfp
tfd = tfp.distributions
from tensorflow.python.client import device_lib
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

print("Available GPUs:", tf.config.list_physical_devices('GPU'))
physical_devices = tf.config.list_physical_devices('GPU')
for device in physical_devices:
    tf.config.experimental.set_memory_growth(device, True)
tf.config.set_soft_device_placement(True)
# GPU 메모리 증가 허용 설정
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
    except RuntimeError as e:
        print(e)

import tf2onnx

tf.random.set_seed(530)

## Data Load

In [None]:
mpii_path = "../data/UT/loocv/Fold_3"

# 데이터 로드 및 차원 축소
loocv_id_data_tr = np.load(os.path.join(mpii_path, "utm_fold_3_train_ids.npy")).flatten()
loocv_hps_data_tr = np.load(os.path.join(mpii_path, "utm_fold_3_train_2d_hps.npy")).reshape(-1, 2)
loocv_img_data_tr = np.load(os.path.join(mpii_path, "utm_fold_3_train_images.npy")).reshape(-1, 36, 60)
loocv_gzs_data_tr = np.load(os.path.join(mpii_path, "utm_fold_3_train_2d_gazes.npy")).reshape(-1, 2)

loocv_id_data_te = np.load(os.path.join(mpii_path, "utm_fold_3_test_ids.npy")).flatten()
loocv_hps_data_te = np.load(os.path.join(mpii_path, "utm_fold_3_test_2d_hps.npy")).reshape(-1, 2)
loocv_img_data_te = np.load(os.path.join(mpii_path, "utm_fold_3_test_images.npy")).reshape(-1, 36, 60)
loocv_gzs_data_te = np.load(os.path.join(mpii_path, "utm_fold_3_test_2d_gazes.npy")).reshape(-1, 2)

# 전체 ID 데이터 집합 생성
total_id_data = np.unique(np.concatenate([loocv_id_data_tr, loocv_id_data_te]))

# Label Encoder와 OneHot Encoder 초기화
lb_encoder = LabelEncoder()
oh_encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')

# 전체 ID에 대해 Label Encoding
lb_encoder.fit(total_id_data)  # Label encoding을 위한 전체 ID 학습

# 원-핫 인코딩 적용
oh_encoder.fit(lb_encoder.transform(total_id_data).reshape(-1, 1))

# 학습 데이터와 추론 데이터에 대한 ID 원-핫 인코딩
train_id_encoded = oh_encoder.transform(lb_encoder.transform(loocv_id_data_tr).reshape(-1, 1))
test_id_encoded = oh_encoder.transform(lb_encoder.transform(loocv_id_data_te).reshape(-1, 1))

# 이미지 데이터 정규화
train_img_normalized = loocv_img_data_tr / 255.0
test_img_normalized = loocv_img_data_te / 255.0

# 데이터셋 구성
train_data = ((train_img_normalized, loocv_hps_data_tr, train_id_encoded), loocv_gzs_data_tr)
test_data = ((test_img_normalized, loocv_hps_data_te, test_id_encoded), loocv_gzs_data_te)


# Model Architecture

## Image Regressor Class

In [None]:
class ImageRegressor(keras.Model):
    '''
    Simple 2D image regressor with 7 convolution blocks and 2 final dense layers.
    '''
    
    def __init__(self, name='regressor', **kwargs):
        """Simple 2D image regressor with 7 convolution blocks and 2 final dense layers.

        Args:
            name (str, optional): Model name. Defaults to 'regressor'.
        """        
        
        super(ImageRegressor, self).__init__(name=name, **kwargs)
        
        # 36 x 60
        self.conv0 = keras.layers.Conv2D(64, 3, padding='same', name='conv0')
        self.bn0 = keras.layers.BatchNormalization(name='bn0')
        self.elu0 = keras.layers.ELU(name='elu0')
        self.maxpool0 = keras.layers.MaxPool2D(name='maxpool0')
        # 18 x 30        
        self.conv1 = keras.layers.Conv2D(128, 3, padding='same', name='conv1')
        self.dropout1 = keras.layers.Dropout(0.5, name='dropout1')
        self.bn1 = keras.layers.BatchNormalization(name='bn1')
        self.elu1 = keras.layers.ELU(name='elu1')
        self.maxpool1 = keras.layers.MaxPool2D(name='maxpool1')
        # 9 x 15
        self.conv2 = keras.layers.Conv2D(128, 3, padding='same', name='conv2')
        self.bn2 = keras.layers.BatchNormalization(name='bn2')
        self.elu2 = keras.layers.ELU(name='elu2')
        self.maxpool2 = keras.layers.MaxPool2D(name='maxpool2')
        # 5 x 8
        self.conv3 = keras.layers.Conv2D(256, 3, padding='same', name='conv3')
        self.dropout3 = keras.layers.Dropout(0.5, name='dropout3')
        self.bn3 = keras.layers.BatchNormalization(name='bn3')
        self.elu3 = keras.layers.ELU(name='elu3')
        self.maxpool3 = keras.layers.MaxPool2D(name='maxpool3')
        # 3 x 4
        self.conv4 = keras.layers.Conv2D(256, 3, padding='same', name='conv4')
        self.bn4 = keras.layers.BatchNormalization(name='bn4')
        self.elu4 = keras.layers.ELU(name='elu4')
        self.maxpool4 = keras.layers.MaxPool2D(name='maxpool4')
        # 2 x 2
        self.conv5 = keras.layers.Conv2D(512, 3, padding='same', name='conv5')
        self.dropout5 = keras.layers.Dropout(0.5, name='dropout5')
        self.bn5 = keras.layers.BatchNormalization(name='bn5')
        self.elu5 = keras.layers.ELU(name='elu5')
        self.maxpool5 = keras.layers.MaxPool2D(padding='same', name='maxpool5')
        # # 1 x 1
        self.conv6 = keras.layers.Conv2D(512, 3, padding='same', name='conv6')
        self.elu6 = keras.layers.ELU(name='elu6')
        self.flatten = keras.layers.Flatten(name='flatten')
        self.dense = keras.layers.Dense(256, name='dense', activation='elu',
                                        kernel_regularizer=keras.regularizers.L1L2(l1=0.01))
        self.out = keras.layers.Dense(2, name='output')
        
    def call(self, inputs, return_layer_activations=False):
        images, head_poses = inputs
        
        c0 = self.conv0(images)
        c0 = self.bn0(c0)
        c0 = self.elu0(c0)
        
        c1 = self.maxpool0(c0)
        c1 = self.conv1(c1)
        c1 = self.dropout1(c1)
        c1 = self.bn1(c1)
        c1 = self.elu1(c1)
        
        c2 = self.maxpool1(c1)
        c2 = self.conv2(c2)
        c2 = self.bn2(c2)
        c2 = self.elu2(c2)
        
        c3 = self.maxpool2(c2)
        c3 = self.conv3(c3)
        c3 = self.dropout3(c3)
        c3 = self.bn3(c3)
        c3 = self.elu3(c3)
        
        c4 = self.maxpool3(c3)
        c4 = self.conv4(c4)
        c4 = self.bn4(c4)
        c4 = self.elu4(c4)
        
        c5 = self.maxpool4(c4)
        c5 = self.conv5(c5)
        c5 = self.dropout5(c5)
        c5 = self.bn5(c5)
        c5 = self.elu5(c5)
        
        c6 = self.maxpool5(c5)
        c6 = self.conv6(c6)
        c6 = self.elu6(c6)
        h = self.flatten(c6)
        h = keras.layers.Concatenate()([h, head_poses])
        h = self.dense(h)
        y = self.out(h)
        if return_layer_activations:
            return c0, c1, c2, c3, c4, c5, c6, h, y
        else:
            return y

## Adversarial Regressor Class

In [None]:
class AdversarialRegressor(keras.Model):
    '''
    Domain adversarial regressor for the ImageRegressor. Receives the
    layer activations from the ImageRegressor as inputs and predicts
    cluster membership.
    '''
    def __init__(self, n_clusters, name='adversary', **kwargs):
        """Domain adversarial regressor for the ImageRegressor. Receives the
        layer activations from the ImageRegressor as inputs and predicts
        cluster membership.

        Args: 
            n_clusters (int): number of clusters 
            name (str, optional): Model name. Defaults to 'adversary'.
        """        
        super(AdversarialRegressor, self).__init__(name=name, **kwargs)
        self.n_clusters = n_clusters
        
        # 36 x 60
        self.conv0 = keras.layers.Conv2D(64, 3, padding='same', name='conv0')
        self.bn0 = keras.layers.BatchNormalization(name='bn0')
        self.elu0 = keras.layers.ELU(name='elu0')
        self.maxpool0 = keras.layers.MaxPool2D(name='maxpool0')
        # 18 x 30        
        self.conv1 = keras.layers.Conv2D(64, 3, padding='same', name='conv1')
        self.dropout1 = keras.layers.Dropout(0.5, name='dropout1')
        self.bn1 = keras.layers.BatchNormalization(name='bn1')
        self.elu1 = keras.layers.ELU(name='elu1')
        self.maxpool1 = keras.layers.MaxPool2D(name='maxpool1')
        # 9 x 15
        self.conv2 = keras.layers.Conv2D(128, 3, padding='same', name='conv2')
        self.bn2 = keras.layers.BatchNormalization(name='bn2')
        self.elu2 = keras.layers.ELU(name='elu2')
        self.maxpool2 = keras.layers.MaxPool2D(name='maxpool2')
        # 5 x 8
        self.conv3 = keras.layers.Conv2D(128, 3, padding='same', name='conv3')
        self.dropout3 = keras.layers.Dropout(0.5, name='dropout3')
        self.bn3 = keras.layers.BatchNormalization(name='bn3')
        self.elu3 = keras.layers.ELU(name='elu3')
        self.maxpool3 = keras.layers.MaxPool2D(name='maxpool3')
        # 3 x 4
        self.conv4 = keras.layers.Conv2D(256, 3, padding='same', name='conv4')
        self.bn4 = keras.layers.BatchNormalization(name='bn4')
        self.elu4 = keras.layers.ELU(name='elu4')
        self.maxpool4 = keras.layers.MaxPool2D(name='maxpool4')
        # 2 x 2
        self.conv5 = keras.layers.Conv2D(256, 3, padding='same', name='conv5')
        self.dropout5 = keras.layers.Dropout(0.5, name='dropout5')
        self.bn5 = keras.layers.BatchNormalization(name='bn5')
        self.elu5 = keras.layers.ELU(name='elu5')
        self.maxpool5 = keras.layers.MaxPool2D(padding='same', name='maxpool5')
        # # 1 x 1
        self.conv6 = keras.layers.Conv2D(512, 3, padding='same', name='conv6')
        self.elu6 = keras.layers.ELU(name='elu6')
        self.flatten = keras.layers.Flatten(name='flatten')
        self.dense = keras.layers.Dense(256, name='dense', activation='elu')
        self.out = keras.layers.Dense(n_clusters, name='output', activation='softmax')
        
    def call(self, inputs):
        c0, c1, c2, c3, c4, c5, c6, h = inputs
        
        x = self.conv0(c0)
        x = self.bn0(x)
        x = self.elu0(x)
        
        x = self.maxpool0(x)
        x = keras.layers.Concatenate()([x, c1])
        x = self.conv1(x)
        x = self.dropout1(x)
        x = self.bn1(x)
        x = self.elu1(x)
        
        x = self.maxpool1(x)
        x = keras.layers.Concatenate()([x, c2])
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.elu2(x)
        
        x = self.maxpool2(x)
        x = keras.layers.Concatenate()([x, c3])
        x = self.conv3(x)
        x = self.dropout3(x)
        x = self.bn3(x)
        x = self.elu3(x)
        
        x = self.maxpool3(x)
        x = keras.layers.Concatenate()([x, c4])
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.elu4(x)
        
        x = self.maxpool4(x)
        x = keras.layers.Concatenate()([x, c5])
        x = self.conv5(x)
        x = self.dropout5(x)
        x = self.bn5(x)
        x = self.elu5(x)
        
        x = self.maxpool5(x)
        # Don't concatenate c6 because the tensor shapes don't line up
        x = self.conv6(x)
        x = self.elu6(x)
        x = self.flatten(x)
        x = keras.layers.Concatenate()([x, h])
        x = self.dense(x)
        x = self.out(x)
        return x

## Random Effects Class

In [None]:
def make_prior_fn(prior_scale):
    def _prior_fn(kernel_size, bias_size=0, prior_scale=prior_scale):
        n = kernel_size + bias_size
        prior_loc = tf.zeros(n, dtype=tf.float32)
        prior_scale = tf.fill([n], tf.cast(prior_scale, tf.float32))
        
        def prior_fn():
            return tf.concat([prior_loc, prior_scale], axis=-1)
        
        return prior_fn
    return _prior_fn

def make_posterior_fn(post_loc_init_scale, post_scale_init_min, post_scale_init_range):
    def _re_posterior_fn(kernel_size, bias_size=0):
        n = kernel_size + bias_size
        initializer = keras.initializers.RandomNormal(mean=0, stddev=post_loc_init_scale)
        posterior_loc = tf.Variable(initial_value=initializer(shape=(n,), dtype=tf.float32), trainable=True)
        initializer = keras.initializers.RandomUniform(minval=post_scale_init_min, 
                                                          maxval=post_scale_init_min + post_scale_init_range)
        posterior_scale = tf.Variable(initial_value=initializer(shape=(n,), dtype=tf.float32), trainable=True)
        
        def posterior_fn():
            return tf.concat([posterior_loc, tf.nn.softplus(posterior_scale)], axis=-1)
        
        return posterior_fn
    return _re_posterior_fn

class RandomEffects(keras.layers.Layer):
    def __init__(self, 
                 units: int=1, 
                 post_loc_init_scale: float=0.05, 
                 post_scale_init_min: float=0.05,
                 post_scale_init_range: float=0.05,
                 prior_scale: float=0.05,
                 kl_weight: float=0.001,
                 l1_weight: float=None,
                 name=None) -> None:  
        
        super(RandomEffects, self).__init__(name=name)
        self.kl_weight = kl_weight
        self.l1_weight = l1_weight
        self.units = units
        self.prior_scale = prior_scale
        
        posterior_fn = make_posterior_fn(post_loc_init_scale, post_scale_init_min, post_scale_init_range)
        self.posterior = posterior_fn(units)
        
        prior_fn = make_prior_fn(prior_scale)
        self.prior = prior_fn(units, prior_scale=prior_scale)
        
    def call(self, inputs, training=None):
        inputs = tf.cast(inputs, dtype=tf.float32)  # 데이터 타입을 float32로 변경
        inputs = tf.expand_dims(inputs, axis=-1)  # inputs의 차원을 확장

        if training:
            # sample from approximate posterior
            posterior_sample = self.posterior()
            loc = posterior_sample[..., :self.units]
            scale = tf.nn.softplus(posterior_sample[..., self.units:])  # Apply softplus to ensure positive scale

            eps = tf.random.normal(shape=(tf.shape(inputs)[0], self.units), mean=0, stddev=1)
            u = loc + eps * scale
            u = tf.expand_dims(u, axis=1)  # u의 형태를 [32, 1, 512]로 조정
            outputs = tf.matmul(inputs, u)  # 변경된 차원으로 곱셈 수행
            # outputs = tf.squeeze(outputs, axis=1)  # 결과 텐서에서 불필요한 차원 제거

            # compute kl divergence
            prior_sample = self.prior()
            prior_loc = prior_sample[..., :self.units]
            prior_scale = tf.nn.softplus(prior_sample[..., self.units:])
            
            kl = 0.5 * tf.reduce_sum(
                tf.square(loc - prior_loc) / tf.square(prior_scale) +
                tf.square(scale) / tf.square(prior_scale) -
                1.0 + 2.0 * (tf.math.log(prior_scale) - tf.math.log(scale))
            )
            kl = kl * self.kl_weight
            self.add_loss(kl)
        else:
            # In testing mode, use the posterior means
            posterior_sample = self.posterior()
            loc = posterior_sample[..., :self.units]
            loc = tf.expand_dims(loc, axis=0)  # Ensure loc has the correct shape for matmul
            loc = tf.expand_dims(loc, axis=1)  # loc 형태를 [1, 512, 1]로 조정
            outputs = tf.matmul(inputs, loc)
            # outputs = tf.squeeze(outputs, axis=2)  # 결과 텐서에서 불필요한 차원 제거

        if self.l1_weight:
            self.add_loss(self.l1_weight * tf.reduce_sum(tf.abs(loc)))
        
        return outputs

## Domain Adversarial Image Regressor Class

In [None]:
class DomainAdversarialImageRegressor(keras.Model):
    '''
    Domain adversarial 2D image regressor which learns the regression 
    task while competing with an adversary, which learns to predict cluster 
    membership from the regressor's layer activations.
    '''
    
    def __init__(self, n_clusters, name='da_regressor', **kwargs):
        """Domain adversarial 2D image regressor which learns the regression 
        task while competing with an adversary, which learns to predict cluster 
        membership from the regressor's layer activations.

        Args:
            n_clusters (int): number of clusters
            name (str, optional): Model name. Defaults to 'da_regressor'.
        """        
        super(DomainAdversarialImageRegressor, self).__init__(name=name, **kwargs)
        self.regressor = ImageRegressor()
        self.adversary = AdversarialRegressor(n_clusters)
        self.n_clusters = n_clusters
        
    def call(self, inputs):
        x, hps, z, = inputs
        y_pred = self.regressor((x,hps))
        
        return y_pred
    
    def compile(self,
                loss_regressor=keras.losses.MeanAbsoluteError(),
                loss_adversary=keras.losses.CategoricalCrossentropy(),
                loss_regressor_weight=1.0,
                loss_gen_weight=0.1,
                metric_regressor=keras.metrics.MeanAbsoluteError(),
                opt_regressor=keras.optimizers.Nadam(learning_rate=0.0001),
                opt_adversary=keras.optimizers.Nadam(learning_rate=0.0001)
                ):
      
        super().compile()
        
        self.loss_regressor = loss_regressor
        self.loss_adversary = loss_adversary
        self.loss_regressor_weight = loss_regressor_weight
        self.loss_gen_weight = loss_gen_weight
        self.metric_regressor = metric_regressor
        self.opt_regressor = opt_regressor
        self.opt_adversary = opt_adversary
        
        self.loss_regressor_tracker = keras.metrics.Mean(name='reg_loss')
        self.loss_gen_tracker = keras.metrics.Mean(name='gen_loss')
        self.loss_adversary_tracker = keras.metrics.Mean(name='adv_loss')
        self.loss_total_tracker = keras.metrics.Mean(name='total_loss')
        
    @property
    def metrics(self):
        return [self.loss_regressor_tracker,
                self.loss_gen_tracker,
                self.loss_adversary_tracker,
                self.loss_total_tracker,
                self.metric_regressor]
        
    def _compute_update_loss(self, loss_reg, loss_gen):
        '''Compute total loss and update loss running means'''
        self.loss_regressor_tracker.update_state(loss_reg)
        self.loss_gen_tracker.update_state(loss_gen)
        
        loss_total = (self.loss_regressor_weight * loss_reg) + (self.loss_gen_weight * loss_gen)
        self.loss_total_tracker.update_state(loss_total)
        
        return loss_total
    
    def train_step(self, data):
        if len(data) == 3:
            (images, headpose, clusters), labels, sample_weights = data
        else:
            (images, headpose, clusters), labels = data
            sample_weights = None
            
        # train adversary
        with tf.GradientTape() as gt:
            reg_outs = self.regressor((images, headpose), return_layer_activations=True)
            layer_activations = reg_outs[:-1]
            clusters_pred = self.adversary(layer_activations)
            loss_adv = self.loss_adversary(clusters, clusters_pred, sample_weight=sample_weights)

        grads_adv = gt.gradient(loss_adv, self.adversary.trainable_weights)
        self.opt_adversary.apply_gradients(zip(grads_adv, self.adversary.trainable_weights))
        self.loss_adversary_tracker.update_state(loss_adv)
        
        # train main regressor
        with tf.GradientTape() as gt2:
            reg_outs = self.regressor((images, headpose), return_layer_activations=True)
            y_pred = reg_outs[-1]
            loss_reg = self.loss_regressor(labels, y_pred, sample_weight=sample_weights)
                        
            layer_activations = reg_outs[:-1]
            clusters_pred = self.adversary(layer_activations)
            loss_gen = self.loss_adversary(clusters, clusters_pred, sample_weight=sample_weights)
            
            loss_total = self._compute_update_loss(loss_reg, loss_gen)
            
        grads = gt2.gradient(loss_total, self.regressor.trainable_weights)
        self.opt_regressor.apply_gradients(zip(grads, self.regressor.trainable_weights))
        
        self.metric_regressor.update_state(labels, y_pred)
        return {m.name: m.result() for m in self.metrics}
    
    def test_step(self, data):
        (images, headpose, clusters), labels = data
        
        # Z Predict
        reg_outs = self((images, headpose, clusters), training=False, return_layer_activations=True)
        layer_activations = reg_outs[:-1]
        
        # 클러스터 인덱스 추출
        cluster_indices = tf.argmax(clusters, axis=1)

        # Check if cluster IDs are known
        is_known_cluster = tf.vectorized_map(lambda x: tf.reduce_any(tf.equal(x, self.cluster_list)), cluster_indices)
        
        # Calculate adversary output once and reuse
        adversary_output = self.adversary(layer_activations)

        # Use known or predicted clusters
        new_clusters = tf.where(tf.expand_dims(is_known_cluster, axis=1), clusters, adversary_output)

        # Main Regression with potentially new clusters
        reg_outs = self((images, headpose, new_clusters), training=False, return_layer_activations=True)
        y_pred = reg_outs[-1]
        
        loss_reg = self.loss_regressor(labels, y_pred)
        loss_gen = self.loss_adversary(new_clusters, adversary_output)
        
        _ = self._compute_update_loss(loss_reg, loss_gen)
        
        self.metric_regressor.update_state(labels, y_pred)
        return {m.name: m.result() for m in self.metrics}

## Mixed Effects Regressor

In [None]:
class MixedEffectsRegressor(DomainAdversarialImageRegressor):
    def __init__(self, 
                 cluster_list=[],
                 slope_post_init_scale=0.1,
                 intercept_post_init_scale=0.1,
                 slope_prior_scale=0.25,
                 intercept_prior_scale=0.25,
                 kl_weight=1e-3,
                 name='me_regressor', **kwargs):

        # self.cluster_list = cluster_list
        self.cluster_list = tf.constant([int(cid[1:]) for cid in cluster_list], dtype=tf.int64)
        n_clusters = len(self.cluster_list)
        super(MixedEffectsRegressor, self).__init__(n_clusters, name=name, **kwargs)
                
        # Single slope and intercept RE layer
        self.re_slopes = RandomEffects(units=256,
                                       post_loc_init_scale=slope_post_init_scale,
                                       prior_scale=slope_prior_scale,
                                       kl_weight=kl_weight,
                                       name='re_slopes')
        
        self.re_intercept = RandomEffects(units=2,
                                       post_loc_init_scale=intercept_post_init_scale,
                                       prior_scale=intercept_prior_scale,
                                       kl_weight=kl_weight,
                                       name='re_intercept')
        
    def call(self, inputs, training, return_layer_activations=False):
        x, hps, z = inputs
        
        if x.shape[-1] != 1:
            x = tf.expand_dims(x, -1)

        c0 = self.regressor.conv0(x)
        c0 = self.regressor.bn0(c0)
        c0 = self.regressor.elu0(c0)
        
        c1 = self.regressor.maxpool0(c0)
        c1 = self.regressor.conv1(c1)
        c1 = self.regressor.dropout1(c1)
        c1 = self.regressor.bn1(c1)
        c1 = self.regressor.elu1(c1)
        
        c2 = self.regressor.maxpool1(c1)
        c2 = self.regressor.conv2(c2)
        c2 = self.regressor.bn2(c2)
        c2 = self.regressor.elu2(c2)
        
        c3 = self.regressor.maxpool2(c2)
        c3 = self.regressor.conv3(c3)
        c3 = self.regressor.dropout3(c3)
        c3 = self.regressor.bn3(c3)
        c3 = self.regressor.elu3(c3)
        
        c4 = self.regressor.maxpool3(c3)
        c4 = self.regressor.conv4(c4)
        c4 = self.regressor.bn4(c4)
        c4 = self.regressor.elu4(c4)
        
        c5 = self.regressor.maxpool4(c4)
        c5 = self.regressor.conv5(c5)
        c5 = self.regressor.dropout5(c5)
        c5 = self.regressor.bn5(c5)
        c5 = self.regressor.elu5(c5)
        
        c6 = self.regressor.maxpool5(c5)
        c6 = self.regressor.conv6(c6)
        c6 = self.regressor.elu6(c6)
        h = self.regressor.flatten(c6)
        h = tf.concat([h, hps], axis=-1)
        h = self.regressor.dense(h)
        
        slopes = self.re_slopes(z, training=training)
        intercepts = self.re_intercept(z, training=training)
        if len(slopes.shape) == 3:
            slopes = tf.reduce_mean(slopes, axis=1)
        if len(intercepts.shape) == 3:
            intercepts = tf.reduce_mean(intercepts, axis=1)
        y = self.regressor.out(h * (1 + slopes))

        # Apply intercepts
        y = y + intercepts
        
        if return_layer_activations:
            return c0, c1, c2, c3, c4, c5, c6, h, y
        else:
            return y
    
    def compile(self,
            loss_regressor=keras.losses.MeanAbsoluteError(),
            loss_adversary=keras.losses.CategoricalCrossentropy(),
            loss_regressor_weight=1.0,
            loss_gen_weight=0.1,
            metric_regressor=keras.metrics.MeanAbsoluteError(),
            opt_regressor=keras.optimizers.Nadam(learning_rate=0.0001),
            opt_adversary=keras.optimizers.Nadam(learning_rate=0.0001)
            ):
        
        super(MixedEffectsRegressor, self).compile(loss_regressor=loss_regressor,
                                                    loss_adversary=loss_adversary,
                                                    loss_regressor_weight=loss_regressor_weight,
                                                    loss_gen_weight=loss_gen_weight,
                                                    metric_regressor=metric_regressor,
                                                    opt_regressor=opt_regressor,
                                                    opt_adversary=opt_adversary)
        self.loss_kld_tracker = tf.keras.metrics.Mean(name='kld')
        
    @property
    def metrics(self):
        return [self.loss_regressor_tracker,
                self.loss_gen_tracker,
                self.loss_adversary_tracker,
                self.loss_kld_tracker,
                self.loss_total_tracker,
                self.metric_regressor]
        
    def _compute_update_loss(self, loss_reg, loss_gen, training=True):
        '''Compute total loss and update loss running means'''
        self.loss_regressor_tracker.update_state(loss_reg)
        self.loss_gen_tracker.update_state(loss_gen)
        if training:
            kld = tf.reduce_sum(self.re_slopes.losses) + tf.reduce_sum(self.re_intercept.losses)
            self.loss_kld_tracker.update_state(kld)
        else:
            # KLD can't be computed at inference time because posteriors are simplified to 
            # point estimates
            kld = 0
        loss_total = (self.loss_regressor_weight * loss_reg) \
            + (self.loss_gen_weight * loss_gen) \
            + kld
        self.loss_total_tracker.update_state(loss_total)
        
        return loss_total
        
    
    def train_step(self, data):
        (images, headpose, clusters), labels = data
        sample_weights = None

        # train adversary
        with tf.GradientTape() as gt:
            reg_outs = self((images, headpose, clusters), training=True, return_layer_activations=True)
            layer_activations = reg_outs[:-1]
            clusters_pred = self.adversary(layer_activations)
            loss_adv = self.loss_adversary(clusters, clusters_pred, sample_weight=sample_weights)

        grads_adv = gt.gradient(loss_adv, self.adversary.trainable_weights)
        self.opt_adversary.apply_gradients(zip(grads_adv, self.adversary.trainable_weights))
        self.loss_adversary_tracker.update_state(loss_adv)

        # train main regressor
        with tf.GradientTape() as gt2:
            reg_outs = self((images, headpose, clusters), training=True, return_layer_activations=True)
            y_pred = reg_outs[-1]
            # 각 출력에 대한 손실을 계산하고 합산
            total_loss_reg = 0
            for i in range(y_pred.shape[1]): # y_pred의 출력 개수만큼 반복
                total_loss_reg += self.loss_regressor(labels[:, i], y_pred[:, i], sample_weight=sample_weights)

            layer_activations = reg_outs[:-1]
            clusters_pred = self.adversary(layer_activations)
            loss_gen = self.loss_adversary(clusters, clusters_pred, sample_weight=sample_weights)

            loss_total = self._compute_update_loss(total_loss_reg, loss_gen)

        grads = gt2.gradient(loss_total, self.regressor.trainable_weights)
        self.opt_regressor.apply_gradients(zip(grads, self.regressor.trainable_weights))

        self.metric_regressor.update_state(labels, y_pred)
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        (images, headpose, clusters), labels = data
        
        # Z Predict
        reg_outs = self((images, headpose, clusters), training=False, return_layer_activations=True)
        layer_activations = reg_outs[:-1]
        
        # 클러스터 인덱스 추출
        cluster_indices = tf.argmax(clusters, axis=1)

        # Check if cluster IDs are known
        is_known_cluster = tf.vectorized_map(lambda x: tf.reduce_any(tf.equal(x, self.cluster_list)), cluster_indices)
        
        # Calculate adversary output once and reuse
        adversary_output = self.adversary(layer_activations)

        # Use known or predicted clusters
        new_clusters = tf.where(tf.expand_dims(is_known_cluster, axis=1), clusters, adversary_output)

        # Main Regression with potentially new clusters
        reg_outs = self((images, headpose, new_clusters), training=False, return_layer_activations=True)
        y_pred = reg_outs[-1]
        
        loss_reg = self.loss_regressor(labels, y_pred)
        loss_gen = self.loss_adversary(new_clusters, adversary_output)
        
        _ = self._compute_update_loss(loss_reg, loss_gen)
        
        self.metric_regressor.update_state(labels, y_pred)
        return {m.name: m.result() for m in self.metrics}

# Training

In [None]:
# 모델 생성
model = MixedEffectsRegressor(cluster_list = total_id_data)

# 모델 컴파일
model.compile(loss_regressor=keras.losses.MeanAbsoluteError(),
              loss_adversary=keras.losses.CategoricalCrossentropy(),
              loss_regressor_weight=1.0,
              loss_gen_weight=0.1,
              metric_regressor=tf.keras.metrics.MeanAbsoluteError(),
              opt_regressor=tf.keras.optimizers.Nadam(learning_rate=0.0001),
              opt_adversary=tf.keras.optimizers.Nadam(learning_rate=0.0001))

# 콜백 설정
# early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
# reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-6)

# 모델 학습
history = model.fit(
    x=train_data[0],  # 입력 데이터: 이미지와 클러스터 정보
    y=train_data[1],  # 타겟 데이터: 시선 벡터
    epochs=100,
    batch_size=32,
    #validation_data=val_data,  # 검증 데이터
    # callbacks=[early_stopping, reduce_lr]
)

In [None]:
# test_data = ((test_img_normalized, loocv_hps_data_te, test_id_encoded), loocv_gzs_data_te)
pred_gzs = model.predict(test_data[0])

In [None]:
def mae_keras(predictedGaze, groundtruthGaze, is_3d=False, deg=False):
    '''
    Calculate Mean Angular Error using TensorFlow.
    
    Args:
    predictedGaze (tf.Tensor): Predicted gaze vectors.
    groundtruthGaze (tf.Tensor): Ground truth gaze vectors.
    is_3d (bool): Flag indicating whether the input vectors are 3D. Default is False.
    deg (bool): Flag indicating whether the spherical coordinates are in degrees. Default is False.
    
    Returns:
    tf.Tensor: Mean angular error.
    '''
    Gaze_1 = tf.cast(predictedGaze, dtype=tf.float32)
    Gaze_2 = tf.cast(groundtruthGaze, dtype=tf.float32)
    
    if not is_3d:
        Gaze_1 = convert_to_xyz_tf(Gaze_1, deg)
        Gaze_2 = convert_to_xyz_tf(Gaze_2, deg)

    Gaze_1 = Gaze_1 / tf.norm(Gaze_1, axis=1, keepdims=True)
    Gaze_2 = Gaze_2 / tf.norm(Gaze_2, axis=1, keepdims=True)

    cos_val = tf.reduce_sum(Gaze_1 * Gaze_2, axis=1)
    cos_val = tf.clip_by_value(cos_val, -1, 1)  # Ensure cos_val is within the valid range for arccos

    angle_val = tf.acos(cos_val) * 180 / tf.constant(np.pi)

    return tf.reduce_mean(angle_val)

def convert_to_xyz_tf(spherical, deg=False):
    if deg:
        spherical = spherical * tf.constant(np.pi) / 180

    # Create xyz tensor similarly as above but using TensorFlow operations.
    xyz = tf.zeros((spherical.shape[0], 3), dtype=tf.float32)
    xyz = tf.tensor_scatter_nd_update(
        xyz, [[i, 0] for i in range(spherical.shape[0])],
        -tf.cos(spherical[:, 0]) * tf.sin(spherical[:, 1]))
    xyz = tf.tensor_scatter_nd_update(
        xyz, [[i, 1] for i in range(spherical.shape[0])],
        -tf.sin(spherical[:, 0]))
    xyz = tf.tensor_scatter_nd_update(
        xyz, [[i, 2] for i in range(spherical.shape[0])],
        -tf.cos(spherical[:, 0]) * tf.cos(spherical[:, 1]))

    xyz = xyz / tf.norm(xyz, axis=1, keepdims=True)

    return xyz

In [None]:
# MAE 계산
error = mae_keras(pred_gzs, test_data[1])
print("Mean Angular Error:", error.numpy(), "degrees")