# Dynamic Aging Index (DAI) - Deep Learning Model for Biological Age Prediction

## Project Overview

This notebook implements a sophisticated deep learning architecture for predicting biological aging dynamics using a combination of:

- **Variational Autoencoder (VAE)** for dimensionality reduction and feature learning
- **Koopman Operator Theory** for modeling temporal dynamics in latent space
- **Dynamic Aging Index (DAI)** as a scalar metric for biological age assessment

## Key Features

- **Multi-temporal prediction**: Predicts biological aging states at different time intervals (3-year, 10-year)
- **Residual neural architecture**: Implements skip connections for improved gradient flow
- **Regularized training**: Uses L1/L2 regularization and dropout for robust learning
- **Custom loss functions**: Combines reconstruction, KL divergence, and temporal prediction losses
- **Scalable design**: Handles high-dimensional biological data (332,909 features)

## Architecture Components

1. **Encoder**: Reduces high-dimensional biological data to a 25-dimensional latent space
2. **Decoder**: Reconstructs original data from latent representations
3. **Koopman Operator**: Models temporal evolution in latent space using linear dynamics
4. **DAI Projection**: Maps latent states to a scalar aging index

## Dataset Structure

The model uses three types of data:
- **Cross-sectional data**: For general feature learning
- **Present state data**: Current biological measurements
- **Future state data**: Target predictions for temporal modeling

---


## 1. Environment Setup and Dependencies

This cell loads all necessary libraries for the Dynamic Aging Index model:

### Core Libraries:
- **TensorFlow/Keras**: Deep learning framework for neural network implementation
- **PyArrow**: Efficient data loading for large biological datasets
- **Pandas**: Data manipulation and analysis
- **NumPy**: Numerical computing
- **Matplotlib**: Visualization
- **Scikit-learn**: Data preprocessing and model evaluation tools

### Key Components:
- **Regularizers**: L1/L2 regularization for preventing overfitting
- **Initializers**: He normal initialization for optimal gradient flow
- **Callbacks**: Learning rate scheduling and model checkpointing
- **Metrics**: Custom loss tracking during training


In [None]:
#Load the necessary libraries

import pyarrow as pa
import pyarrow.parquet as pq
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers
from tensorflow.keras.layers import Dense, Input, Lambda, BatchNormalization, Add, Dropout
from tensorflow.keras import layers, metrics, models, regularizers
from tensorflow.keras import models
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras import regularizers
import tensorflow.keras.backend as K
from tensorflow.keras import initializers
from tensorflow.keras.initializers import HeNormal
from tensorflow.keras.callbacks import TensorBoard
from tensorflow import keras
import joblib
import math
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.callbacks import LearningRateScheduler
from sklearn.model_selection import train_test_split


## 2. Data Loading and Preprocessing

### Dataset Structure:
The model requires three types of biological data:

1. **Cross-sectional data (X.parquet)**: 
   - Primary dataset containing 332,909 features
   - Used for learning DNA methylation pattern representations
   - Contains CpG Probes beta values

2. **Present state data (X0_train/test.parquet)**:
   - Current DNA methylation beta values for temporal modeling
   - Split into training and testing sets
   - Normative aging study datasets (phs000853.v2.p2) were used for training and GSE73115 dataset from geo is used  
    for testing.

3. **Future state data (X1_train/test.parquet)**:
   - Target future biological states
   - Corresponds to aging progression over time
   - Used for training temporal dynamics

### Data Processing Steps:
- Extract target variables (y1, y) from cross-sectional data
- Convert all data to float32 for memory efficiency and GPU compatibility
- Split datasets into training, validation, and testing sets
- Maintain temporal relationships in X0 and X1 splits


In [None]:
#Load the datasets

def print_title(title):
    print(f'{50 * "="}')
    print(title)
    print(f'{50 * "="}')


print_title('Loading Data')
df = pq.read_table("C:\\Users\\Data\\crosssectionaldata\\X.parquet", use_threads=True).to_pandas()
X0_train = pq.read_table("C:\\Users\\Data\\X0_train.parquet", use_threads=True).to_pandas()
X0_test = pq.read_table("C:\\Users\\Data\\X0_test.parquet", use_threads=True).to_pandas()
X1_train = pq.read_table("C:\\Users\\Data\\X1_train.parquet", use_threads=True).to_pandas()
X1_test = pq.read_table("C:\\Users\\Data\\X1_test.parquet", use_threads=True).to_pandas()
print_title('Data loading complete')

In [None]:
#Process and split the datasets for validation

X = df.iloc[:, 1:, ]
y1 = df.iloc[:, 0]

y = X.iloc[:, 0]
X = X.iloc[:, 1:, ]

X = X.values.astype(np.float32)
X0_train = X0_train.values.astype(np.float32)
X1_train = X1_train.values.astype(np.float32)
X0_test = X0_test.values.astype(np.float32)
X1_test = X1_test.values.astype(np.float32)

X_train, X_test, y_train, y_test, y1_train, y1_test = train_test_split(X, y, y1, test_size=0.2, random_state=0)

X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=0)
X0_train, X0_val = train_test_split(X0_train, test_size=0.2, random_state=0, shuffle=False)
X1_train, X1_val = train_test_split(X1_train, test_size=0.2, random_state=0, shuffle=False)


