In [None]:
import numpy as np
import tensorflow as tf
import os
from tensorflow.keras import Input
from tensorflow.keras import models as M
from tensorflow.keras import layers as L
from tensorflow.keras import backend as keras
from tensorflow.keras.utils import plot_model
import trimesh
import fpsample
from transformers import TFBertModel, BertTokenizer
import tensorflow_graphics as tfg
from tensorflow_graphics.nn.loss import chamfer_distance
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint







## Extracting Point clouds from meshes, extracting file names to create the dataset.
 - this will work directly with the download URL for the MSN dataset.
 - this might take a long time so you can try training the model on a smaller dataset or extract certain classes only.

In [None]:
import trimesh
from io import BytesIO
from urllib.parse import urlparse, parse_qs


def normalize_point_cloud(point_cloud):
    # Compute the mean of the point cloud
    mean = np.mean(point_cloud, axis=0)

    # Subtract the mean to move the point cloud to the origin
    point_cloud -= mean

    # Compute the maximum distance from the origin
    max_distance = np.max(np.sqrt(np.sum(point_cloud**2, axis=1)))

    # Scale the distances to the range -1.0 and 1.0
    point_cloud /= max_distance

    return point_cloud


def process_and_save_point_clouds(links_url, point_cloud_dir='', num_points=6144):
    """
    Fetches 3D mesh files from links provided in a text file, converts them to point clouds,
    and saves the point clouds as PLY files.
    
    Args:
        links_url (str): URL of the text file containing links to the 3D mesh files.
        point_cloud_dir (str): Directory to save the point cloud PLY files.
        num_points (int): Number of points to sample from each mesh to create the point cloud.
    """
    # Step 1: Read the file containing the links
    def read_links(url):
        response = requests.get(url)
        response.raise_for_status()
        links = response.text.splitlines()
        return links

    links = read_links(links_url)
    
    # Ensure the directory exists
    os.makedirs(point_cloud_dir, exist_ok=True)
    
    # Step 2: Convert the 3D meshes to point clouds directly from URLs
    def mesh_to_point_cloud_from_url(url, num_points):
        response = requests.get(url)
        response.raise_for_status()
        mesh = trimesh.load(BytesIO(response.content), file_type='stl')
        points, _ = trimesh.sample.sample_surface(mesh, num_points)
        return points

    # Step 3: Save the point cloud as a PLY file
    def save_point_cloud_as_ply(points, save_path):
        point_cloud = trimesh.PointCloud(points)
        point_cloud.export(save_path)

    
    #def save_point_cloud_as_txt(points, save_path):
        #np.savetxt(save_path, points, delimiter=' ')

    # Extract the file name from the URL
    def extract_file_name(url):
        parsed_url = urlparse(url)
        query_params = parse_qs(parsed_url.query)
        file_name = query_params.get('files', [None])[0]
        return file_name

    # Process each link, convert to point cloud and save as a PLY file
    for url in links:
        try:
            file_name = extract_file_name(url)
            if file_name is None:
                print(f"Failed to extract file name from {url}")
                continue
            
            point_cloud = mesh_to_point_cloud_from_url(url, num_points)
            base_name = file_name.replace('.stl', '.ply')
            point_cloud_path = os.path.join(point_cloud_dir, base_name)
            save_point_cloud_as_ply(normalize_point_cloud(point_cloud), point_cloud_path)
            #save_point_cloud_as_txt(normalize_point_cloud(point_cloud), point_cloud_path)
            #print(f"Saved point cloud to {point_cloud_path}")
        except Exception as e:
            print(f"Failed to process {url}: {e}")

## Utility functions for data preprocessing.

In [None]:
def remove_knn_points_by_index(points, point_index, num_remove):
    center_point = points[point_index]
    distances = np.linalg.norm(points - center_point, axis=1)
    knn_indices = np.argsort(distances)[:num_remove]
    remaining_points = np.delete(points, knn_indices, axis=0)
    return remaining_points


