In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, ConvLSTM3D, Conv3D, Flatten, Dense, Dropout, BatchNormalization, GlobalAveragePooling3D, Concatenate
from tensorflow.keras.models import Model
import numpy as np
import pickle as pkl

2025-03-07 19:04:49.247836: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-07 19:04:49.268915: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741354489.292276 4105416 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741354489.299176 4105416 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-07 19:04:49.324701: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
devices = tf.config.list_physical_devices()
print(devices)

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [3]:
tf.config.set_visible_devices([], 'GPU')

In [5]:
# importing data
def returning_pkl_file_data(path : str):
    with open(path, 'rb') as f:
        temp = pkl.load(f)
    return temp

mci_func = returning_pkl_file_data(r"feature_extraction/MCI_func_52_79_95_79_197.pkl")
mci_struct = returning_pkl_file_data(r'feature_extraction/MCI_struct_52_169_205_169.pkl')
cn_func = returning_pkl_file_data(r'feature_extraction/CN_func_30_79_95_79_197.pkl')
cn_struct = returning_pkl_file_data(r'feature_extraction/cat12_CN_structural_features.pkl')

In [6]:
func_data = np.concat((mci_func, cn_func), axis=0)
func_data.shape

(82, 79, 95, 79, 197)

In [7]:
struct_data = np.concat((mci_struct, cn_struct), axis=0)
struct_data.shape

(82, 169, 205, 169)

In [8]:
all_labels = np.concat((np.zeros((52,)), np.ones((30,))))
all_labels.shape

(82,)

In [12]:
func_data = np.expand_dims(func_data, axis=len(func_data.shape))
struct_data = np.expand_dims(struct_data, axis=len(struct_data.shape))

func_data.shape, struct_data.shape

((82, 79, 95, 79, 197, 1), (82, 169, 205, 169, 1))

In [13]:
from sklearn.model_selection import train_test_split

func_train, func_test, struct_train, struct_test, y_train, y_test = train_test_split(
    func_data, struct_data, all_labels, test_size=0.2, random_state=42
)

func_train.shape, func_test.shape, struct_train.shape, struct_test.shape, y_train.shape, y_test.shape

((65, 79, 95, 79, 197, 1),
 (17, 79, 95, 79, 197, 1),
 (65, 169, 205, 169, 1),
 (17, 169, 205, 169, 1),
 (65,),
 (17,))

In [2]:
# --- fMRI Model (ConvLSTM) ---
def build_fmri_model():
    fmri_input = Input(shape=(79, 95, 79, 197, 1), name="fmri_input")

    x = ConvLSTM3D(filters=2, kernel_size=(3,3,3), padding="same", return_sequences=True, activation="relu")(fmri_input)
    x = BatchNormalization()(x)

    x = ConvLSTM3D(filters=2, kernel_size=(3,3,3), padding="same", return_sequences=True, activation="relu")(x)
    x = BatchNormalization()(x)

    x = Flatten()(x)
    x = Dense(128, activation="relu")(x)

    return Model(inputs=fmri_input, outputs=x, name="fMRI_Model")

# --- sMRI Model (3D CNN) ---
def build_smri_model():
    smri_input = Input(shape=(169, 205, 169, 1), name="smri_input")

    y = Conv3D(filters=32, kernel_size=(3,3,3), activation="relu", padding="valid")(smri_input)
    y = BatchNormalization()(y)
    y = Conv3D(filters=64, kernel_size=(3,3,3), activation="relu", padding="valid")(y)
    y = BatchNormalization()(y)

    y = GlobalAveragePooling3D()(y)
    y = Dense(128, activation="relu")(y)

    return Model(inputs=smri_input, outputs=y, name="sMRI_Model")

# --- Combine fMRI & sMRI Models ---
def build_combined_model():
    fmri_model = build_fmri_model()
    smri_model = build_smri_model()

    combined = Concatenate()([fmri_model.output, smri_model.output])
    combined = Dense(128, activation="relu")(combined)
    combined = Dropout(0.5)(combined)
    output = Dense(1, activation="softmax")(combined)

    model = Model(inputs=[fmri_model.input, smri_model.input], outputs=output, name="Combined_Model")
    return model

# Build & Compile Model
model = build_combined_model()
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
model.summary()


I0000 00:00:1741354496.134740 4105416 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 44782 MB memory:  -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:4b:00.0, compute capability: 8.6
2025-03-07 19:05:07.541783: W external/local_xla/xla/tsl/framework/bfc_allocator.cc:497] Allocator (GPU_0_bfc) ran out of memory trying to allocate 111.39GiB (rounded to 119603522560)requested by op StatelessRandomUniformV2
If the cause is memory fragmentation maybe the environment variable 'TF_GPU_ALLOCATOR=cuda_malloc_async' will improve the situation. 
Current allocation summary follows.
Current allocation summary follows.
2025-03-07 19:05:07.541872: I external/local_xla/xla/tsl/framework/bfc_allocator.cc:1053] BFCAllocator dump for GPU_0_bfc
2025-03-07 19:05:07.541898: I external/local_xla/xla/tsl/framework/bfc_allocator.cc:1060] Bin (256): 	Total Chunks: 27, Chunks in use: 26. 6.8KiB allocated for chunks. 6.5KiB in use in bin. 240B client-requested in use 

RuntimeError: pybind11::error_already_set: MISMATCH of original and normalized active exception types: ORIGINAL _NotOkStatusException REPLACED BY KeyboardInterrupt: <EMPTY MESSAGE>

