In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, ConvLSTM3D, Conv3D, Flatten, Dense, Dropout, BatchNormalization, GlobalAveragePooling3D, Concatenate, TimeDistributed, MaxPooling3D, AveragePooling3D, GlobalMaxPooling3D, LSTM, Lambda
from tensorflow.keras.models import Model
import numpy as np
import pickle as pkl

2025-03-08 19:47:32.118154: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-08 19:47:32.137556: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741443452.160118 1329386 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741443452.166772 1329386 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-08 19:47:32.191804: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
devices = tf.config.list_physical_devices()
print(devices)

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [3]:
tf.config.set_visible_devices([], 'GPU')

In [4]:
# importing data
def returning_pkl_file_data(path : str):
    with open(path, 'rb') as f:
        temp = pkl.load(f)
    return temp

mci_func = returning_pkl_file_data(r"feature_extraction/MCI_func_52_79_95_79_197.pkl")
mci_struct = returning_pkl_file_data(r'feature_extraction/MCI_struct_cat_52_169_205_169.pkl')
cn_func = returning_pkl_file_data(r'feature_extraction/CN_func_42_79_95_79_197.pkl')
cn_struct = returning_pkl_file_data(r'feature_extraction/CN_struct_cat_42_169_205_169.pkl')

In [5]:
func_data = np.concat((mci_func, cn_func), axis=0, dtype=np.float16)
func_data.shape

(94, 79, 95, 79, 197)

In [6]:
struct_data = np.concat((mci_struct, cn_struct), axis=0, dtype=np.float16)
struct_data.shape

(94, 169, 205, 169)

In [7]:
all_labels = np.concat((np.zeros((len(mci_func),)), np.ones((len(cn_func),))))
all_labels.shape

(94,)

In [8]:
func_data = np.expand_dims(func_data, axis=len(func_data.shape))
struct_data = np.expand_dims(struct_data, axis=len(struct_data.shape))

func_data.shape, struct_data.shape

((94, 79, 95, 79, 197, 1), (94, 169, 205, 169, 1))

In [17]:
func_data.shape

(94, 40, 48, 40, 197, 1)

In [None]:
from skimage.transform import resize

def downsample_volume(volume, new_shape):
    return resize(volume, new_shape, mode='constant')

# Downsample fMRI data from (79, 95, 79, 197, 1) to (40, 48, 40, 100, 1)
new_shape_fmri = (40, 48, 40, 197, 1)
new_shape_smri = (85, 102, 85, 1) 
# func_train.shape, func_test.shape, struct_ train.shape, struct_test.shape, y_train.shape, y_test.shape
func_data = np.array([downsample_volume(sample, new_shape_fmri) for sample in func_data])
struct_data = np.array([downsample_volume(sample, new_shape_smri) for sample in struct_data])
func_data.shape, struct_data.shape


((94, 40, 48, 40, 197, 1), (94, 85, 102, 85, 1, 1))

In [20]:
struct_data = struct_data.squeeze()
struct_data = np.expand_dims(struct_data, axis=-1)
struct_data.shape

(94, 85, 102, 85, 1)

In [21]:
# 2. Create memory-efficient dataset using generator
class BrainDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, func_data, struct_data, labels, batch_size=4):
        self.func_data = func_data
        self.struct_data = struct_data
        self.labels = labels
        self.batch_size = batch_size
        self.indices = np.arange(len(func_data))
        
    def __len__(self):
        return int(np.ceil(len(self.func_data) / self.batch_size))
    
    def __getitem__(self, idx):
        batch_indices = self.indices[idx*self.batch_size:(idx+1)*self.batch_size]
        return (
            {
                'func': self.func_data[batch_indices],
                'struct': self.struct_data[batch_indices]
            },
            self.labels[batch_indices]
        )

In [22]:
from sklearn.model_selection import train_test_split

# Split indices to avoid data duplication
indices = np.arange(len(func_data))
train_idx, test_idx = train_test_split(indices, test_size=0.2, random_state=42)

# Create memory-mapped datasets
train_generator = BrainDataGenerator(
    func_data[train_idx], 
    struct_data[train_idx], 
    all_labels[train_idx]
)

test_generator = BrainDataGenerator(
    func_data[test_idx], 
    struct_data[test_idx], 
    all_labels[test_idx],
    batch_size=len(test_idx)  # Full batch for testing
)

# 4. Convert to tf.data.Dataset with optimized pipeline
def create_tf_dataset(generator, training=True):
    dataset = tf.data.Dataset.from_generator(
        lambda: generator,
        output_signature=(
            {
                'func': tf.TensorSpec(shape=(None, *func_data.shape[1:]), dtype=tf.float16),
                'struct': tf.TensorSpec(shape=(None, *struct_data.shape[1:]), dtype=tf.float16)
            },
            tf.TensorSpec(shape=(None,), dtype=tf.float16)
        )
    )
    
    if training:
        dataset = dataset.repeat()  # Infinite repetition for training
        dataset = dataset.shuffle(100)
        
    dataset = dataset.cache().prefetch(tf.data.AUTOTUNE)
    return dataset

