# Import Libraries

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
import warnings
warnings.filterwarnings('ignore')

from sklearn.model_selection import StratifiedKFold

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

import keras
from keras.utils import to_categorical

## Data Load

In [None]:
mpii_path = "/Users/dkmoon/Desktop/DKU/ML_LAB/Scholar/data/mpii"

# 전체 참가자 클래스에 대하여 Train/Valid 균등 분류
within_id_data = np.load(os.path.join(mpii_path, "within_ids.npy"))
# 'p' 제거
within_id_data = np.char.replace(within_id_data, 'p', '')
# 정수형으로 변환
within_id_data = within_id_data.astype(int)
within_hps_data = np.load(os.path.join(mpii_path, "within_2d_hps.npy"))
within_img_data = np.load(os.path.join(mpii_path, "within_images.npy"))
within_gzs_data = np.load(os.path.join(mpii_path, "within_2d_gazes.npy"))

print(f"within_id_data : {within_id_data.shape}")
print(f"within_hps_data : {within_hps_data.shape}")
print(f"within_img_data : {within_img_data.shape}")
print(f"within_gzs_data : {within_gzs_data.shape}")

In [None]:
# Prepare data for splitting
ids_data = within_id_data.reshape(-1)
hps_data = within_hps_data.reshape(-1, 2)
img_data = within_img_data.reshape(-1, 36, 60, 1) / 255.0
gzs_data = within_gzs_data.reshape(-1, 2)

In [None]:
# Create stratified k-fold splitter
skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=530)

# Split data
folds = []
for train_val_index, test_index in skf.split(np.zeros(ids_data.shape[0]), ids_data):
    train_index, val_index = np.split(train_val_index, [int(len(train_val_index) * 0.875)])
    
    fold = {
        'train': {
            'ids': ids_data[train_index],
            'hps': hps_data[train_index],
            'imgs': img_data[train_index],
            'gzs': gzs_data[train_index]
        },
        'val': {
            'ids': ids_data[val_index],
            'hps': hps_data[val_index],
            'imgs': img_data[val_index],
            'gzs': gzs_data[val_index]
        },
        'test': {
            'ids': ids_data[test_index],
            'hps': hps_data[test_index],
            'imgs': img_data[test_index],
            'gzs': gzs_data[test_index]
        }
    }
    folds.append(fold)

# Print fold information
for i, fold in enumerate(folds):
    print(f"Fold {i+1}:")
    print(f"  Train set: {fold['train']['ids'].shape[0]} samples")
    print(f"  Val set: {fold['val']['ids'].shape[0]} samples")
    print(f"  Test set: {fold['test']['ids'].shape[0]} samples")

# Model Architecture

## Fixed Effect Subnetwork