## 3. Data Normalization

### MinMaxScaler Implementation


In [None]:
#Scale the datasets

scaler = MinMaxScaler()

# Fit the scaler on the training data and transform the training set
X_train_scaled = scaler.fit_transform(X_train)

# Transform the validation and test sets using the same scaler
X_val_scaled = scaler.transform(X_val)
X_test_scaled = scaler.transform(X_test)

# Similarly, fit and transform X0_train, X0_val, X0_test for the second dataset
X0_train_scaled = scaler.transform(X0_train)
X0_val_scaled = scaler.transform(X0_val)
X0_test_scaled = scaler.transform(X0_test)

# And fit and transform X1_train, X1_val, X1_test for the third dataset
X1_train_scaled = scaler.transform(X1_train)
X1_val_scaled = scaler.transform(X1_val)
X1_test_scaled = scaler.transform(X1_test)

## 5. Model Architecture Parameters

### Key Dimensions:
- **Input shape**: 332,909 features (high-dimensional biological data)
- **Latent dimension**: 25 (compressed representation)
- **Batch size**: 54 (optimized for GPU memory and convergence)



In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)


In [None]:
input_shape = (332909,)
latent_dim = 25
batch_size = 54

## 6. Autoencoder Architecture Implementation

### Residual Block Design:
The `residual_block` function implements skip connections for improved gradient flow:

- **Input processing**: Dense layer with ReLU activation
- **Batch normalization**: Stabilizes training and accelerates convergence
- **Dropout**: Prevents overfitting (40% dropout rate)
- **Skip connection**: Adds original input to processed features
- **Dimension matching**: Automatically adjusts shortcut dimensions when needed

### Encoder Architecture:
**Purpose**: Compresses high-dimensional biological data to 25-dimensional latent space

**Layer progression**:
1. Input (332,909) ‚Üí Dense (600) + BatchNorm + Dropout
2. Residual blocks: 600 ‚Üí 400 ‚Üí 200 ‚Üí 100 ‚Üí 50
3. Output: Mean, log variance, and sampled latent vector

**Key features**:
- **Progressive compression**: Gradually reduces dimensionality
- **Regularization**: L1/L2 regularization (0.01) and dropout (0.4)
- **He normal initialization**: Optimal for ReLU activations

### Decoder Architecture:
**Purpose**: Reconstructs original data from latent representations

**Layer progression**:
1. Input (25) ‚Üí Dense (50) + BatchNorm + Dropout
2. Residual blocks: 50 ‚Üí 100 ‚Üí 200 ‚Üí 400 ‚Üí 600
3. Output: Dense (332,909) with sigmoid activation

**Design principles**:
- **Symmetric structure**: Mirrors encoder architecture
- **Sigmoid output**: Ensures reconstruction values in [0,1] range
- **Skip connections**: Maintains gradient flow through deep layers


In [None]:
#Autoencoder Architecture

def residual_block(x, filters, kernel_size=3, stride=1, l1_reg=0.01, l2_reg=0.01, dropout_rate=0.4, initializer='he_normal'):
    shortcut = x
    x = layers.Dense(filters, activation="relu", 
                     kernel_regularizer=regularizers.l1_l2(l1=l1_reg, l2=l2_reg),
                     kernel_initializer=initializer)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout_rate)(x)
    x = layers.Dense(filters, activation=None, 
                     kernel_regularizer=regularizers.l1_l2(l1=l1_reg, l2=l2_reg),
                     kernel_initializer=initializer)(x)
    x = layers.BatchNormalization()(x)
    
    # Adjust the shortcut if necessary
    if shortcut.shape[-1] != filters:
        shortcut = layers.Dense(filters, activation=None, 
                                kernel_regularizer=regularizers.l1_l2(l1=l1_reg, l2=l2_reg),
                                kernel_initializer=initializer)(shortcut)
        shortcut = layers.BatchNormalization()(shortcut)
    
    x = layers.add([x, shortcut])
    x = layers.ReLU()(x)
    return x


def build_encoder(input_shape, latent_dim, l1_reg=0.01, l2_reg=0.01, dropout_rate=0.4, initializer='he_normal'):
    encoder_inputs = layers.Input(shape=input_shape)
    x = layers.Dense(600, activation="relu", 
                     kernel_regularizer=regularizers.l1_l2(l1=l1_reg, l2=l2_reg),
                     kernel_initializer=initializer)(encoder_inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout_rate)(x)
    
    # Add residual blocks
    x = residual_block(x, 600, l1_reg=l1_reg, l2_reg=l2_reg, initializer=initializer)
    x = residual_block(x, 400, l1_reg=l1_reg, l2_reg=l2_reg, dropout_rate=dropout_rate, initializer=initializer)
    x = residual_block(x, 200, l1_reg=l1_reg, l2_reg=l2_reg, initializer=initializer)
    x = residual_block(x, 100, l1_reg=l1_reg, l2_reg=l2_reg, dropout_rate=dropout_rate, initializer=initializer)
    x = residual_block(x, 50, l1_reg=l1_reg, l2_reg=l2_reg, dropout_rate=dropout_rate)
        
    z_mean = layers.Dense(latent_dim, name="z_mean", kernel_initializer=initializer)(x)
    z_log_var = layers.Dense(latent_dim, name="z_log_var", kernel_initializer=initializer)(x)
    z = Sampling()([z_mean, z_log_var])
    encoder = models.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
    return encoder


