<a href="https://colab.research.google.com/github/douglasmasho/MedAlgo/blob/main/Survival.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
import os
import glob

# Mount Google Drive
drive.mount('/content/drive')

# data_dir = '/content/drive/MyDrive/BRATS/MICCAI_BraTS_2019_Data_Training'

# survival_data = os.path.join(data_dir, "survival_data.csv")
# hgg_images = os.path.join(data_dir, "HGG")
# lgg_images = os.path.join(data_dir, "LGG")

# print(hgg_images)

Mounted at /content/drive


In [None]:
import os
import numpy as np
import pandas as pd
import nibabel as nib
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import img_to_array, array_to_img
from tensorflow.keras import layers, models
import tensorflow as tf

# Load survival data
survival_data = pd.read_csv('/content/drive/MyDrive/BRATS/MICCAI_BraTS_2019_Data_Training/processed_survival_data.csv')

# Filter subjects with non-NA Age and Survival
survival_data = survival_data.dropna(subset=['Age', 'Survival', 'Status'])

# Process the survival data
def process_survival(survival):
    if 'ALIVE' in survival:
        days = int(survival.split('(')[1].split()[0])
        return days
    else:
        return -int(survival)

survival_data['Survival'] = survival_data['Survival'].apply(process_survival)

# Define a function to get 2D slices from 3D MRI images
def get_2d_slices(img_3d):
    slices = []
    for i in range(img_3d.shape[-1]):
        slice_2d = img_3d[:, :, i]
        # Add a channel dimension (height, width, channels)
        slice_2d = np.expand_dims(slice_2d, axis=-1)
        slices.append(slice_2d)
    return slices

# Load and preprocess the data
def load_and_preprocess_data(root_dirs, df):
    images = []
    survival_days = []
    statuses = []
    ages = []

    for root_dir in root_dirs:
        for subject_dir in os.listdir(root_dir):
            subject_path = os.path.join(root_dir, subject_dir)
            if os.path.isdir(subject_path):
                img_path = os.path.join(subject_path, f'{subject_dir}_t1ce.nii')
                if os.path.isfile(img_path):
                    # Extract Brats19ID from the subject directory name
                    brats_id = subject_dir
                    if brats_id in df['BraTS19ID'].values:
                        # Get the corresponding survival and age data
                        row = df[df['BraTS19ID'] == brats_id].iloc[0]
                        age = row['Age']
                        survival = row['Survival']
                        status = row['Status']

                        img_3d = nib.load(img_path).get_fdata()
                        slices = get_2d_slices(img_3d)

                        for slice_2d in slices:
                            # Resize the image if needed
                            img_resized = array_to_img(slice_2d, scale=True).resize((64, 64))
                            img_array = img_to_array(img_resized)
                            images.append(img_array)
                            survival_days.append(survival)
                            statuses.append(1 if status == 'alive' else 0)
                            ages.append(age)

    return np.array(images), np.array(ages), np.array(survival_days), np.array(statuses)

# Define directories for both HGG and LGG
hgg_dir = '/content/drive/MyDrive/BRATS/MICCAI_BraTS_2019_Data_Training/HGG'
lgg_dir = '/content/drive/MyDrive/BRATS/MICCAI_BraTS_2019_Data_Training/LGG'
all_dirs = [hgg_dir, lgg_dir]

# Load and preprocess the data
images, ages, survival_days, statuses = load_and_preprocess_data(all_dirs, survival_data)

# Split the dataset into training and validation sets
X_train, X_val, age_train, age_val, y_train, y_val, status_train, status_val = train_test_split(
    images, ages, survival_days, statuses, test_size=0.2, random_state=42
)

# Define the model
def create_model(input_shape):
    inputs = tf.keras.Input(shape=input_shape)
    x = layers.Conv2D(32, (3, 3), activation='relu')(inputs)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(64, (3, 3), activation='relu')(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(128, (3, 3), activation='relu')(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Flatten()(x)
    x = layers.Dense(128, activation='relu')(x)

    age_input = tf.keras.Input(shape=(1,))
    age_dense = layers.Dense(64, activation='relu')(age_input)

    combined = layers.concatenate([x, age_dense])

    # Output layer for survival days
    survival_output = layers.Dense(1, name='survival')(combined)

    # Output layer for status
    status_output = layers.Dense(1, activation='sigmoid', name='status')(combined)

    model = tf.keras.Model(inputs=[inputs, age_input], outputs=[survival_output, status_output])
    model.compile(optimizer='adam',
                  loss={'survival': 'mse', 'status': 'binary_crossentropy'},
                  metrics={'status': 'accuracy'})
    return model

# Create and train the model
input_shape = (64, 64, 1)
model = create_model(input_shape)

# Define a callback to save the best model
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    'best_glioma_survival_model.keras',
    monitor='val_status_accuracy',
    save_best_only=True,
    mode='max'
)

history = model.fit(
    [X_train, age_train],
    {'survival': y_train, 'status': status_train},
    epochs=200,
    batch_size=32,
    validation_data=([X_val, age_val], {'survival': y_val, 'status': status_val}),
    callbacks=[checkpoint_callback]
)

# Evaluate the model
results = model.evaluate([X_val, age_val], {'survival': y_val, 'status': status_val})

# Print the results
print("Evaluation Results:")
print(results)
# print(f"Total Loss: {results[0]}")
# print(f"Survival Loss: {results[1]}")
# print(f"Status Loss: {results[2]}")
# print(f"Status Accuracy: {results[3]}")
# print("")
# Save the final model if needed
model.save('glioma_survival_model_with_status.keras')


Epoch 1/200
[1m822/822[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 9ms/step - loss: 157521.1562 - status_accuracy: 0.9875 - val_loss: 129528.3281 - val_status_accuracy: 0.9889
Epoch 2/200
[1m822/822[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 6ms/step - loss: 130219.5547 - status_accuracy: 0.9882 - val_loss: 121136.0938 - val_status_accuracy: 0.9915
Epoch 3/200
[1m822/822[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 6ms/step - loss: 118872.1094 - status_accuracy: 0.9868 - val_loss: 104909.1250 - val_status_accuracy: 0.9865
Epoch 4/200
[1m822/822[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 6ms/step - loss: 97584.2266 - status_accuracy: 0.9847 - val_loss: 91046.1641 - val_status_accuracy: 0.9909
Epoch 5/200
[1m822/822[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 6ms/step - loss: 79708.1797 - status_accuracy: 0.9840 - val_loss: 77931.0859 - val_status_accuracy: 0.9916
Epoch 6/200
[1m822/822[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m 

IndexError: list index out of range