In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, ConvLSTM3D, Conv3D, Flatten, Dense, Dropout, BatchNormalization, GlobalAveragePooling3D, Concatenate, TimeDistributed, MaxPooling3D, AveragePooling3D, GlobalMaxPooling3D, LSTM, Lambda
from tensorflow.keras.models import Model
import numpy as np
import pickle as pkl




In [2]:
devices = tf.config.list_physical_devices()
print(devices)

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]


In [3]:
tf.config.set_visible_devices([], 'GPU')

In [4]:
# importing data
def returning_pkl_file_data(path : str):
    with open(path, 'rb') as f:
        temp = pkl.load(f)
    return temp

mci_func = returning_pkl_file_data(r"feature_extraction/MCI_func_52_79_95_79_197.pkl")
mci_struct = returning_pkl_file_data(r'feature_extraction/MCI_struct_cat_52_169_205_169.pkl')
cn_func = returning_pkl_file_data(r'feature_extraction/CN_func_42_79_95_79_197.pkl')
cn_struct = returning_pkl_file_data(r'feature_extraction/CN_struct_cat_42_169_205_169.pkl')

In [5]:
func_data = np.concat((mci_func, cn_func), axis=0, dtype=np.float16)
func_data.shape

(94, 79, 95, 79, 197)

In [6]:
struct_data = np.concat((mci_struct, cn_struct), axis=0, dtype=np.float16)
struct_data.shape

(94, 169, 205, 169)

In [7]:
all_labels = np.concat((np.zeros((len(mci_func),)), np.ones((len(cn_func),))))
all_labels.shape

(94,)

In [8]:
func_data = np.expand_dims(func_data, axis=len(func_data.shape))
struct_data = np.expand_dims(struct_data, axis=len(struct_data.shape))

func_data.shape, struct_data.shape

((94, 79, 95, 79, 197, 1), (94, 169, 205, 169, 1))

In [19]:
# 2. Create memory-efficient dataset using generator
class BrainDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, func_data, struct_data, labels, batch_size=4):
        self.func_data = func_data
        self.struct_data = struct_data
        self.labels = labels
        self.batch_size = batch_size
        self.indices = np.arange(len(func_data))
        
    def __len__(self):
        return int(np.ceil(len(self.func_data) / self.batch_size))
    
    def __getitem__(self, idx):
        batch_indices = self.indices[idx*self.batch_size:(idx+1)*self.batch_size]
        return (
            {
                'func': self.func_data[batch_indices],
                'struct': self.struct_data[batch_indices]
            },
            self.labels[batch_indices]
        )

In [20]:
from sklearn.model_selection import train_test_split

# Split indices to avoid data duplication
indices = np.arange(len(func_data))
train_idx, test_idx = train_test_split(indices, test_size=0.2, random_state=42)

# Create memory-mapped datasets
train_generator = BrainDataGenerator(
    func_data[train_idx], 
    struct_data[train_idx], 
    all_labels[train_idx]
)

test_generator = BrainDataGenerator(
    func_data[test_idx], 
    struct_data[test_idx], 
    all_labels[test_idx],
    batch_size=len(test_idx)  # Full batch for testing
)

# 4. Convert to tf.data.Dataset with optimized pipeline
def create_tf_dataset(generator, training=True):
    dataset = tf.data.Dataset.from_generator(
        lambda: generator,
        output_signature=(
            {
                'func': tf.TensorSpec(shape=(None, *func_data.shape[1:]), dtype=tf.float16),
                'struct': tf.TensorSpec(shape=(None, *struct_data.shape[1:]), dtype=tf.float16)
            },
            tf.TensorSpec(shape=(None,), dtype=tf.float16)
        )
    )
    
    if training:
        dataset = dataset.repeat()  # Infinite repetition for training
        dataset = dataset.shuffle(100)
        
    dataset = dataset.cache().prefetch(tf.data.AUTOTUNE)
    return dataset

