<a href="https://colab.research.google.com/github/hmMed22/ARTSS/blob/main/ViT_OSS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.utils.class_weight import compute_class_weight
import os
from tensorflow.keras.utils import Sequence


In [None]:


# Set global variables
image_size = (224, 224)  # Input image dimensions
patch_size = 16  # Size of each patch
max_joints = 22  # Maximum number of joint images per patient
num_classes = 1  # Regression output
dim = 16  # Embedding dimension
num_heads = 2  # Number of attention heads
mlp_dim = 32  # Hidden dimension in MLP
num_layers = 3  # Number of Transformer layers
batch_size=8
# Custom Layer for Patch Extraction
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        # images shape: (batch_size, H, W, C)
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding='VALID'
        )
        # Flatten the patches
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [tf.shape(images)[0], -1, patch_dims])
        return patches

# Positional Encoding Layer
class PositionalEncoding(layers.Layer):
    def __init__(self, num_patches, embed_dim):
        super(PositionalEncoding, self).__init__()
        self.pos_embedding = self.add_weight(
            name="pos_embedding",
            shape=(1, num_patches, embed_dim),
            initializer='random_normal',
            trainable=True
        )

    def call(self, x):
        return x + self.pos_embedding

# Build the Vision Transformer Model
def build_vit_model():
    # Input layer
    inputs = layers.Input(shape=(max_joints, image_size[0], image_size[1], 3))

    # Reshape inputs to process each joint image individually
    def reshape_inputs(tensor):
        # tensor shape: (batch_size, max_joints, H, W, C)
        batch_size = tf.shape(tensor)[0]
        reshaped = tf.reshape(tensor, [batch_size * max_joints, image_size[0], image_size[1], 3])
        return reshaped

    x = layers.Lambda(reshape_inputs, output_shape=(image_size[0], image_size[1], 3))(inputs)

    # Create patches
    patches = Patches(patch_size)(x)  # Shape: (batch_size * max_joints, num_patches, patch_dims)

    # Linear projection of flattened patches
    projection = layers.Dense(dim)
    x = projection(patches)  # Shape: (batch_size * max_joints, num_patches, dim)

    # Positional Encoding
    num_patches = (image_size[0] // patch_size) * (image_size[1] // patch_size)
    x = PositionalEncoding(num_patches, dim)(x)

    # Transformer Encoder Blocks
    for _ in range(num_layers):
        # Layer Normalization
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)

        # Multi-Head Self-Attention
        attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=dim)(x1, x1)
        x2 = layers.Add()([x, attention_output])  # Residual Connection

        # Layer Normalization
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)

        # Feed-Forward Network
        ffn_output = layers.Dense(mlp_dim, activation='relu')(x3)
        ffn_output = layers.Dense(dim)(ffn_output)
        x = layers.Add()([x2, ffn_output])  # Residual Connection

    # Global Average Pooling over the patches
    x = layers.GlobalAveragePooling1D()(x)  # Shape: (batch_size * max_joints, dim)

    # Reshape back to (batch_size, max_joints, dim)
    def reshape_back(tensor):
        batch_size = tf.shape(tensor)[0] // max_joints
        return tf.reshape(tensor, [batch_size, max_joints, dim])

    x = layers.Lambda(reshape_back, output_shape=(max_joints, dim))(x)

    # Aggregate over joints (e.g., average pooling)
    x = layers.GlobalAveragePooling1D()(x)  # Shape: (batch_size, dim)

    # Optional Dropout
    x = layers.Dropout(0.1)(x)

    # Output layer
    outputs = layers.Dense(num_classes)(x)  # Shape: (batch_size, num_classes)

    # Define the model
    model = models.Model(inputs=inputs, outputs=outputs)
    return model

# Instantiate the model
vit_model = build_vit_model()

# Compile the model
vit_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.Huber(),
    metrics=[
        tf.keras.metrics.MeanAbsoluteError(name='mae'),
        tf.keras.metrics.RootMeanSquaredError(name='rmse')
    ]
)

vit_model.summary()


In [None]:
epochs = 50
batch_size = 8
learning_rate_reduction_patience = 5
early_stopping_patience = 10

# Data loading: Replace 'your_data.csv' with your actual CSV file path
data_path = 'your_data.csv'
df = pd.read_csv(data_path)

# Assume CSV has columns 'patient_id', 'image_path', and 'score_avg'
image_paths = df['image_path'].tolist()
scores = df['score_avg'].tolist()

# Binning the scores for stratification
bins = [0, 50, 100, np.inf]
labels = [0, 1, 2]  # 0: score < 50, 1: 50-100, 2: >100
df['score_bin'] = pd.cut(df['score_avg'], bins=bins, labels=labels)

# Stratified train-test split
train_df, test_df = train_test_split(df, test_size=0.3, stratify=df['score_bin'], random_state=42)

# Data augmentation for the training set
data_gen = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    #shear_range=0.1,
   # zoom_range=0.1,
    horizontal_flip=True,
    fill_mode='nearest'
)