def preprocess_data(root_folder):
    """
    Preprocesses point cloud data and text labels for training by generating input sets, eye seeds,
    tokenized text, and ground truth sets from 3D mesh files.

    Args:
        root_folder (str): The root directory containing .ply files with point clouds.

    Returns:
        tuple: A tuple containing six lists:
            - input_set (list): List of partial point clouds generated from the meshes.
            - eye_seeds (list): Randomly generated eye seeds for each point cloud.
            - text_set (list): List of text labels (class names) for each mesh.
            - input_ids (list): List of tokenized text input IDs.
            - attention_masks (list): List of attention masks for the tokenized text.
            - GT_set (list): Ground truth (full) point clouds.
    """
    input_set = []
    eye_seeds = []
    text_set = []
    GT_set = []

    for dirpath, dirnames, filenames in os.walk(root_folder):
        for filename in filenames:
            if filename.endswith('.ply'):
                base_name = os.path.splitext(filename)[0]
                class_name_part = base_name.split('_', 1)[1]
                class_name = class_name_part.split('.')[0]
                class_name = class_name.replace('_', ' ')
                file_path = os.path.join(dirpath, filename)

                # Load the mesh and sample the ground truth points
                mesh = trimesh.load(file_path)
                GT = np.array(mesh.vertices)
                GT_idx = fpsample.fps_sampling(GT, 6144, start_idx=0)
                GT = GT[GT_idx]

                # Generate partial point clouds using farthest point sampling
                indices = fpsample.fps_sampling(GT, 2, start_idx=0)
                for idx in indices:
                    partial_cloud = remove_knn_points_by_index(GT, idx, 2048)
                    input_set.append(partial_cloud)
                    eye_seeds.append(np.random.rand(1, 1))
                    text_set.append(class_name)
                    GT_set.append(GT)

    # Step 4: Tokenize the text using BERT tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    encoded_inputs = tokenizer.batch_encode_plus(
        text_set,
        add_special_tokens=True,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors='tf'
    )
    input_ids = encoded_inputs['input_ids']
    attention_masks = encoded_inputs['attention_mask']

    return np.array(input_set, dtype= np.float32), np.array(eye_seeds, dtype= np.float32), np.array(input_ids, dtype= np.int32), np.array(attention_masks, dtype= np.int32), np.array(GT_set, dtype= np.float32)

## loss functions
 - Code for the Density aware chamfer distance loss (DCD) 
 - code for the Vanilla chamfer distance (CD)

In [None]:
def eye_seed(X):
    return tf.zeros([X.shape[0],1,1])

def distance_matrix(array1, array2):
    batch_size, num_point, num_features = array1.shape
    expanded_array1 = tf.tile(tf.expand_dims(array1, 2), [1, 1, num_point, 1])
    expanded_array2 = tf.tile(tf.expand_dims(array2, 1), [1, num_point, 1, 1])
    distances = tf.norm(expanded_array1-expanded_array2, axis=-1)
    return distances

def min_distances_and_indices(array1, array2):
    distances = distance_matrix(array1, array2)
    min_dists_1_to_2, indices_1_to_2 = tf.reduce_min(distances, axis=-1), tf.argmin(distances, axis=-1)
    min_dists_2_to_1, indices_2_to_1 = tf.reduce_min(distances, axis=-2), tf.argmin(distances, axis=-2)
    return min_dists_1_to_2, min_dists_2_to_1, indices_1_to_2, indices_2_to_1