train_dataset = create_tf_dataset(train_generator)
test_dataset = create_tf_dataset(test_generator)

In [23]:
# --- fMRI Model (ConvLSTM) ---
def build_fmri_model():
    fmri_input = Input(shape=func_data.shape[1:], name="func", dtype=tf.float16)
    x = Lambda(lambda x: tf.transpose(x, perm=[0, 4, 1, 2, 3, 5]))(fmri_input)
    # Apply Conv3D independently to each time step
    x = TimeDistributed(Conv3D(32, (3,3,3), padding="same", activation="relu"))(x)
    x = TimeDistributed(MaxPooling3D(2))(x)  # Downsample spatial dimensions
    x = TimeDistributed(Conv3D(64, (3,3,3), padding="same", activation="relu"))(x)
    x = TimeDistributed(GlobalAveragePooling3D())(x)  # Shape: (batch, time=197, 64)
    # Temporal modeling with LSTM
    x = LSTM(128)(x)  # Output shape: (batch, 128)
    return Model(inputs=fmri_input, outputs=x, name="fMRI_Model")

# --- sMRI Model (3D CNN) ---
def build_smri_model():
    smri_input = Input(shape=struct_data.shape[1:], name="struct", dtype=tf.float16)

    y = Conv3D(filters=32, kernel_size=(3,3,3), activation="relu", padding="valid")(smri_input)
    y = BatchNormalization()(y)
    y = Conv3D(filters=32, kernel_size=(3,3,3), activation="relu", padding="valid")(y)
    y = BatchNormalization()(y)

    y = GlobalAveragePooling3D()(y)
    y = Dense(128, activation="relu")(y)

    return Model(inputs=smri_input, outputs=y, name="sMRI_Model")

# --- Combine fMRI & sMRI Models ---
def build_combined_model():
    fmri_model = build_fmri_model()
    smri_model = build_smri_model()

    combined = Concatenate()([fmri_model.output, smri_model.output])
    combined = Dense(128, activation="relu")(combined)
    combined = Dropout(0.5)(combined)
    output = Dense(1, activation="sigmoid")(combined)

    model = Model(inputs=[fmri_model.input, smri_model.input], outputs=output, name="Combined_Model")
    return model

# Build & Compile Model
model = build_combined_model()
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
model.summary()


In [22]:
func_data = np.random.rand(100, 5, 4, 4, 4, 1)
struct_data = np.random.rand(100, 10, 10, 10, 1)
all_labels = np.random.randint(0, 2, size=(100,))
func_data.shape, struct_data.shape, all_labels.shape

((100, 5, 4, 4, 4, 1), (100, 10, 10, 10, 1), (100,))

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
train_fmri, val_fmri, train_smri, val_smri, train_labels, val_labels = train_test_split(fmri_data, smri_data, labels, test_size=0.2, random_state=42)
train_fmri.shape, val_fmri.shape, train_smri.shape, val_smri.shape, train_labels.shape, val_labels.shape

((80, 5, 4, 4, 4, 1),
 (20, 5, 4, 4, 4, 1),
 (80, 10, 10, 10, 1),
 (20, 10, 10, 10, 1),
 (80,),
 (20,))

In [24]:
history = model.fit(
    train_dataset,
    validation_data=test_dataset,
    epochs=10,
    steps_per_epoch=len(train_generator),
    validation_steps=1
)

Epoch 1/10


2025-03-08 21:12:21.746300: E tensorflow/core/util/util.cc:131] oneDNN supports DT_HALF only on platforms with AVX-512. Falling back to the default Eigen-based implementation if present.
I0000 00:00:1741448550.293230 1330295 service.cc:148] XLA service 0x7475a4006260 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1741448550.293272 1330295 service.cc:156]   StreamExecutor device (0): Host, Default Version
2025-03-08 21:12:31.582372: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1741448759.869110 1330295 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
2025-03-08 21:16:09.898249: W external/local_xla/xla/tsl/framework/bfc_allocator.cc:497] Allocator (mklcpu) ran out of memory trying to allocate 1.56TiB (rounded to 1717586539520)requested by op 
20

ResourceExhaustedError: Graph execution error:

Detected at node StatefulPartitionedCall defined at (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main

  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code

  File "/usr/local/lib/python3.10/dist-packages/ipykernel_launcher.py", line 18, in <module>

  File "/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py", line 1075, in launch_instance

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py", line 739, in start

  File "/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py", line 205, in start

  File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever

  File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once

  File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 534, in process_one

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 362, in execute_request

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 778, in execute_request

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 449, in do_execute

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py", line 549, in run_cell

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3075, in run_cell

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3577, in run_code

  File "/tmp/ipykernel_1329386/699743898.py", line 1, in <module>

  File "/home/tripti/.local/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/home/tripti/.local/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 368, in fit

  File "/home/tripti/.local/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 216, in function

  File "/home/tripti/.local/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 129, in multi_step_on_iterator

Out of memory while trying to allocate 1717586539408 bytes.
	 [[{{node StatefulPartitionedCall}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_multi_step_on_iterator_44208]