In [None]:
class FixedEffectSubnetwork(keras.Model):
    def __init__(self, name='regressor', **kwargs):
        super(FixedEffectSubnetwork, self).__init__(name=name, **kwargs)
        # 36 x 60
        self.conv0 = keras.layers.Conv2D(64, 3, padding='same', name='conv0')
        self.bn0 = keras.layers.BatchNormalization(name='bn0')
        self.act0 = keras.layers.ELU(name='act0')
        self.maxpool0 = keras.layers.MaxPool2D(name='maxpool0')
        
        # 18 x 30        
        self.conv1 = keras.layers.Conv2D(128, 3, padding='same', name='conv1')
        self.dropout1 = keras.layers.Dropout(0.3, name='dropout1')
        self.bn1 = keras.layers.BatchNormalization(name='bn1')
        self.act1 = keras.layers.ELU(name='act1')
        self.maxpool1 = keras.layers.MaxPool2D(name='maxpool1')
        
        # 9 x 15
        self.conv2 = keras.layers.Conv2D(128, 3, padding='same', name='conv2')
        self.bn2 = keras.layers.BatchNormalization(name='bn2')
        self.act2 = keras.layers.ELU(name='act2')
        self.maxpool2 = keras.layers.MaxPool2D(name='maxpool2')
        
        # 5 x 8
        self.conv3 = keras.layers.Conv2D(256, 3, padding='same', name='conv3')
        self.dropout3 = keras.layers.Dropout(0.3, name='dropout3')
        self.bn3 = keras.layers.BatchNormalization(name='bn3')
        self.act3 = keras.layers.ELU(name='act3')
        self.maxpool3 = keras.layers.MaxPool2D(name='maxpool3')
        
        # 3 x 4
        self.conv4 = keras.layers.Conv2D(256, 3, padding='same', name='conv4')
        self.bn4 = keras.layers.BatchNormalization(name='bn4')
        self.act4 = keras.layers.ELU(name='act4')
        self.maxpool4 = keras.layers.MaxPool2D(name='maxpool4')
        
        # 2 x 2
        self.conv5 = keras.layers.Conv2D(512, 3, padding='same', name='conv5')
        self.dropout5 = keras.layers.Dropout(0.5, name='dropout5')
        self.bn5 = keras.layers.BatchNormalization(name='bn5')
        self.act5 = keras.layers.ELU(name='act5')
        self.maxpool5 = keras.layers.MaxPool2D(padding='same', name='maxpool5')
        
        # # 1 x 1
        self.conv6 = keras.layers.Conv2D(512, 3, padding='same', name='conv6')
        self.act6 = keras.layers.ELU(name='act6')
        self.flatten = keras.layers.Flatten(name='flatten')
        self.dense = keras.layers.Dense(256, name='dense', activation='elu',
                                        kernel_regularizer=keras.regularizers.L1L2(l1=0.01))
        self.out = keras.layers.Dense(2, name='fe_output')
        
    def call(self, inputs):
        images, head_poses = inputs
        
        c0 = self.conv0(images)
        c0 = self.bn0(c0)
        c0 = self.act0(c0)
        
        c1 = self.maxpool0(c0)
        c1 = self.conv1(c1)
        c1 = self.dropout1(c1)
        c1 = self.bn1(c1)
        c1 = self.act1(c1)
        
        c2 = self.maxpool1(c1)
        c2 = self.conv2(c2)
        c2 = self.bn2(c2)
        c2 = self.act2(c2)
        
        c3 = self.maxpool2(c2)
        c3 = self.conv3(c3)
        c3 = self.dropout3(c3)
        c3 = self.bn3(c3)
        c3 = self.act3(c3)
        
        c4 = self.maxpool3(c3)
        c4 = self.conv4(c4)
        c4 = self.bn4(c4)
        c4 = self.act4(c4)
        
        c5 = self.maxpool4(c4)
        c5 = self.conv5(c5)
        c5 = self.dropout5(c5)
        c5 = self.bn5(c5)
        c5 = self.act5(c5)
        
        c6 = self.maxpool5(c5)
        c6 = self.conv6(c6)
        c6 = self.act6(c6)
        
        h = self.flatten(c6)
        h = keras.layers.Concatenate()([h, head_poses])
        h = self.dense(h)
        y = self.out(h)
        
        return c0, c1, c2, c3, c4, c5, c6, h, y

## Random Effect Subnetwork

In [None]:
def make_prior_fn(df, scale):
    def prior_fn():
        return tfd.StudentT(df=df, loc=0.0, scale=scale)
    return prior_fn

def make_posterior_fn(df, loc_init_scale, scale_init_min, scale_init_range):
    def posterior_fn(units):
        loc_initializer = tf.random_normal_initializer(mean=0.0, stddev=loc_init_scale)
        scale_initializer = tf.random_uniform_initializer(minval=scale_init_min, maxval=scale_init_min + scale_init_range)

        loc = tf.Variable(initial_value=loc_initializer(shape=(units,)), name="loc", dtype=tf.float32)
        scale = tf.Variable(initial_value=scale_initializer(shape=(units,)), name="scale", dtype=tf.float32)

        return tfd.StudentT(df=df, loc=loc, scale=tf.nn.softplus(scale))
    return posterior_fn