def calc_cd(output, gt, calc_f1=False, return_raw=False, normalize=False, separate=False):
    dist1, dist2, idx1, idx2 = min_distances_and_indices(gt, output)
    cd_p = (tf.sqrt(tf.reduce_mean(dist1, axis=1)) + tf.sqrt(tf.reduce_mean(dist2, axis=1))) / 2
    cd_t = (tf.reduce_mean(dist1, axis=1) + tf.reduce_mean(dist2, axis=1))
    if separate:
        res = [tf.concat([tf.reduce_mean(tf.sqrt(dist1), axis=1, keepdims=True),
                          tf.reduce_mean(tf.sqrt(dist2), axis=1, keepdims=True)], axis=0),
               tf.concat([tf.reduce_mean(dist1, axis=1, keepdims=True),
                          tf.reduce_mean(dist2, axis=1, keepdims=True)], axis=0)]
    else:
        res = [cd_p, cd_t]
    if calc_f1:
        f1, _, _ = fscore(dist1, dist2, 0.0001)
        res.append(f1)
    if return_raw:
        res.extend([dist1, dist2, idx1, idx2])
    return res

def calc_dcd(x, gt, alpha=1, n_lambda=1, return_raw=False, non_reg=False):
    x = tf.cast(x, tf.float32)
    gt = tf.cast(gt, tf.float32)
    batch_size = tf.shape(x)[0]
    n_x = tf.shape(x)[1]
    n_gt = tf.shape(gt)[1]
    if non_reg:
        frac_12 = tf.maximum(1.0, tf.cast(n_x, tf.float32) / tf.cast(n_gt, tf.float32))
        frac_21 = tf.maximum(1.0, tf.cast(n_gt, tf.float32) / tf.cast(n_x, tf.float32))
    else:
        frac_12 = tf.cast(n_x, tf.float32) / tf.cast(n_gt, tf.float32)
        frac_21 = tf.cast(n_gt, tf.float32) / tf.cast(n_x, tf.float32)
    cd_p, cd_t, dist1, dist2, idx1, idx2 = calc_cd(x, gt, return_raw=True)
    exp_dist1 = tf.exp(-dist1 * alpha)
    exp_dist2 = tf.exp(-dist2 * alpha)
    def compute_loss(b):
        idx1_b = tf.gather(idx1, b)
        idx2_b = tf.gather(idx2, b)
        count1 = tf.math.bincount(idx1_b, minlength=tf.cast(n_x, tf.int64))
        weight1 = tf.gather(count1, idx1_b)
        weight1 = tf.cast(weight1, tf.float32)
        weight1 = tf.pow(weight1, n_lambda)
        weight1 = tf.pow((weight1 + 1e-6), -1) * frac_21
        loss1 = tf.reduce_mean(-exp_dist1[b] * weight1 + 1.0)
        count2 = tf.math.bincount(idx2_b, minlength=tf.cast(n_gt, tf.int64))
        weight2 = tf.gather(count2, idx2_b)
        weight2 = tf.cast(weight2, tf.float32)
        weight2 = tf.pow(weight2, n_lambda)
        weight2 = tf.pow((weight2 + 1e-6), -1) * frac_12
        loss2 = tf.reduce_mean(-exp_dist2[b] * weight2 + 1.0)
        return loss1, loss2
    loss1, loss2 = tf.map_fn(compute_loss, tf.range(batch_size), dtype=(tf.float32, tf.float32))
    loss = tf.reduce_mean(loss1 + loss2)
    res = [loss, cd_p, cd_t]
    if return_raw:
        res.extend([dist1, dist2, idx1, idx2])
    return loss

def chamfer_distance_loss(y_true, y_pred):
    return tfg.nn.loss.chamfer_distance.evaluate(y_true, y_pred)


## Utility functions for the model

In [None]:
def pairwise_distance(xyz1, xyz2):
    n = xyz1.shape[1]
    c = xyz1.shape[2]
    m = xyz2.shape[1]
    xyz1 = tf.tile(tf.reshape(xyz1, (-1,1,n,c)), [1,m,1,1])
    xyz2 = tf.tile(tf.reshape(xyz2, (-1,m,1,c)), [1,1,n,1])
    dist = tf.reduce_sum((xyz1-xyz2)**2, -1)
    return dist