encoder = build_encoder(input_shape, latent_dim, l1_reg=0.01, l2_reg=0.01, dropout_rate=0.4, initializer='he_normal')

def build_decoder(latent_dim, output_shape, l1_reg=0.01, l2_reg=0.01, dropout_rate=0.4, initializer='he_normal'):
    latent_inputs = layers.Input(shape=(latent_dim,))
    x = layers.Dense(50, activation="relu", 
                     kernel_regularizer=regularizers.l1_l2(l1=l1_reg, l2=l2_reg),
                     kernel_initializer=initializer)(latent_inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout_rate)(x)
    
    # Add residual blocks
    
    x = residual_block(x, 50, l1_reg=l1_reg, l2_reg=l2_reg, initializer=initializer)
    x = residual_block(x, 100, l1_reg=l1_reg, l2_reg=l2_reg, dropout_rate=dropout_rate, initializer=initializer)
    x = residual_block(x, 200, l1_reg=l1_reg, l2_reg=l2_reg, initializer=initializer)
    x = residual_block(x, 400, l1_reg=l1_reg, l2_reg=l2_reg, dropout_rate=dropout_rate, initializer=initializer)
    x = residual_block(x, 600, l1_reg=l1_reg, l2_reg=l2_reg, initializer=initializer)
    
    decoder_outputs = layers.Dense(output_shape, activation="sigmoid", kernel_initializer=initializer)(x)  # Use linear activation for real-valued outputs
    decoder = models.Model(latent_inputs, decoder_outputs, name="decoder")
    return decoder


decoder = build_decoder(latent_dim, 332909, l1_reg=0.01, l2_reg=0.01, dropout_rate=0.4, initializer='he_normal')


## 7. Dynamic Aging Index (DAI) Projection

### ScalarTransformation Layer:
**Purpose**: Maps the 25-dimensional latent space to a single scalar value representing biological age

**Implementation**:
- **Input**: 25-dimensional latent vector
- **Output**: Single scalar value (DAI score)
- **Activation**: Linear (no activation function)
- **Interpretation**: Higher values indicate accelerated aging

### DAI Applications:
- **Biological age assessment**: Quantifies aging beyond chronological age
- **Health monitoring**: Tracks aging progression over time
- **Risk stratification**: Identifies individuals with accelerated aging
- **Intervention evaluation**: Measures effectiveness of anti-aging treatments

### Key Benefits:
- **Interpretability**: Single number for easy understanding
- **Clinical relevance**: Directly applicable to healthcare decisions
- **Temporal modeling**: Enables prediction of future aging states


In [None]:
#Projector block where Latent space is reduced to 1D scalar value - DAI

class ScalarTransformation(layers.Layer):
    def __init__(self, **kwargs):
        super(ScalarTransformation, self).__init__(**kwargs)
        self.dense = layers.Dense(1, activation=None)  # Single scalar output

    def call(self, inputs):
        return self.dense(inputs)

## 8. Variational Autoencoder (VAE) Implementation

### VAE Class Overview:
The VAE combines encoder, decoder, and DAI projection into a unified architecture for biological age modeling.

### Key Components:
1. **Encoder**: Compresses input data to latent representations
2. **Decoder**: Reconstructs data from latent codes
3. **Scalar Transformation**: Maps latent states to DAI scores

### Multi-Modal Processing:
The VAE processes three types of input simultaneously:
- **Cross-sectional data**: General biological features
- **Present state data**: Current biological measurements
- **Future state data**: Target aging states

### Outputs:
- **Reconstruction**: Reconstructed cross-sectional data
- **Latent representations**: Present and future latent states
- **DAI scores**: Scalar aging indices for present and future states
- **Statistical parameters**: Mean and variance for variational inference


In [None]:
#Variational autoencoder class