def kl_divergence_student_t(posterior, prior, num_samples=1000, seed=530):
    """Calculate the KL divergence between two Student's T distributions using Monte Carlo approximation."""
    samples = posterior.sample(num_samples, seed=seed)  # Sample from the posterior
    log_posterior_prob = posterior.log_prob(samples)  # Log probability under the posterior
    log_prior_prob = prior.log_prob(samples)  # Log probability under the prior
    
    # Monte Carlo approximation of the KL divergence
    kl_div = tf.reduce_mean(log_posterior_prob - log_prior_prob)
    return kl_div

class RandomEffectSubnetwork(keras.layers.Layer):
    def __init__(self, 
                 units, 
                 df,
                 post_loc_init_scale, 
                 post_scale_init_min, 
                 post_scale_init_range,
                 prior_scale,
                 kl_weight=0.001, 
                 l1_weight=None, 
                 name=None):

        super(RandomEffectSubnetwork, self).__init__(name=name)
        self.kl_weight = kl_weight
        self.l1_weight = l1_weight
        self.units = units

        self.posterior = make_posterior_fn(df, post_loc_init_scale, post_scale_init_min, post_scale_init_range)(units)
        self.prior = make_prior_fn(df, prior_scale)()

    def call(self, inputs, training=None):
        inputs = tf.cast(inputs, dtype=tf.float32)  # 데이터 타입을 float32로 변경
        inputs = tf.expand_dims(inputs, axis=-1)  # inputs의 차원을 확장
        if training:
            # Shape of `samples`: [batch_size, units]
            samples = self.posterior.sample(sample_shape=(tf.shape(inputs)[0],))
            # Ensure 'samples' can be broadcasted with 'inputs'
            # Assuming `inputs` shape is [batch_size, input_dim], we need to align 'samples' along that dimension
            samples = tf.reshape(samples, [tf.shape(inputs)[0], 1, self.units])
            outputs = inputs * samples  # Broadcast multiplication

            kl_div = kl_divergence_student_t(self.posterior, self.prior, num_samples=500, seed=530)
            self.add_loss(self.kl_weight * tf.reduce_sum(kl_div))
        else:
            # Use the mean of the posterior distribution as a deterministic output
            mean_samples = self.posterior.mean()
            mean_samples = tf.reshape(mean_samples, [1, 1, self.units])
            outputs = inputs * mean_samples  # Broadcast multiplication

        if self.l1_weight:
            self.add_loss(self.l1_weight * tf.reduce_sum(tf.abs(self.posterior.mean())))

        return outputs

## Adversarial Classifier