def knn_point(k, xyz1, xyz2):
    dist = -pairwise_distance(xyz1, xyz2)
    val, idx = tf.math.top_k(dist, k)
    return -val, idx

class UniformSampler(tf.keras.layers.Layer):
    def __init__(self, num_points, seed=42, **kwargs):
        super(UniformSampler, self).__init__(**kwargs)
        self.num_points = num_points
        self.seed = seed

    def build(self, input_shape):
        pass

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        data_size = tf.shape(inputs)[1]
        indices = tf.random.uniform(
            shape=(batch_size, self.num_points),
            minval=0,
            maxval=data_size,
            dtype=tf.int32,
            seed=self.seed
        )
        return indices

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.num_points, input_shape[2])

    def get_config(self):
        config = super(UniformSampler, self).get_config()
        config.update({
            "num_points": self.num_points,
            "seed": self.seed
        })
        return config

def sample_and_group(args, nsample):
    xyz, pts, fps_idx = args
    new_xyz = tf.gather_nd(xyz, tf.expand_dims(fps_idx,-1), batch_dims=1)
    new_pts = tf.gather_nd(pts, tf.expand_dims(fps_idx,-1), batch_dims=1)
    _, idx = knn_point(nsample, xyz, new_xyz)
    grouped_pts = tf.gather_nd(pts, tf.expand_dims(idx,-1), batch_dims=1)
    grouped_pts -= tf.tile(tf.expand_dims(new_pts, 2),
                           (1,1,nsample,1))
    new_pts = tf.concat([grouped_pts,
                         tf.tile(tf.expand_dims(new_pts, 2),
                                 (1,1,nsample,1))],
                        axis=-1)
    return new_xyz, new_pts

def LBR(tensor, C, seq_name, use_bias=True, activation=None, LeakyAlpha=0.0):
    x_in = Input(shape=tensor.shape[1:], name=seq_name+'_input')
    x = L.Dense(C, use_bias=use_bias, activation=activation, name=seq_name+'_lin')(x_in)
    if LeakyAlpha==0.0:
        x_out = L.ReLU(name=seq_name+'_ReLU')(x)
    else:
        x_out = L.LeakyReLU(alpha=LeakyAlpha, name=seq_name+'_ReLU')(x)
    model = M.Model(inputs=x_in, outputs=x_out, name=seq_name)
    return model(tensor)