train_dataset = create_tf_dataset(train_generator)
test_dataset = create_tf_dataset(test_generator)

In [21]:
# --- fMRI Model (ConvLSTM) ---
def build_fmri_model():
    fmri_input = Input(shape=func_data.shape[1:], name="func", dtype=tf.float16)
    x = Lambda(lambda x: tf.transpose(x, perm=[0, 4, 1, 2, 3, 5]))(fmri_input)
    # Apply Conv3D independently to each time step
    x = TimeDistributed(Conv3D(32, (3,3,3), padding="same", activation="relu"))(x)
    x = TimeDistributed(MaxPooling3D(2))(x)  # Downsample spatial dimensions
    x = TimeDistributed(Conv3D(64, (3,3,3), padding="same", activation="relu"))(x)
    x = TimeDistributed(GlobalAveragePooling3D())(x)  # Shape: (batch, time=197, 64)
    # Temporal modeling with LSTM
    x = LSTM(128)(x)  # Output shape: (batch, 128)
    return Model(inputs=fmri_input, outputs=x, name="fMRI_Model")

# --- sMRI Model (3D CNN) ---
def build_smri_model():
    smri_input = Input(shape=struct_data.shape[1:], name="struct", dtype=tf.float16)

    y = Conv3D(filters=32, kernel_size=(3,3,3), activation="relu", padding="valid")(smri_input)
    y = BatchNormalization()(y)
    y = Conv3D(filters=32, kernel_size=(3,3,3), activation="relu", padding="valid")(y)
    y = BatchNormalization()(y)

    y = GlobalAveragePooling3D()(y)
    y = Dense(128, activation="relu")(y)

    return Model(inputs=smri_input, outputs=y, name="sMRI_Model")

# --- Combine fMRI & sMRI Models ---
def build_combined_model():
    fmri_model = build_fmri_model()
    smri_model = build_smri_model()

    combined = Concatenate()([fmri_model.output, smri_model.output])
    combined = Dense(128, activation="relu")(combined)
    combined = Dropout(0.5)(combined)
    output = Dense(1, activation="sigmoid")(combined)

    model = Model(inputs=[fmri_model.input, smri_model.input], outputs=output, name="Combined_Model")
    return model

# Build & Compile Model
model = build_combined_model()
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
model.summary()


Model: "Combined_Model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 func (InputLayer)           [(None, 5, 4, 4, 4, 1)]      0         []                            
                                                                                                  
 struct (InputLayer)         [(None, 10, 10, 10, 1)]      0         []                            
                                                                                                  
 lambda_3 (Lambda)           (None, 4, 5, 4, 4, 1)        0         ['func[0][0]']                
                                                                                                  
 conv3d_12 (Conv3D)          (None, 8, 8, 8, 32)          896       ['struct[0][0]']              
                                                                                     

In [22]:
func_data = np.random.rand(100, 5, 4, 4, 4, 1)
struct_data = np.random.rand(100, 10, 10, 10, 1)
all_labels = np.random.randint(0, 2, size=(100,))
func_data.shape, struct_data.shape, all_labels.shape

((100, 5, 4, 4, 4, 1), (100, 10, 10, 10, 1), (100,))

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
train_fmri, val_fmri, train_smri, val_smri, train_labels, val_labels = train_test_split(fmri_data, smri_data, labels, test_size=0.2, random_state=42)
train_fmri.shape, val_fmri.shape, train_smri.shape, val_smri.shape, train_labels.shape, val_labels.shape

((80, 5, 4, 4, 4, 1),
 (20, 5, 4, 4, 4, 1),
 (80, 10, 10, 10, 1),
 (20, 10, 10, 10, 1),
 (80,),
 (20,))

In [23]:
history = model.fit(
    train_dataset,
    validation_data=test_dataset,
    epochs=10,
    steps_per_epoch=len(train_generator),
    validation_steps=1
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