In [None]:
class AdversarialRegressor(keras.Model):
    def __init__(self, n_clusters, name='adversary', **kwargs):   
        super(AdversarialRegressor, self).__init__(name=name, **kwargs)
        self.n_clusters = n_clusters
        # 36 x 60
        self.conv0 = keras.layers.Conv2D(64, 3, padding='same', name='conv0')
        self.bn0 = keras.layers.BatchNormalization(name='bn0')
        self.act0 = keras.layers.ELU(name='act0')
        self.maxpool0 = keras.layers.MaxPool2D(name='maxpool0')
        # 18 x 30        
        self.conv1 = keras.layers.Conv2D(64, 3, padding='same', name='conv1')
        self.dropout1 = keras.layers.Dropout(0.3, name='dropout1')
        self.bn1 = keras.layers.BatchNormalization(name='bn1')
        self.act1 = keras.layers.ELU(name='act1')
        self.maxpool1 = keras.layers.MaxPool2D(name='maxpool1')
        # 9 x 15
        self.conv2 = keras.layers.Conv2D(128, 3, padding='same', name='conv2')
        self.bn2 = keras.layers.BatchNormalization(name='bn2')
        self.act2 = keras.layers.ELU(name='act2')
        self.maxpool2 = keras.layers.MaxPool2D(name='maxpool2')
        # 5 x 8
        self.conv3 = keras.layers.Conv2D(128, 3, padding='same', name='conv3')
        self.dropout3 = keras.layers.Dropout(0.3, name='dropout3')
        self.bn3 = keras.layers.BatchNormalization(name='bn3')
        self.act3 = keras.layers.ELU(name='act3')
        self.maxpool3 = keras.layers.MaxPool2D(name='maxpool3')
        # 3 x 4
        self.conv4 = keras.layers.Conv2D(256, 3, padding='same', name='conv4')
        self.bn4 = keras.layers.BatchNormalization(name='bn4')
        self.act4 = keras.layers.ELU(name='act4')
        self.maxpool4 = keras.layers.MaxPool2D(name='maxpool4')
        # 2 x 2
        self.conv5 = keras.layers.Conv2D(256, 3, padding='same', name='conv5')
        self.dropout5 = keras.layers.Dropout(0.3, name='dropout5')
        self.bn5 = keras.layers.BatchNormalization(name='bn5')
        self.act5 = keras.layers.ELU(name='act5')
        self.maxpool5 = keras.layers.MaxPool2D(padding='same', name='maxpool5')
        # # 1 x 1
        self.conv6 = keras.layers.Conv2D(512, 3, padding='same', name='conv6')
        self.act6 = keras.layers.ELU(name='act6')
        self.flatten = keras.layers.Flatten(name='flatten')
        
        self.dense = keras.layers.Dense(256, name='dense', activation='elu')
        self.out = keras.layers.Dense(n_clusters, name='adv_output', activation='softmax')
        
    def call(self, inputs):
        # Fixed Effect Subnetwork의 Feature map
        c0, c1, c2, c3, c4, c5, c6, h = inputs
        
        x = self.conv0(c0)
        x = self.bn0(x)
        x = self.act0(x)
        x = self.maxpool0(x)
        
        x = keras.layers.Concatenate()([x, c1])
        x = self.conv1(x)
        x = self.dropout1(x)
        x = self.bn1(x)
        x = self.act1(x)
        x = self.maxpool1(x)
        
        x = keras.layers.Concatenate()([x, c2])
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act2(x)
        x = self.maxpool2(x)
        
        x = keras.layers.Concatenate()([x, c3])
        x = self.conv3(x)
        x = self.dropout3(x)
        x = self.bn3(x)
        x = self.act3(x)
        x = self.maxpool3(x)
        
        x = keras.layers.Concatenate()([x, c4])
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.act4(x)
        x = self.maxpool4(x)
        
        x = keras.layers.Concatenate()([x, c5])
        x = self.conv5(x)
        x = self.dropout5(x)
        x = self.bn5(x)
        x = self.act5(x)
        x = self.maxpool5(x)

        x = keras.layers.Concatenate()([x, c6])
        x = self.conv6(x)
        x = self.act6(x)
        x = self.flatten(x)
        x = keras.layers.Concatenate()([x, h])
        x = self.dense(x)
        x = self.out(x) # z_pred
        return x

## Mixed Effect Network

