In [1]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, Flatten, Dense, Dropout, LSTM, TimeDistributed, BatchNormalization, Concatenate, Softmax
import tensorflow.keras.backend as K




In [2]:
# Self-Attention Mechanism
def self_attention(inputs):
    """
    Self-attention mechanism for feature refinement.
    inputs: (batch_size, time_steps, features) OR (batch_size, features)
    """
    query = Dense(inputs.shape[-1])(inputs)  # Learn query
    key = Dense(inputs.shape[-1])(inputs)  # Learn key
    value = Dense(inputs.shape[-1])(inputs)  # Learn value

    # Compute attention scores (Scaled Dot-Product Attention)
    scores = tf.matmul(query, key, transpose_b=True)  # (batch, time_steps, time_steps)
    scores = scores / tf.sqrt(tf.cast(tf.shape(inputs)[-1], tf.float32))  # Scale
    attention_weights = Softmax(axis=-1)(scores)  # Apply softmax

    # Compute weighted sum of values
    attended_output = tf.matmul(attention_weights, value)  # (batch, time_steps, features)

    return attended_output

In [3]:
# Define input shapes
time_steps = 20  # fMRI time steps
depth, height, width, channels = 64, 64, 64, 1  # 3D volume shape
num_classes = 2  # Number of classification categories

# --- fMRI Branch (3D-CNN → LSTM → Self-Attention) ---
fmri_input = Input(shape=(time_steps, depth, height, width, channels), name="fMRI_Input")

x = TimeDistributed(Conv3D(32, (3, 3, 3), activation='relu', padding='same'))(fmri_input)
x = TimeDistributed(MaxPooling3D((2, 2, 2)))(x)
x = TimeDistributed(BatchNormalization())(x)

x = TimeDistributed(Conv3D(64, (3, 3, 3), activation='relu', padding='same'))(x)
x = TimeDistributed(MaxPooling3D((2, 2, 2)))(x)
x = TimeDistributed(BatchNormalization())(x)

x = TimeDistributed(Flatten())(x)
x = LSTM(64, return_sequences=True)(x)  # Keep sequence for attention

# Apply Self-Attention on LSTM Outputs
x = self_attention(x)
fmri_features = LSTM(64, return_sequences=False)(x)  # Extract final feature representation





In [4]:
# --- sMRI Branch (3D-CNN → Self-Attention) ---
smri_input = Input(shape=(depth, height, width, channels), name="sMRI_Input")

y = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(smri_input)
y = MaxPooling3D((2, 2, 2))(y)
y = BatchNormalization()(y)

y = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(y)
y = MaxPooling3D((2, 2, 2))(y)
y = BatchNormalization()(y)

y = Flatten()(y)
y = self_attention(y)  # Apply Self-Attention on extracted features
smri_features = Dense(64, activation='relu')(y)

ResourceExhaustedError: {{function_node __wrapped__StatelessRandomUniformV2_device_/job:localhost/replica:0/task:0/device:CPU:0}} OOM when allocating tensor with shape[262144,262144] and type float on /job:localhost/replica:0/task:0/device:CPU:0 by allocator mklcpu [Op:StatelessRandomUniformV2] name: 

In [None]:
# --- Fusion with Self-Attention ---
merged = Concatenate()([fmri_features, smri_features])  # Merge fMRI & sMRI features
merged = self_attention(merged)  # Apply Self-Attention to combined features

# --- Final Classification ---
z = Dense(64, activation='relu')(merged)
z = Dropout(0.5)(z)
output = Dense(num_classes, activation='softmax', name="Output")(z)  

# Build and compile the model
model = Model(inputs=[fmri_input, smri_input], outputs=output)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Print Model Summary
model.summary()