# Padding function to ensure each sample has the same number of joints
def pad_images(images, max_joints):
    num_images = len(images)
    if num_images >= max_joints:
        return images[:max_joints]
    else:
        padding_needed = max_joints - num_images
        padding_images = [np.zeros((image_size[0], image_size[1], 3)) for _ in range(padding_needed)]
        return images + padding_images

# Custom Data Generator
class DataGenerator(Sequence):
    def __init__(self, df, batch_size, augment=False):
        self.df = df
        self.batch_size = batch_size
        self.augment = augment
        self.indices = np.arange(len(self.df))

    def __len__(self):
        return int(np.ceil(len(self.df) / self.batch_size))

    def __getitem__(self, index):
        batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]
        batch_df = self.df.iloc[batch_indices]
        images, labels = self.__data_generation(batch_df)
        return images, labels

    def __data_generation(self, batch_df):
        batch_images = []
        batch_labels = []

        for _, row in batch_df.iterrows():
            # Load images for the current patient
            patient_images = [tf.image.resize(tf.image.decode_jpeg(tf.io.read_file(img_path)), image_size)
                              for img_path in row['image_path'].split(';')]  # Assuming image paths are ';' separated

            # Apply padding
            patient_images = pad_images(patient_images, max_joints)
            patient_images = np.stack(patient_images)  # Shape: (max_joints, H, W, C)

            # Apply augmentation if set
            if self.augment:
                augmented_images = [data_gen.random_transform(img) for img in patient_images]
                patient_images = np.stack(augmented_images)

            batch_images.append(patient_images)
            batch_labels.append(row['score_avg'])

        return np.array(batch_images), np.array(batch_labels)

    def on_epoch_end(self):
        np.random.shuffle(self.indices)

# Instantiate the data generators
train_generator = DataGenerator(train_df, batch_size=batch_size, augment=True)
test_generator = DataGenerator(test_df, batch_size=batch_size, augment=False)

# Stratified k-fold cross-validation
n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

# Lists to store metrics for each fold
rmse_list = []
mae_list = []
huber_list = []

for fold, (train_index, val_index) in enumerate(skf.split(train_df, train_df['score_bin'])):
    print(f"Training fold {fold + 1}/{n_splits}...")
    fold_train_df = train_df.iloc[train_index]
    fold_val_df = train_df.iloc[val_index]

    # Create fold-specific data generators
    fold_train_generator = DataGenerator(fold_train_df, batch_size=batch_size, augment=True)
    fold_val_generator = DataGenerator(fold_val_df, batch_size=batch_size, augment=False)

    # Build model
    model = build_vit_model()
    model.compile(
        optimizer='adam',
        loss=tf.keras.losses.Huber(),
        metrics=[tf.keras.metrics.MeanAbsoluteError(name='mae'), tf.keras.metrics.RootMeanSquaredError(name='rmse')]
    )

    # Define callbacks
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=learning_rate_reduction_patience, min_lr=1e-6, verbose=1)
    early_stopping = EarlyStopping(monitor='val_loss', patience=early_stopping_patience, restore_best_weights=True, verbose=1)

    # Train the model
    model.fit(
        fold_train_generator,
        validation_data=fold_val_generator,
        epochs=epochs,
        steps_per_epoch=len(fold_train_generator),
        validation_steps=len(fold_val_generator),
        callbacks=[reduce_lr, early_stopping]
    )

    # Evaluate the model on the validation set
    val_loss, val_mae, val_rmse = model.evaluate(fold_val_generator)
    huber_list.append(val_loss)
    mae_list.append(val_mae)
    rmse_list.append(val_rmse)

    # Save model for each fold
    model.save(f'model_fold_{fold + 1}.h5')

    # Log metrics for each fold
    print(f"Fold {fold + 1} - Huber Loss: {val_loss:.4f}, MAE: {val_mae:.4f}, RMSE: {val_rmse:.4f}")

# Calculate and save average metrics across all folds
average_huber = np.mean(huber_list)
average_mae = np.mean(mae_list)
average_rmse = np.mean(rmse_list)

print(f"Average Huber Loss: {average_huber:.4f}")
print(f"Average MAE: {average_mae:.4f}")
print(f"Average RMSE: {average_rmse:.4f}")

# Save metrics to a log file
with open('cross_validation_metrics.txt', 'w') as f:
    for fold in range(n_splits):
        f.write(f"Fold {fold + 1} - Huber Loss: {huber_list[fold]:.4f}, MAE: {mae_list[fold]:.4f}, RMSE: {rmse_list[fold]:.4f}\n")
    f.write(f"\nAverage Huber Loss: {average_huber:.4f}\n")
    f.write(f"Average MAE: {average_mae:.4f}\n")
    f.write(f"Average RMSE: {average_rmse:.4f}\n")

# Test the final model using the test generator
test_loss, test_mae, test_rmse = model.evaluate(test_generator)
print(f'Test MAE: {test_mae:.4f}, Test RMSE: {test_rmse:.4f}')