In [None]:
class MixedEffectNetwork(keras.Model):
    def __init__(self,
                 cluster_list=[],
                 slope_post_init_scale=0.1,
                 intercept_post_init_scale=0.1,
                 slope_scale=0.25,
                 intercept_scale=0.25,
                 kl_weight=1e-3,
                 df=3,
                 name='me_network', **kwargs):
        super(MixedEffectNetwork, self).__init__(name=name, **kwargs)
        
        # 고정 효과 하위 네트워크 함수 정의
        self.fixed_effect_subnetwork = FixedEffectSubnetwork(name='fe_network')
        
        # 변량 효과 하위 네트워크의 기울기 함수 정의
        self.re_slopes = RandomEffectSubnetwork(units=256,
                                                df=df,
                                                post_loc_init_scale=slope_post_init_scale,
                                                post_scale_init_min=0.01, 
                                                post_scale_init_range=0.02,
                                                prior_scale=slope_scale,
                                                kl_weight=kl_weight,
                                                name='re_slopes')
        # 변량 효과 하위 네트워크의 절편 함수 정의
        self.re_intercept = RandomEffectSubnetwork(units=2,
                                                   df=df,
                                                   post_loc_init_scale=intercept_post_init_scale,
                                                   post_scale_init_min=0.01, 
                                                   post_scale_init_range=0.02,
                                                   prior_scale=intercept_scale,
                                                   kl_weight=kl_weight,
                                                   name='re_intercept')
        
        # 적대적 분류 신경망
        self.cluster_list = tf.constant(cluster_list, dtype=tf.int64)
        n_clusters= len(self.cluster_list)
        self.adversary = AdversarialRegressor(n_clusters, name='z_predictor')
        
        # 혼합 효과 출력
        self.me_out = keras.layers.Dense(2, name='me_output')
        
    def call(self, inputs, training=False):
        images, head_poses, cluster_ids = inputs
        # Fixed Effect Subnetwork
        c0, c1, c2, c3, c4, c5, c6, h, y_fixed = self.fixed_effect_subnetwork((images, head_poses))
        
        # Random Effect Subnetwork
        re_slope = self.re_slopes(cluster_ids, training=training)
        re_intercept = self.re_intercept(cluster_ids, training=training)
        
        if len(re_slope.shape) == 3:
            re_slope = tf.reduce_mean(re_slope, axis=1)
        if len(re_intercept.shape) == 3:
            re_intercept = tf.reduce_mean(re_intercept, axis=1)
        
        # Mixed Effect Network
        y_mixed = self.me_out(h * (1 + re_slope))
        y_mixed = y_mixed + re_intercept
        
        return c0, c1, c2, c3, c4, c5, c6, h, y_fixed, re_slope, re_intercept, y_mixed
    
    def compile(self,
                loss_me_regressor=keras.losses.MeanAbsoluteError(),
                loss_fe_regressor=keras.losses.MeanAbsoluteError(),
                loss_adversary=keras.losses.CategoricalCrossentropy(),
                fe_regressor_weight=1.0,
                adv_gen_weight=0.1,
                
                metric_regressor=keras.metrics.MeanAbsoluteError(),
                opt_regressor=keras.optimizers.Nadam(learning_rate=0.0001),
                opt_adversary=keras.optimizers.Nadam(learning_rate=0.0001)):
        
        super(MixedEffectNetwork, self).compile()
        
        self.loss_me_regressor = loss_me_regressor
        self.loss_fe_regressor = loss_fe_regressor
        self.loss_adversary = loss_adversary
        self.fe_regressor_weight = fe_regressor_weight
        self.adv_gen_weight = adv_gen_weight
        
        self.metric_regressor = metric_regressor
        self.opt_regressor = opt_regressor
        self.opt_adversary = opt_adversary
        
        self.loss_me_regressor_tracker = tf.keras.metrics.Mean(name='me_regressor_loss')
        self.loss_fe_regressor_tracker = tf.keras.metrics.Mean(name='fe_regressor_loss')
        self.loss_adversary_tracker = tf.keras.metrics.Mean(name='adversary_loss')
        self.loss_kld_tracker = tf.keras.metrics.Mean(name='kld')
        self.loss_total_tracker = tf.keras.metrics.Mean(name='total_loss')
    
    @property
    def metrics(self):
        return [self.loss_me_regressor_tracker,
                self.loss_fe_regressor_tracker,
                self.loss_adversary_tracker,
                self.loss_kld_tracker,
                self.loss_total_tracker,
                self.metric_regressor]

    def _compute_update_loss(self, loss_me_reg, loss_fe_reg, loss_adv, training=True):
        self.loss_me_regressor_tracker.update_state(loss_me_reg)
        self.loss_fe_regressor_tracker.update_state(loss_fe_reg)
        
        self.loss_adversary_tracker.update_state(loss_adv)
        if training:
            kld = tf.reduce_sum(self.re_slopes.losses) + tf.reduce_sum(self.re_intercept.losses)
            self.loss_kld_tracker.update_state(kld)
        else:
            # KLD can't be computed at inference time because posteriors are simplified to 
            # point estimates
            kld = 0
            
        loss_total = loss_me_reg + (self.fe_regressor_weight * loss_fe_reg) + (self.adv_gen_weight * loss_adv) + kld
        self.loss_total_tracker.update_state(loss_total)
        
        return loss_total
    
    # 모델 학습
    def train_step(self, data):
        (images, head_poses, clusters), labels = data
        sample_weights = None

        # 원-핫 인코딩된 클러스터
        clusters_one_hot = to_categorical(clusters, num_classes=len(self.cluster_list))

        with tf.GradientTape() as tape:
            c0, c1, c2, c3, c4, c5, c6, h, y_fixed, re_slope, re_intercept, y_mixed = self((images, head_poses, clusters), training=True)
            clusters_pred = self.adversary([c0, c1, c2, c3, c4, c5, c6, h])
            loss_adv = self.loss_adversary(clusters_one_hot, clusters_pred, sample_weight=sample_weights)
            
            loss_fe_reg = self.loss_fe_regressor(labels, y_fixed)
            loss_me_reg = self.loss_me_regressor(labels, y_mixed)
            loss_total = self._compute_update_loss(loss_me_reg, loss_fe_reg, loss_adv, training=True)
            
        grads = tape.gradient(loss_total, self.trainable_weights)
        self.opt_regressor.apply_gradients(zip(grads, self.trainable_weights))

        self.metric_regressor.update_state(labels, y_mixed)
        return {m.name: m.result() for m in self.metrics}
    
    def test_step(self, data):
        (images, headpose, clusters), labels = data
        
        # 원-핫 인코딩된 클러스터
        clusters_one_hot = to_categorical(clusters, num_classes=len(self.cluster_list))

        # Z Predict
        c0, c1, c2, c3, c4, c5, c6, h, y_fixed, re_slope, re_intercept, y_mixed = self((images, headpose, clusters), training=False)
        
        # 클러스터 인덱스 추출
        cluster_indices = tf.argmax(clusters_one_hot, axis=1)

        # Check if cluster IDs are known
        is_known_cluster = tf.vectorized_map(lambda x: tf.reduce_any(tf.equal(x, self.cluster_list)), cluster_indices)
        
        # Calculate adversary output once and reuse
        adversary_output = self.adversary([c0, c1, c2, c3, c4, c5, c6, h])

        # Use known or predicted clusters
        new_clusters = tf.where(tf.expand_dims(is_known_cluster, axis=1), clusters_one_hot, adversary_output)

        # Main Regression with potentially new clusters
        _, _, _, _, _, _, _, _, y_fixed_new, re_slope_new, re_intercept_new, y_mixed_new = self((images, headpose, new_clusters), training=False)
        
        loss_me_reg = self.loss_me_regressor(labels, y_mixed_new)
        loss_fe_reg = self.loss_fe_regressor(labels, y_fixed_new)
        loss_adv = self.loss_adversary(new_clusters, adversary_output)
        
        _ = self._compute_update_loss(loss_me_reg, loss_fe_reg, loss_adv, training=False)
        
        self.metric_regressor.update_state(labels, y_mixed_new)
        return {m.name: m.result() for m in self.metrics}