class VAE(keras.Model):
    def __init__(self, encoder, decoder, scalar_transformation, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.scalar_transformation = scalar_transformation
      

    def call(self, inputs):
        print(len(inputs))
        z_mean_cross, z_log_var_cross, z_cross = self.encoder(inputs[0])
        z_mean_present, z_log_var_present, z_present = self.encoder(inputs[1])
        z_mean_future, z_log_var_future, z_future = self.encoder(inputs[2])

        scalar_present = self.scalar_transformation(z_present)
        scalar_future = self.scalar_transformation(z_future)
        
        reconstruction = self.decoder(z_cross)

        return reconstruction, z_present, z_future, z_mean_cross, z_log_var_cross, scalar_present, scalar_future

## 9. Koopman Operator Implementation

### Theoretical Foundation:
The Koopman operator provides a linear framework for modeling nonlinear dynamical systems by lifting the state space to an infinite-dimensional function space.

### Key Concepts:
- **Linear dynamics in latent space**: Models temporal evolution as linear transformations
- **Eigenvalue decomposition**: Uses complex conjugate pairs and real eigenvalues
- **Time-varying parameters**: Adapts to different biological aging trajectories

### Architecture Components:

#### 1. Omega Networks:
- **Purpose**: Learn time-varying parameters for temporal dynamics
- **Structure**: Small neural networks (8‚Üí4‚Üí2‚Üílatent_dim)
- **Input**: Latent state coordinates
- **Output**: Omega parameters for eigenvalue computation

#### 2. Complex Conjugate Blocks:
- **Function**: Handles oscillatory dynamics in biological systems
- **Mathematical form**: 2x2 rotation and scaling matrices
- **Parameters**: Real frequency and exponential decay rate

#### 3. Varying Multiplication:
- **Purpose**: Applies learned dynamics to predict future states
- **Process**: Combines complex and real eigenvalue components
- **Output**: Evolved latent state after time delta_t

### Model Parameters:
- **Latent dimension**: 25 (compressed biological state)
- **Complex pairs**: 5 (oscillatory dynamics)
- **Real eigenvalues**: 15 (exponential dynamics)
- **Time step**: 3 years (biological aging interval)


In [None]:
#Koopman operator class

class KoopmanOperator(tf.Module):
    def __init__(self, params):
        super().__init__()
        self.params = params
        self.latent_dim = params['latent_dim']  # Use latent_dim directly from params, no default
        
        # Create all omega networks once during initialization
        self.omega_nets = self.create_all_omega_nets()
        
        # Create transformation layer once during initialization
        self.transformation_layer = layers.Dense(1, activation=None)
    
    def form_complex_conjugate_block(self, omegas, delta_t):
        scale = tf.exp(omegas[:, 1] * delta_t)
        entry11 = tf.multiply(scale, tf.cos(omegas[:, 0] * delta_t))
        entry12 = tf.multiply(scale, tf.sin(omegas[:, 0] * delta_t))
        row1 = tf.stack([entry11, -entry12], axis=1)  # [None, 2]
        row2 = tf.stack([entry12, entry11], axis=1)  # [None, 2]
        result = tf.stack([row1, row2], axis=2)
        print("form_complex_conjugate_block - result shape:", result.shape)
        return result

    def varying_multiply(self, y, omegas, delta_t):
        num_real = self.params.get('num_real', 0)
        num_complex_pairs = self.params.get('num_complex_pairs', 0)
        complex_list = []
        real_list = []

        for j in range(num_complex_pairs):
            ind = 2 * j
            ystack = tf.stack([y[:, ind:ind + 2], y[:, ind:ind + 2]], axis=2)  # [None, 2, 2]
            L_stack = self.form_complex_conjugate_block(omegas[j], delta_t)
            elmtwise_prod = tf.multiply(ystack, L_stack)
            complex_list.append(tf.reduce_sum(elmtwise_prod, 1))

        if len(complex_list) > 0:
            complex_part = tf.concat(complex_list, axis=1)
            print("varying_multiply - complex_part shape:", complex_part.shape)

        for j in range(num_real):
            ind = 2 * num_complex_pairs + j
            temp = y[:, ind]
            real_list.append(tf.multiply(temp[:, tf.newaxis], tf.exp(omegas[num_complex_pairs + j] * delta_t)))

        if len(real_list) > 0:
            real_part = tf.concat(real_list, axis=1)
            print("varying_multiply - real_part shape:", real_part.shape)

        # Ensure the final result has the correct shape
        if len(complex_list) > 0 and len(real_list) > 0:
            result = tf.concat([complex_part, real_part], axis=1)
            result = result[:, :self.latent_dim]  # Trim to latent_dim
            print("varying_multiply - result shape (complex + real):", result.shape)
            return result
        elif len(complex_list) > 0:
            return complex_part
        else:
            return real_part
        
    def create_all_omega_nets(self):
        omega_nets = []
        for j in range(self.params['num_complex_pairs']):
            temp_name = f'OC{j + 1}'
            omega_net = self.create_one_omega_net(temp_name)  # Create model
            omega_nets.append(omega_net)
    
        for j in range(self.params['num_real']):
            temp_name = f'OR{j + 1}'
            omega_net = self.create_one_omega_net(temp_name)  # Create model
            omega_nets.append(omega_net)
    
        print("create_all_omega_nets - number of omega_nets:", len(omega_nets))
        return omega_nets

    def create_one_omega_net(self, name_prefix):
        latent_inputs = layers.Input(shape=(self.latent_dim,))
        
        x = layers.Dense(8, activation="relu", name=f'{name_prefix}_dense1')(latent_inputs)
        x = layers.BatchNormalization(name=f'{name_prefix}_batchnorm1')(x)
        x = layers.Dropout(0.4, name=f'{name_prefix}_dropout1')(x)       
        
        x = residual_block(x, 8, l1_reg=0.01, l2_reg=0.01)
        x = residual_block(x, 4, l1_reg=0.01, l2_reg=0.01, dropout_rate=0.4)
        x = residual_block(x, 2, l1_reg=0.01, l2_reg=0.01)
        
        omega_params = layers.Dense(self.latent_dim, name=f'{name_prefix}_output')(x)
        omegas = tf.keras.Model(latent_inputs, omega_params, name=name_prefix)
        
        return omegas

    def apply_omega_nets(self, ycoords):
        omegas = []
        for j in range(self.params['num_complex_pairs']):
            ind = 2 * j
            pair_of_columns = ycoords[:, ind:ind + 2]
            radius_of_pair = tf.reduce_sum(tf.square(pair_of_columns), axis=1, keepdims=True)
            radius_of_pair = tf.tile(radius_of_pair, [1, self.latent_dim])
            omega_output = self.omega_nets[j](radius_of_pair)
            print(f"apply_omega_nets - omega_net {j} output shape:", omega_output.shape)
            omegas.append(omega_output)
    
        for j in range(self.params['num_real']):
            ind = 2 * self.params['num_complex_pairs'] + j
            one_column = ycoords[:, ind]
            one_column = tf.tile(one_column[:, tf.newaxis], [1, self.latent_dim])
            omega_output = self.omega_nets[self.params['num_complex_pairs'] + j](one_column)
            print(f"apply_omega_nets - omega_net {self.params['num_complex_pairs'] + j} output shape:", omega_output.shape)
            omegas.append(omega_output)
    
        return omegas

    def compute_future_state(self, current_state, delta_t):
        """
        Compute future state based on current state and varying delta_t.
        """
        ycoords = current_state
        omegas = self.apply_omega_nets(ycoords)
        
        # Adjust varying time steps (delta_t)
        future_state = self.varying_multiply(current_state, omegas, delta_t)
        
        # Apply transformation to future state
        trans_future_state = self.transformation_layer(future_state)
        
        return future_state, trans_future_state



# Example parameters for model creation
params = {
    'input_shape': (332909,),  # Example input shape
    'latent_dim': 25,          # Latent space dimension
    'l1_reg': 0.01,            # L1 regularization strength
    'l2_reg': 0.01,            # L2 regularization strength
    'dropout_rate': 0.4,       # Dropout rate
    'delta_t': 3,              # Time step size
    'num_real': 15,            # Number of real eigenvalues
    'num_complex_pairs': 5,   # Number of complex conjugate eigenvalue pairs
    'output_shape': 332909,    # Output shape
}

## 10. Koopman Model Wrapper

### KoopmanModel Class:
A high-level wrapper that orchestrates the Koopman operator for multi-step temporal predictions.

### Key Features:

#### Multi-Step Prediction:
- **Flexible time intervals**: Supports different prediction horizons
- **Iterative evolution**: Uses predicted states as input for next predictions
- **Custom time steps**: Allows varying time intervals (e.g., 3-year, 10-year predictions)

#### Prediction Process:
1. **Input**: Current latent state
2. **Evolution**: Apply Koopman operator for specified time interval
3. **Iteration**: Use evolved state for next prediction step
4. **Output**: Sequence of future states and transformed predictions

#### Applications:
- **Short-term prediction**: 3-year biological aging forecast
- **Long-term prediction**: 10-year aging trajectory
- **Custom intervals**: Flexible time horizons for different clinical needs


In [None]:
#Defining the KoopmanModel class

class KoopmanModel(tf.keras.Model):
    def __init__(self, koopman_operator):
        super(KoopmanModel, self).__init__()
        self.koopman_operator = koopman_operator

    def call(self, input_present, num_future=1, time_intervals=None):
        future_states = []
        trans_future_states = []
        current_state = input_present
        
        # If no custom time intervals are provided, use a fixed delta_t
        if time_intervals is None:
            time_intervals = [self.koopman_operator.params['delta_t']] * num_future
        
        for i in range(num_future):
            delta_t = time_intervals[i]  # Use the appropriate delta_t for each step
            g_next_state, trans_future_state = self.koopman_operator.compute_future_state(current_state, delta_t)
            
            future_states.append(g_next_state)
            trans_future_states.append(trans_future_state)
            
            current_state = g_next_state  # Update current state to the newly predicted state
        
        return future_states, trans_future_states


## 11. Integrated Model Architecture

### MyModel Class Overview:
The main model class that integrates VAE and Koopman operator into a unified framework for biological aging prediction.

### Architecture Integration:
1. **VAE Component**: Handles feature learning and dimensionality reduction
2. **Koopman Component**: Models temporal dynamics in latent space
3. **Custom Training Loop**: Implements multi-objective optimization

### Key Features:

#### Multi-Modal Processing:
- **Cross-sectional learning**: General biological feature representations
- **Temporal modeling**: Present-to-future state transitions
- **Auxiliary reconstruction**: Future state reconstruction for validation

#### Output Components:
- **Reconstruction**: Reconstructed cross-sectional data
- **Latent states**: Present and future compressed representations
- **DAI predictions**: Scalar aging indices
- **Future reconstructions**: Reconstructed future biological states

#### Training Strategy:
- **End-to-end optimization**: Joint training of all components
- **Multi-objective loss**: Balances reconstruction, prediction, and regularization
- **Gradient clipping**: Prevents exploding gradients during training


## 12. Custom Loss Functions

### Multi-Objective Loss Design:
The model optimizes multiple objectives simultaneously to ensure both reconstruction quality and temporal prediction accuracy.

### Loss Components:

#### 1. Reconstruction Loss (Weight: 10.0)
- **Purpose**: Ensures accurate reconstruction of biological data
- **Formula**: Mean squared error between input and reconstructed data
- **Importance**: Maintains fidelity of biological information

#### 2. KL Divergence Loss (Weight: 1.0)
- **Purpose**: Regularizes latent space for meaningful representations
- **Formula**: Standard VAE KL divergence between prior and posterior
- **Benefit**: Encourages smooth, interpretable latent space

#### 3. Linear Dynamics Loss (Weight: 100.0)
- **Purpose**: Ensures Koopman predictions match actual future latent states
- **Formula**: MSE between predicted and actual latent states
- **Critical**: Core temporal modeling objective

#### 4. Future State Loss (Weight: 100.0)
- **Purpose**: Aligns DAI predictions with actual future aging states
- **Formula**: MSE between predicted and actual DAI scores
- **Application**: Direct biological age prediction accuracy

#### 5. Auxiliary Loss (Weight: 10.0)
- **Purpose**: Validates future state reconstruction quality
- **Formula**: MSE between reconstructed and actual future biological data
- **Validation**: Ensures future predictions are biologically plausible

#### 6. L-Infinity Loss (Weight: 1.0)
- **Purpose**: Controls maximum reconstruction errors
- **Formula**: Maximum absolute difference across all features
- **Stability**: Prevents extreme reconstruction errors

### Loss Weighting Strategy:
- **High weights (100.0)**: Temporal prediction accuracy (most critical)
- **Medium weights (10.0)**: Reconstruction quality
- **Low weights (1.0)**: Regularization and stability


In [None]:
#The entire architecture with custom loss functions and gradient initiation

class MyModel(tf.keras.Model):
    def __init__(self, vae, koopman, loss_weights=None, **kwargs):
        super().__init__(**kwargs)
        self.vae = vae
        self.koopman = koopman        
        self.total_loss_tracker = metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = metrics.Mean(name="kl_loss")
        self.linear_dynamics_loss_tracker = metrics.Mean(name="linear_dynamics_loss")
        self.future_state_loss_tracker = metrics.Mean(name="future_state_loss")
        self.aux_loss_tracker = metrics.Mean(name="aux_loss")
        self.l_inf_loss_tracker = metrics.Mean(name="l_inf_loss")

        if loss_weights is None:
            loss_weights = {
                "reconstruction_loss": 10.0,
                "kl_loss": 1.0,
                "linear_dynamics_loss": 100.0,
                "future_state_loss": 100.0,
                "l_inf_loss": 1.0
            }
        self.loss_weights = loss_weights

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
            self.linear_dynamics_loss_tracker,
            self.future_state_loss_tracker,
            self.aux_loss_tracker,
            self.l_inf_loss_tracker
        ]
    
    def call(self, inputs, num_future=1, time_intervals=None):
        """
        Call method adjusted to handle multiple future time intervals (3 years and 10 years).
        """
        reconstruction, z_present, z_future, z_mean_cross, z_log_var_cross, scalar_present, scalar_future = self.vae(inputs)
        
        future_states = []
        trans_future_states = []
        
        current_state = z_present
        
        # Adjust for varying time intervals
        if time_intervals is None:
            time_intervals = [self.koopman.koopman_operator.params['delta_t']] * num_future
        
        for i in range(num_future):
            # Use the respective time interval for each future step
            delta_t = time_intervals[i]
            g_next_state, trans_future_state = self.koopman.koopman_operator.compute_future_state(current_state, delta_t)
            future_states.append(g_next_state)
            trans_future_states.append(trans_future_state)
            
            # Update current state to the newly predicted state
            current_state = g_next_state

        aux_reconstructed = [self.vae.decoder(g_next) for g_next in future_states]
        
        # Return the states
        return reconstruction, z_present, z_future, z_mean_cross, z_log_var_cross, scalar_present, scalar_future, future_states, trans_future_states, aux_reconstructed


    def compute_losses(self, inputs, reconstruction, z_present, z_future, z_mean_cross, z_log_var_cross, scalar_present, scalar_future, k_trans, k_untrans, aux_reconstructed, num_future=1):
        input_data_cross, input_data_present, input_data_future = inputs
    
        # Reconstruction loss (L_recon)
        reconstruction_loss = tf.reduce_mean(tf.reduce_sum(tf.square(input_data_cross - reconstruction), axis=-1))
        
        # KL divergence loss (remains as is)
        kl_loss = -0.5 * tf.reduce_mean(1 + z_log_var_cross - tf.square(z_mean_cross) - tf.exp(z_log_var_cross))
    
        # Future state prediction loss (L_pred)
        linear_dynamics_loss = 0.0
        for i in range(num_future):
            linear_dynamics_loss += tf.reduce_mean(tf.reduce_sum(tf.square(k_untrans[i] - z_future[i]), axis=-1))


        future_state_loss = 0.0
        for i in range(num_future):
            future_state_loss += tf.reduce_mean(tf.reduce_sum(tf.square(k_trans[i] - scalar_future[i]), axis=-1))


        aux_loss = tf.reduce_mean(tf.reduce_sum(tf.square(aux_reconstructed[-1] - input_data_future), axis=-1))
    
        # Infinity Norm Loss (L_inf)
        l_inf_loss = tf.reduce_max(tf.abs(input_data_cross - reconstruction)) + tf.reduce_max(tf.abs(input_data_future - aux_reconstructed[-1]))
    
        # Apply loss weights
        total_loss = (
            self.loss_weights["reconstruction_loss"] * (reconstruction_loss + aux_loss) +
            self.loss_weights["linear_dynamics_loss"] * linear_dynamics_loss +  
            self.loss_weights["kl_loss"] * kl_loss +
            self.loss_weights["future_state_loss"] * future_state_loss +
            self.loss_weights["l_inf_loss"] * l_inf_loss
        )
        
        return total_loss, reconstruction_loss, kl_loss, linear_dynamics_loss, future_state_loss, aux_loss, l_inf_loss

    def train_step(self, data):
        data_unpacked = data[0]
        input_data_cross, input_data_present, input_data_future = data_unpacked
        
        # Define time intervals: 3-year and 10-year prediction
        time_intervals = [3]  # Use [3] for single-step prediction or modify as needed
        
        with tf.GradientTape() as tape:
            reconstruction, z_present, z_future, z_mean_cross, z_log_var_cross, scalar_present, scalar_future, k_untrans, k_trans, aux_reconstructed = self(
                data_unpacked, num_future=len(time_intervals), time_intervals=time_intervals, training=True
            )
        
            # Compute losses
            total_loss, reconstruction_loss, kl_loss, linear_dynamics_loss, future_state_loss, aux_loss, l_inf_loss = self.compute_losses(
                (input_data_cross, input_data_present, input_data_future),
                reconstruction, z_present, z_future, z_mean_cross, z_log_var_cross, scalar_present, scalar_future, k_trans, k_untrans, aux_reconstructed, num_future=len(time_intervals)
            )
        
        # Compute gradients
        gradients = tape.gradient(total_loss, self.trainable_variables)
        
        # Apply gradients
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # Update metrics
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        self.linear_dynamics_loss_tracker.update_state(linear_dynamics_loss)
        self.future_state_loss_tracker.update_state(future_state_loss)
        self.aux_loss_tracker.update_state(aux_loss)
        self.l_inf_loss_tracker.update_state(l_inf_loss)
        
        return {m.name: m.result() for m in self.metrics}
    
    def test_step(self, data):
        data_unpacked = data[0]
        input_data_cross, input_data_present, input_data_future = data_unpacked
        
        # Define time intervals for testing: e.g., 3-year and 10-year prediction
        time_intervals = [3]
        
        reconstruction, z_present, z_future, z_mean_cross, z_log_var_cross, scalar_present, scalar_future, k_trans, k_untrans, aux_reconstructed = self(
            data_unpacked, num_future=len(time_intervals), time_intervals=time_intervals, training=False
        )
        
        # Compute losses
        total_loss, reconstruction_loss, kl_loss, linear_dynamics_loss, future_state_loss, aux_loss, l_inf_loss = self.compute_losses(
            (input_data_cross, input_data_present, input_data_future),
            reconstruction, z_present, z_future, z_mean_cross, z_log_var_cross, scalar_present, scalar_future, k_trans, k_untrans, aux_reconstructed, num_future=len(time_intervals)
        )
        
        # Update metrics
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        self.linear_dynamics_loss_tracker.update_state(linear_dynamics_loss)
        self.future_state_loss_tracker.update_state(future_state_loss)
        self.aux_loss_tracker.update_state(aux_loss)
        self.l_inf_loss_tracker.update_state(l_inf_loss)
        
        return {m.name: m.result() for m in self.metrics}