def Self_Attention(tensor, seq_name):
    x_in = Input(shape=tensor.shape[1:], name=seq_name+'_input')
    C = x_in.shape[2]
    W_q = L.Dense(C//4, use_bias=False, activation=None, name=seq_name+'_Q')
    W_k = L.Dense(C//4, use_bias=False, activation=None, name=seq_name+'_K')
    W_v = L.Dense(C, use_bias=False, activation=None, name=seq_name+'_V')
    x_q = W_q(x_in)
    x_k = W_k(x_in)
    W_k.set_weights(W_q.get_weights())
    x_k = L.Lambda(lambda t: tf.transpose(t, perm=(0,2,1)), name=seq_name+'_KT')(x_k)
    x_v = W_v(x_in)
    energy = L.Lambda(lambda ts: tf.matmul(ts[0],ts[1]), name=seq_name+'_matmul1')([x_q, x_k])
    attention = L.Softmax(axis=1, name=seq_name+'_softmax')(energy)
    attention = L.Lambda(lambda t: t / (1e-9 + tf.reduce_sum(t, axis=2, keepdims=True)), name=seq_name+'_l1norm')(attention)
    x_r = L.Lambda(lambda ts: tf.matmul(ts[0],ts[1]), name=seq_name+'_matmul2')([attention, x_v])
    x_r = L.Lambda(lambda ts: tf.subtract(ts[0],ts[1]), name=seq_name+'_subtract')([x_in, x_r])
    x_r = LBR(x_r, C, seq_name+'_LBR', use_bias=True)
    x_out = L.Lambda(lambda ts: tf.add(ts[0],ts[1]), name=seq_name+'_add')([x_in, x_r])
    model = M.Model(inputs=x_in, outputs=x_out, name=seq_name)
    return model(tensor)


def Cross_Attention(args, seq_name):
    E_tensor, D_tensor = args
    xE_in = Input(shape=E_tensor.shape[1:], name=seq_name+'_input-E')
    C = xE_in.shape[2]
    xD_in = Input(shape=D_tensor.shape[1:], name=seq_name+'_input-D')
    out_dim = xD_in.shape[2]
    W_q = L.Dense(C//4, use_bias=False, activation=None, name=seq_name+'_Q')
    W_k = L.Dense(C//4, use_bias=False, activation=None, name=seq_name+'_K')
    W_v = L.Dense(out_dim, use_bias=False, activation=None, name=seq_name+'_V')
    x_q = W_q(xD_in)
    x_k = W_k(xE_in)
    x_k = L.Lambda(lambda t: tf.transpose(t, perm=(0,2,1)), name=seq_name+'_KT')(x_k)
    x_v = W_v(xE_in)
    energy = L.Lambda(lambda ts: tf.matmul(ts[0],ts[1]), name=seq_name+'_matmul1')([x_q, x_k])
    attention = L.Softmax(axis=1, name=seq_name+'_softmax')(energy)
    attention = L.Lambda(lambda t: t / (1e-9 + tf.reduce_sum(t, axis=2, keepdims=True)), name=seq_name+'_l1norm')(attention)
    x_r = L.Lambda(lambda ts: tf.matmul(ts[0],ts[1]), name=seq_name+'_matmul2')([attention, x_v])
    x_r = L.Lambda(lambda ts: tf.subtract(ts[0],ts[1]), name=seq_name+'_subtract')([xD_in, x_r])
    x_r = LBR(x_r, out_dim, seq_name+'_LBR', use_bias=True)
    x_out = L.Lambda(lambda ts: tf.add(ts[0],ts[1]), name=seq_name+'_add')([xD_in, x_r])
    model = M.Model(inputs=[xE_in,xD_in], outputs=x_out, name=seq_name)
    return model([E_tensor,D_tensor])

def copy_and_mapping(tensor, nmul, seq_name):
    x_in = Input(shape=tensor.shape[1:], name=seq_name+'_input')
    x = L.Lambda(lambda t: tf.expand_dims(t, 2), name=seq_name+'_expand')(x_in)
    C = x.shape[-1]//nmul
    x1 = L.Conv2DTranspose(C,(1,nmul),(1,nmul), use_bias=True, activation=None, name=seq_name+'_convT')(x)
    x2 = L.Dense(C, use_bias=True, activation=None, name=seq_name+'_lin')(x)
    x2 = L.Lambda(lambda t: tf.tile(t, [1,1,nmul,1]), name=seq_name+'_tile')(x2)
    x = L.Lambda(lambda ts: tf.add(ts[0],ts[1]), name=seq_name+'_add')([x1, x2])
    npoint = x.shape[1]*x.shape[2]
    x_out = L.Lambda(lambda t: tf.reshape(t, [-1,npoint,t.shape[3]]), name=seq_name+'_reshape')(x)
    model = M.Model(inputs=x_in, outputs=x_out, name=seq_name)
    return model(tensor)






## Point Encoder

In [None]:
def PCT_encoder(xyz):
    x = LBR(xyz, 64, 'E-IN_LBR1', use_bias=False)
    x = LBR(x, 128, 'E-IN_LBR2', use_bias=False)
    fps_idx = UniformSampler(4096)(xyz)
    new_xyz, new_feature = L.Lambda(sample_and_group, arguments={'nsample':32}, name='E-SG1')([xyz, x, fps_idx])
    x = LBR(new_feature, 512, 'E-SG1_LBR1', use_bias=False)
    x = L.Lambda(lambda t: tf.reduce_max(t, axis=2), name='E-SG1_MaxPool')(x)
    fps_idx = UniformSampler(2048)(new_xyz)
    new_xyz, new_feature = L.Lambda(sample_and_group, arguments={'nsample':32}, name='E-SG2')([new_xyz, x, fps_idx])
    x = LBR(new_feature, 1024, 'E-SG2_LBR1', use_bias=False)
    x = L.Lambda(lambda t: tf.reduce_max(t, axis=2), name='E-SG2_MaxPool')(x)
    x1 = Self_Attention(x, 'E-SA1')
    x2 = Self_Attention(x1, 'E-SA2')
    x3 = Self_Attention(x2, 'E-SA3')
    x4 = Self_Attention(x3, 'E-SA4')
    x0 = L.Lambda(lambda ts: tf.concat(ts, axis=2), name='E-SA_Concat')([x1,x2,x3,x4])
    x = L.Lambda(lambda ts: tf.concat(ts, axis=2), name='E-OUT_Concat')([x0,x])
    x = LBR(x, 2048, 'E-OUT_LBR', use_bias=False, LeakyAlpha=0.2)
    x1 = Self_Attention(x, 'E-SA5')
    x2 = Self_Attention(x1, 'E-SA6')
    x3 = Self_Attention(x2, 'E-SA7')
    x4 = Self_Attention(x3, 'E-SA8')
    x0 = L.Lambda(lambda ts: tf.concat(ts, axis=2), name='E-SA_Concat2')([x1,x2,x3,x4])
    x = LBR(x0, 4096, 'E-OUT_LBR1', use_bias=False, LeakyAlpha=0.2)
    output_feats = L.Lambda(lambda t: tf.reduce_max(t, axis=1, keepdims=True), name='E-OUT_MaxPool')(x)
    return output_feats

## Point Decoder

In [None]:
def pct_decoder(input_feats, input_eye_seed):
    m_feats = L.Lambda(lambda x: tf.tile(x, [1,1024,1]), name = 'D-IN_replicate')(input_feats)
    input_eye = input_eye_seed + tf.eye(1024,1024)
    x = L.Dense(4096//4, use_bias=False, activation=None, name='D1-IN')(input_eye)
    x1 = Cross_Attention([m_feats,x] , 'D-STA1')
    x2 = Cross_Attention([m_feats,x1], 'D-STA2')
    x3 = Cross_Attention([m_feats,x2], 'D-STA3')
    x4 = Cross_Attention([m_feats,x3], 'D-STA4')
    x0 = L.Lambda(lambda ts: tf.concat(ts, axis=2), name='D1-STA_Concat')([x1,x2,x3,x4])
    x = L.Lambda(lambda ts: tf.concat(ts, axis=2), name='D1-OUT_Concat')([x0,x])
    m_feats2 = copy_and_mapping(x, 3, 'D1-OUT_CopyAndMapping')
    input_eye2 = input_eye_seed + tf.eye(3072,3072)
    x = L.Dense(1024//4, use_bias=False, activation=None, name='D2-IN')(input_eye2)
    x1 = Cross_Attention([m_feats2,x] , 'D2-STA1')
    x2 = Cross_Attention([m_feats2,x1], 'D2-STA2')
    x3 = Cross_Attention([m_feats2,x2], 'D2-STA3')
    x4 = Cross_Attention([m_feats2,x3], 'D2-STA4')
    x0 = L.Lambda(lambda ts: tf.concat(ts, axis=2), name='D2-STA_Concat')([x1,x2,x3,x4])
    x = L.Lambda(lambda ts: tf.concat(ts, axis=2), name='D2-OUT_Concat')([x0,x])
    x = copy_and_mapping(x, 2, 'D2-OUT_CopyAndMapping')
    x = LBR(x,128, 'D-OUT_LBR1', use_bias=False)
    x = LBR(x,128, 'D-OUT_LBR2', use_bias=False)
    x = LBR(x, 64, 'D-OUT_LBR3', use_bias=False, LeakyAlpha=0.2)
    output_points = L.Dense(3, activation=None, name='D-OUT_lin')(x)
    return output_points

## Bert Text Encoder

In [None]:
def bert_model(input_ids, attention_mask, model_name='bert-base-uncased', max_length=128):
    bert_model = TFBertModel.from_pretrained(model_name)
    bert_model.trainable = False
    bert_outputs = bert_model([input_ids, attention_mask])
    cls_output = bert_outputs.pooler_output
    dense_output = tf.keras.layers.Dense(4096, activation='relu')(cls_output)
    output = tf.expand_dims(dense_output, axis = 1)
    return output

## Building the Multi-modal point cloud autoencoder

In [None]:
class PCT_AE_Multimodal:
    def __init__(self, num_input_points=4096, max_length=128, bert_model=bert_model, PCT_encoder=PCT_encoder, pct_decoder=pct_decoder):
        self.num_input_points = num_input_points
        self.max_length = max_length
        self.bert_model = bert_model
        self.PCT_encoder = PCT_encoder
        self.pct_decoder = pct_decoder
        self.model = self.build_model()

    def build_model(self):
        eye_seed = Input(shape=(1, 1), name='input_eye_seed')
        xyz = Input(shape=(self.num_input_points, 3), name='input_points')
        input_ids = Input(shape=(self.max_length,), dtype=tf.int32, name='input_ids')
        attention_mask = Input(shape=(self.max_length,), dtype=tf.int32, name='attention_mask')
        if not self.bert_model or not self.PCT_encoder or not self.pct_decoder:
            raise ValueError("Bert model, PCT encoder, and PCT decoder must be provided.")
        text_encoded = self.bert_model(input_ids, attention_mask)
        cloud_encoded = self.PCT_encoder(xyz)
        multi_encoded = cloud_encoded + text_encoded
        output = self.pct_decoder(multi_encoded, eye_seed)
        return M.Model(inputs=[xyz, eye_seed, input_ids, attention_mask], outputs=output)

## Instantiating the model

In [None]:

AE = PCT_AE_Multimodal(bert_model=bert_model, PCT_encoder=PCT_encoder, pct_decoder=pct_decoder)
AE = AE.model
initial_learning_rate = 1e-7
optimizer = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)
AE.compile(optimizer=optimizer, loss=calc_dcd)

## Preparing the data

In [None]:
root_folder = "path/to/your/ply/files"

# Call the preprocess_data function
input_set, eye_seeds, input_ids, attention_masks, GT_set = preprocess_data(root_folder)

## Defining callbacks

In [None]:
reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.9,
    patience=5,
    verbose=1,
    mode='auto',
    min_lr=1e-20
)

save_dir = 'path/to/your/directory'
checkpoint_callback = ModelCheckpoint(
    filepath=os.path.join(save_dir, 'weights_file_name.h5'),
    monitor='val_loss',
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode='min'
)

### Model Training Overview
 -Training was performed over six weeks on 6 NVIDIA RTX A6000 GPUs (40 GB VRAM each) and 8 CPUs with 1000 GB of RAM. 

 -After 322 epochs, training was stopped due to no performance improvements for 12 consecutive epochs. We used the ADAM optimizer with an initial learning rate of 1e-7. To enhance efficiency, the ReduceLROnPlateau callback was employed, dynamically reducing the learning rate based on validation performance, aiding model convergence.


In [None]:
 AE.fit([input_set, eye_seeds, input_ids, attention_masks], GT_set, epochs=500,
           shuffle=True, validation_split=0.1, batch_size=8, verbose=1,
           callbacks=[reduce_lr, checkpoint_callback])