# Training

## Mean Angular Error

In [None]:
def mae_keras(predictedGaze, groundtruthGaze, is_3d=False, deg=False):
    '''
    Calculate Mean Angular Error using TensorFlow.
    
    Args:
    predictedGaze (tf.Tensor): Predicted gaze vectors.
    groundtruthGaze (tf.Tensor): Ground truth gaze vectors.
    is_3d (bool): Flag indicating whether the input vectors are 3D. Default is False.
    deg (bool): Flag indicating whether the spherical coordinates are in degrees. Default is False.
    
    Returns:
    tf.Tensor: Mean angular error.
    '''
    Gaze_1 = tf.cast(predictedGaze, dtype=tf.float32)
    Gaze_2 = tf.cast(groundtruthGaze, dtype=tf.float32)
    
    if not is_3d:
        Gaze_1 = convert_to_xyz_tf(Gaze_1, deg)
        Gaze_2 = convert_to_xyz_tf(Gaze_2, deg)

    Gaze_1 = Gaze_1 / tf.norm(Gaze_1, axis=1, keepdims=True)
    Gaze_2 = Gaze_2 / tf.norm(Gaze_2, axis=1, keepdims=True)

    cos_val = tf.reduce_sum(Gaze_1 * Gaze_2, axis=1)
    cos_val = tf.clip_by_value(cos_val, -1, 1)  # Ensure cos_val is within the valid range for arccos

    angle_val = tf.acos(cos_val) * 180 / tf.constant(np.pi)

    return tf.reduce_mean(angle_val)