## 13. Training Configuration and Utilities

### Training Utilities:

#### Time Tracking:
- **hms_string function**: Converts elapsed seconds to hours:minutes:seconds format
- **Purpose**: Monitor training duration for resource planning
- **Application**: Performance benchmarking and optimization

#### Learning Rate Scheduling:
- **Strategy**: Exponential decay every 250 epochs
- **Decay factor**: 0.1 (reduces learning rate by 90%)
- **Purpose**: Fine-tuning in later training stages
- **Benefits**: Improved convergence and stability

#### Logging Configuration:
- **TensorFlow logging**: Set to ERROR level to reduce verbosity
- **Purpose**: Clean output during training
- **Focus**: Essential information only

### Training Parameters:
- **Optimizer**: Adam with learning rate 0.0001
- **Gradient clipping**: Prevents exploding gradients (clipvalue=1.0, clipnorm=1.0)
- **Checkpointing**: Saves model weights every epoch
- **Validation**: Monitors performance on validation set


In [None]:
def hms_string(sec_elapsed):
    h = int(sec_elapsed / (60 * 60))
    m = int((sec_elapsed % (60 * 60)) / 60)
    s = sec_elapsed % 60
    return "{}:{:>02}:{:>05.2f}".format(h, m, s)

tf.get_logger().setLevel('ERROR')

