In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, ConvLSTM2D, GlobalAveragePooling2D, Conv3D, Flatten, Dense, Dropout, LayerNormalization, GlobalAveragePooling3D, Concatenate, TimeDistributed, MaxPooling3D, AveragePooling3D, GlobalMaxPooling3D, LSTM, Lambda
from tensorflow.keras.models import Model
import numpy as np
import pickle as pkl

2025-03-09 07:29:12.988384: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-09 07:29:13.005247: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741485553.025704 1360977 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741485553.032047 1360977 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-09 07:29:13.052758: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
devices = tf.config.list_physical_devices()
print(devices)

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [3]:
tf.config.set_visible_devices([], 'GPU')

In [4]:
# importing data
def returning_pkl_file_data(path : str):
    with open(path, 'rb') as f:
        temp = pkl.load(f)
    return temp

mci_func = returning_pkl_file_data(r"feature_extraction/MCI_func_52_79_95_79_197.pkl")
mci_struct = returning_pkl_file_data(r'feature_extraction/MCI_struct_cat_52_169_205_169.pkl')
cn_func = returning_pkl_file_data(r'feature_extraction/CN_func_42_79_95_79_197.pkl')
cn_struct = returning_pkl_file_data(r'feature_extraction/CN_struct_cat_42_169_205_169.pkl')

In [5]:
func_data = np.concat((mci_func, cn_func), axis=0, dtype=np.float16)
func_data.shape

(94, 79, 95, 79, 197)

In [6]:
struct_data = np.concat((mci_struct, cn_struct), axis=0, dtype=np.float16)
struct_data.shape

(94, 169, 205, 169)

In [7]:
all_labels = np.concat((np.zeros((len(mci_func),)), np.ones((len(cn_func),))))
all_labels.shape

(94,)

In [8]:
# func_data = np.expand_dims(func_data, axis=len(func_data.shape))
struct_data = np.expand_dims(struct_data, axis=len(struct_data.shape))

func_data.shape, struct_data.shape

((94, 79, 95, 79, 197), (94, 169, 205, 169, 1))

In [9]:
from sklearn.model_selection import train_test_split

func_train, func_test, struct_train, struct_test, y_train, y_test = train_test_split(
    func_data, struct_data, all_labels, test_size=0.2, random_state=42
)

func_train.shape, func_test.shape, struct_train.shape, struct_test.shape, y_train.shape, y_test.shape

((75, 79, 95, 79, 197),
 (19, 79, 95, 79, 197),
 (75, 169, 205, 169, 1),
 (19, 169, 205, 169, 1),
 (75,),
 (19,))

In [10]:
# --- fMRI Model (ConvLSTM) ---
def build_fmri_model():
    fmri_input = Input(shape=func_data.shape[1:], name="fmri_input", dtype=tf.float16)
    x = Lambda(lambda x: tf.transpose(x, perm=[0, 4, 1, 2, 3]))(fmri_input)
    # Apply Conv3D independently to each time step
    x = (ConvLSTM2D(2, kernel_size=3, return_sequences=True, padding="same", activation="relu"))(x)
    # x = (MaxPooling3D(2))(x)  # Downsample spatial dimensions
    x = (ConvLSTM2D(2, kernel_size=3, return_sequences=True, padding="same", activation="relu"))(x)
    x = TimeDistributed(GlobalAveragePooling2D())(x)  # Shape: (batch, time=197, 64)

    # Temporal modeling with LSTM
    # x = LSTM(128)(x)  # Output shape: (batch, 128)
    x = Flatten()(x)
    x = Dense(128)(x)
    return Model(inputs=fmri_input, outputs=x, name="fMRI_Model")

# --- sMRI Model (3D CNN) ---
def build_smri_model():
    smri_input = Input(shape=struct_data.shape[1:], name="smri_input", dtype=tf.float16)

    y = Conv3D(filters=32, kernel_size=3, activation="relu", padding="valid")(smri_input)
    y = LayerNormalization()(y)
    y = Conv3D(filters=32, kernel_size=3, activation="relu", padding="valid")(y)
    y = LayerNormalization()(y)

    y = GlobalAveragePooling3D()(y)
    y = Dense(128, activation="relu")(y)

    return Model(inputs=smri_input, outputs=y, name="sMRI_Model")

# --- Combine fMRI & sMRI Models ---
def build_combined_model():
    fmri_model = build_fmri_model()
    smri_model = build_smri_model()

    combined = Concatenate()([fmri_model.output, smri_model.output])
    combined = Dense(128, activation="relu")(combined)
    combined = Dropout(0.5)(combined)
    output = Dense(1, activation="sigmoid")(combined)

    model = Model(inputs=[fmri_model.input, smri_model.input], outputs=output, name="Combined_Model")
    return model

# Build & Compile Model
model = build_combined_model()
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
model.summary()


In [26]:
func_data = np.random.rand(100, 4, 4, 4, 5)
struct_data = np.random.rand(100, 10, 10, 10, 1)
all_labels = np.random.randint(0, 2, size=(100,))
func_data.shape, struct_data.shape, all_labels.shape

((100, 4, 4, 4, 5), (100, 10, 10, 10, 1), (100,))

In [None]:
from sklearn.model_selection import train_test_split

In [27]:
func_train, func_test, struct_train, struct_test, y_train, y_test = train_test_split(func_data, struct_data, all_labels, test_size=0.2, random_state=42)
func_train.shape, func_test.shape, struct_train.shape, struct_test.shape, y_train.shape, y_test.shape

((80, 4, 4, 4, 5),
 (20, 4, 4, 4, 5),
 (80, 10, 10, 10, 1),
 (20, 10, 10, 10, 1),
 (80,),
 (20,))

In [None]:
history = model.fit(
    {"fmri_input": func_train, "smri_input": struct_train},  # Dictionary format for inputs
    y_train,  # Output labels
    batch_size=8,
    epochs=10,
    validation_data=(
        {"fmri_input": func_test, "smri_input": struct_test},
        y_test
    ),
    verbose=1
)

Epoch 1/10


I0000 00:00:1741485775.473754 1361475 service.cc:148] XLA service 0x70ca64005950 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1741485775.473879 1361475 service.cc:156]   StreamExecutor device (0): Host, Default Version
2025-03-09 07:32:56.208236: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1741485782.668432 1361475 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1590s[0m 156s/step - accuracy: 0.5679 - loss: 22898246.0000 - val_accuracy: 0.6842 - val_loss: nan
Epoch 2/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1485s[0m 148s/step - accuracy: 0.5242 - loss: nan - val_accuracy: 0.6842 - val_loss: nan
Epoch 3/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1479s[0m 148s/step - accuracy: 0.5250 - loss: nan - val_accuracy: 0.6842 - val_loss: nan
Epoch 4/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1472s[0m 147s/step - accuracy: 0.4963 - loss: nan - val_accuracy: 0.6842 - val_loss: nan
Epoch 5/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1483s[0m 148s/step - accuracy: 0.5371 - loss: nan - val_accuracy: 0.6842 - val_loss: nan
Epoch 6/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1478s[0m 148s/step - accuracy: 0.5436 - loss: nan - val_accuracy: 0.6842 - val_loss: nan
Epoch 7/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0