def convert_to_xyz_tf(spherical, deg=False):
    if deg:
        spherical = spherical * tf.constant(np.pi) / 180

    # Create xyz tensor similarly as above but using TensorFlow operations.
    xyz = tf.zeros((spherical.shape[0], 3), dtype=tf.float32)
    xyz = tf.tensor_scatter_nd_update(
        xyz, [[i, 0] for i in range(spherical.shape[0])],
        -tf.cos(spherical[:, 0]) * tf.sin(spherical[:, 1]))
    xyz = tf.tensor_scatter_nd_update(
        xyz, [[i, 1] for i in range(spherical.shape[0])],
        -tf.sin(spherical[:, 0]))
    xyz = tf.tensor_scatter_nd_update(
        xyz, [[i, 2] for i in range(spherical.shape[0])],
        -tf.cos(spherical[:, 0]) * tf.cos(spherical[:, 1]))

    xyz = xyz / tf.norm(xyz, axis=1, keepdims=True)

    return xyz

## t-loss function

In [None]:
# t-분포를 따르는 손실 함수 정의
def t_loss(k):
    def loss(y_true, y_pred, sample_weight=None):
        error = y_true - y_pred
        squared_error = tf.square(error)
        scaled_error = tf.math.log(k + squared_error)
        if sample_weight is not None:
            scaled_error = tf.multiply(sample_weight, scaled_error)
        return tf.reduce_mean(scaled_error)
    return loss

## Train Loop

In [None]:
df = 3
unique_ids = np.unique(ids_data)

# Model training and evaluation using K-fold cross-validation
mae_list = []
histories = []
for i, fold in enumerate(folds):
    print(f"Training fold {i+1}...")

    # Create the model
    model = MixedEffectNetwork(cluster_list=unique_ids, df=df)
    
    # Compile the model
    model.compile(
        loss_me_regressor=t_loss(df),
        loss_fe_regressor=t_loss(df),
        loss_adversary=keras.losses.CategoricalCrossentropy(),
        fe_regressor_weight=1.0,
        adv_gen_weight=0.1,
        metric_regressor=tf.keras.metrics.MeanAbsoluteError(),
        opt_regressor=tf.keras.optimizers.Adam(learning_rate=0.0001),
        opt_adversary=tf.keras.optimizers.Adam(learning_rate=0.0001)
    )

    # Prepare training and validation data
    train_data = (fold['train']['imgs'], fold['train']['hps'], fold['train']['ids']), fold['train']['gzs']
    val_data = (fold['val']['imgs'], fold['val']['hps'], fold['val']['ids']), fold['val']['gzs']

    # Callbacks
    early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-6)
    
    # Train the model
    history = model.fit(
        x=train_data[0],
        y=train_data[1],
        epochs=350,
        batch_size=32,
        validation_data=val_data,
        callbacks=[early_stopping, reduce_lr]
    )
    
    histories.append(history)

    # Evaluate on test set
    test_data = (fold['test']['imgs'], fold['test']['hps'], fold['test']['ids']), fold['test']['gzs']
    pred_gzs = model.predict(test_data[0])
    
    mae_error = mae_keras(pred_gzs, test_data[1])
    mae_list.append(mae_error)
    print(f"Test results for fold {i+1}: {mae_error}")

print(f"\nMean Error : {sum(mae_list)/len(mae_list)}")