In [None]:
def lr_scheduler(epoch, lr):
    if epoch % 250 == 0 and epoch != 0:
        return lr * 0.1  # reduce learning rate by a factor of 10
    else:
        return lr

# Create a learning rate scheduler callback
lr_scheduler_callback = LearningRateScheduler(lr_scheduler)

model.save_weights(filepath.format(epoch=0))

## 14. Model Initialization and Training Setup

### Model Assembly:
The main execution block creates and configures the complete Dynamic Aging Index model:

#### Component Creation:
1. **ScalarTransformation**: DAI projection layer
2. **VAE**: Variational autoencoder with encoder, decoder, and DAI projection
3. **KoopmanOperator**: Temporal dynamics modeling
4. **KoopmanModel**: High-level Koopman wrapper
5. **MyModel**: Integrated model combining all components

#### Training Configuration:
- **Optimizer**: Adam with adaptive learning rate and gradient clipping
- **Checkpointing**: Saves model weights to specified directory
- **Device**: GPU acceleration for faster training
- **Initial weights**: Saved at epoch 0 for backup

#### Key Features:
- **Modular design**: Each component can be modified independently
- **Checkpoint system**: Automatic model saving for recovery
- **GPU optimization**: Leverages CUDA acceleration
- **Gradient stability**: Clipping prevents training instability