At:
  /home/tripti/.local/lib/python3.10/site-packages/tensorflow/python/eager/core.py(42): __init__
  /home/tripti/.local/lib/python3.10/site-packages/tensorflow/python/ops/gen_stateless_random_ops_v2.py(569): stateless_random_uniform_v2
  /home/tripti/.local/lib/python3.10/site-packages/tensorflow/python/ops/stateless_random_ops.py(403): stateless_random_uniform
  /home/tripti/.local/lib/python3.10/site-packages/tensorflow/python/util/dispatch.py(1260): op_dispatch_handler
  /home/tripti/.local/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py(150): error_handler
  /home/tripti/.local/lib/python3.10/site-packages/keras/src/backend/tensorflow/random.py(34): uniform
  /home/tripti/.local/lib/python3.10/site-packages/keras/src/initializers/random_initializers.py(306): __call__
  /home/tripti/.local/lib/python3.10/site-packages/keras/src/backend/tensorflow/core.py(48): <lambda>
  /home/tripti/.local/lib/python3.10/site-packages/tensorflow/python/ops/resource_variable_ops.py(2057): _init_from_args
  /home/tripti/.local/lib/python3.10/site-packages/tensorflow/python/ops/resource_variable_ops.py(1873): __init__
  /home/tripti/.local/lib/python3.10/site-packages/tensorflow/python/ops/variables.py(201): __call__
  /home/tripti/.local/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py(150): error_handler
  /home/tripti/.local/lib/python3.10/site-packages/tensorflow/python/ops/resource_variable_ops.py(357): default_variable_creator_v2
  /home/tripti/.local/lib/python3.10/site-packages/tensorflow/python/ops/variables.py(51): default_variable_creator_v2
  /home/tripti/.local/lib/python3.10/site-packages/tensorflow/python/ops/variables.py(1223): <lambda>
  /home/tripti/.local/lib/python3.10/site-packages/tensorflow/python/ops/variables.py(1230): _variable_call
  /home/tripti/.local/lib/python3.10/site-packages/tensorflow/python/ops/variables.py(198): __call__
  /home/tripti/.local/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py(150): error_handler
  /home/tripti/.local/lib/python3.10/site-packages/keras/src/backend/tensorflow/core.py(47): _initialize_with_initializer
  /home/tripti/.local/lib/python3.10/site-packages/keras/src/backend/common/variables.py(170): __init__
  /home/tripti/.local/lib/python3.10/site-packages/keras/src/layers/layer.py(541): add_weight
  /home/tripti/.local/lib/python3.10/site-packages/keras/src/layers/core/dense.py(109): build
  /home/tripti/.local/lib/python3.10/site-packages/keras/src/layers/layer.py(226): build_wrapper
  /home/tripti/.local/lib/python3.10/site-packages/keras/src/layers/layer.py(1365): _maybe_build
  /home/tripti/.local/lib/python3.10/site-packages/keras/src/layers/layer.py(826): __call__
  /home/tripti/.local/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py(117): error_handler
  /tmp/ipykernel_4105416/3782669251.py(12): build_fmri_model
  /tmp/ipykernel_4105416/3782669251.py(32): build_combined_model
  /tmp/ipykernel_4105416/3782669251.py(44): <module>
  /usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py(3577): run_code
  /usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py(3517): run_ast_nodes
  /usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py(3334): run_cell_async
  /usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py(128): _pseudo_sync_runner
  /usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py(3130): _run_cell
  /usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py(3075): run_cell
  /usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py(549): run_cell
  /usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py(449): do_execute
  /usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py(778): execute_request
  /usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py(362): execute_request
  /usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py(437): dispatch_shell
  /usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py(534): process_one
  /usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py(545): dispatch_queue
  /usr/lib/python3.10/asyncio/events.py(80): _run
  /usr/lib/python3.10/asyncio/base_events.py(1909): _run_once
  /usr/lib/python3.10/asyncio/base_events.py(603): run_forever
  /usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py(205): start
  /usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py(739): start
  /usr/local/lib/python3.10/dist-packages/traitlets/config/application.py(1075): launch_instance
  /usr/local/lib/python3.10/dist-packages/ipykernel_launcher.py(18): <module>
  /usr/lib/python3.10/runpy.py(86): _run_code
  /usr/lib/python3.10/runpy.py(196): _run_module_as_main


In [None]:
fmri_data = np.random.rand(100, 5, 4, 4, 4, 1)
smri_data = np.random.rand(100, 10, 10, 10, 1)
labels = np.random.randint(0, 2, size=(100,))
fmri_data.shape, smri_data.shape, labels.shape

((100, 5, 4, 4, 4, 1), (100, 10, 10, 10, 1), (100,))

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
train_fmri, val_fmri, train_smri, val_smri, train_labels, val_labels = train_test_split(fmri_data, smri_data, labels, test_size=0.2, random_state=42)
train_fmri.shape, val_fmri.shape, train_smri.shape, val_smri.shape, train_labels.shape, val_labels.shape

((80, 5, 4, 4, 4, 1),
 (20, 5, 4, 4, 4, 1),
 (80, 10, 10, 10, 1),
 (20, 10, 10, 10, 1),
 (80,),
 (20,))

In [None]:
history = model.fit(
    {"fmri_input": train_fmri, "smri_input": train_smri},  # Dictionary format for inputs
    train_labels,  # Output labels
    batch_size=16,
    epochs=30,
    validation_data=(
        {"fmri_input": val_fmri, "smri_input": val_smri},
        val_labels
    ),
    verbose=1
)

Epoch 1/30




[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.3910 - loss: 1.0554



[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 1s/step - accuracy: 0.4009 - loss: 1.0785 - val_accuracy: 0.5000 - val_loss: 0.6930
Epoch 2/30
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 1s/step - accuracy: 0.4660 - loss: 0.9222 - val_accuracy: 0.5000 - val_loss: 0.6934
Epoch 3/30
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.3827 - loss: 0.7989

KeyboardInterrupt: 