In [None]:
if __name__ == '__main__':
    import time

    start = time.time()

    scalar_transformation = ScalarTransformation()    
    vae = VAE(encoder, decoder, scalar_transformation)
    koopman_operator = KoopmanOperator(params)
    koopman = KoopmanModel(koopman_operator)
    
    model = MyModel(vae, koopman)
    checkpoint_path = 'C:\\Users\\Best model\\saved-model-{epoch:02d}DAF.ckpt'
    checkpoint = tf.train.Checkpoint(model=model)

    model.compile(optimizer=keras.optimizers.Adam(learning_rate = 0.0001, clipvalue=1.0, clipnorm=1.0))
    

## 15. Model Training Execution

### Training Process:
The final cell executes the complete training pipeline for the Dynamic Aging Index model.

#### Training Configuration:
- **Epochs**: 1000 (extensive training for convergence)
- **Device**: GPU:0 for acceleration
- **Validation**: Every epoch for monitoring
- **Callbacks**: Checkpoint saving and learning rate scheduling

#### Training Features:
- **Multi-objective optimization**: Simultaneous optimization of all loss components
- **Automatic checkpointing**: Saves best models during training
- **Validation monitoring**: Tracks performance on unseen data
- **Time tracking**: Records total training duration

#### Expected Outcomes:
- **Convergence**: Loss values should decrease over epochs
- **Model weights**: Saved checkpoints for inference
- **Training time**: Recorded for performance analysis
- **Validation metrics**: Performance on held-out data

### Training Benefits:
- **End-to-end learning**: All components trained jointly
- **Robust optimization**: Multiple loss objectives ensure comprehensive learning
- **Scalable training**: Handles large biological datasets efficiently
- **Reproducible results**: Deterministic training with fixed random seeds


In [None]:
#Begins the training

with tf.device('/GPU:0'):
    hist = model.fit(
        train_loader,
        epochs=1000,
        validation_data=val_loader,
        validation_freq=1,
        callbacks=[checkpoint]
    )

    elapsed = time.time() - start
    print(f'Training time: {hms_string(elapsed)}')
#     print(hist.history)

---

## üéØ Summary and Usage Instructions

### Model Overview
The Dynamic Aging Index (DAI) model is a sophisticated deep learning architecture that combines:
- **Variational Autoencoders** for biological feature learning
- **Koopman Operator Theory** for temporal dynamics modeling
- **Multi-objective optimization** for robust biological age prediction

### Key Capabilities
‚úÖ **High-dimensional data processing**: Handles 332,909 biological features  
‚úÖ **Temporal prediction**: Forecasts aging states 3-10 years into the future  
‚úÖ **Biological age quantification**: Provides interpretable DAI scores  
‚úÖ **Robust training**: Multi-objective loss with regularization  
‚úÖ **Scalable architecture**: GPU-accelerated training and inference  

### Usage Workflow

#### 1. **Data Preparation**
- Ensure your data follows the required format (cross-sectional, present, future states)
- Normalize features using MinMaxScaler
- Split into train/validation/test sets

#### 2. **Model Training**
```python
# Run the training cell to train the complete model
# Training will automatically save checkpoints
# Monitor loss curves for convergence
```

#### 3. **Model Inference**
```python
# Load trained model from checkpoint
# Input: Biological measurements (332,909 features)
# Output: DAI score and future state predictions
```

#### 4. **Interpretation**
- **DAI Score**: Higher values indicate accelerated aging
- **Future Predictions**: Biological states at specified time horizons
- **Confidence**: Monitor reconstruction and prediction losses

### Performance Expectations
- **Training time**: 10-20 hours (depending on hardware)
- **Memory requirements**: 8-16 GB GPU memory recommended
- **Convergence**: Typically 500-800 epochs for stable results
- **Accuracy**: Validates on held-out biological data

### Applications
üè• **Clinical Research**: Biological age assessment and aging biomarker discovery  
üß¨ **Drug Development**: Anti-aging intervention evaluation  
üìä **Health Monitoring**: Personalized aging trajectory prediction  
üî¨ **Biomarker Discovery**: Identification of key aging indicators  

### Technical Requirements
- **Python 3.8+**
- **TensorFlow 2.x**
- **CUDA-compatible GPU** (recommended)
- **16+ GB RAM**
- **Sufficient storage** for large biological datasets

---

*This notebook provides a complete implementation of the Dynamic Aging Index model for biological age prediction. The architecture combines state-of-the-art deep learning techniques with theoretical foundations from dynamical systems theory to provide accurate, interpretable predictions of biological aging.*
