In [None]:
from tensorflow.keras.layers import ConvLSTM3D, LayerNormalization, MultiHeadAttention, Layer, Dense, Conv3D, GlobalAveragePooling3D, Concatenate, Flatten, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras import Input
import tensorflow as tf
import numpy as np

In [None]:
class SelfAttention3D(Layer):
    def __init__(self, embed_dim, num_heads=8):
        super(SelfAttention3D, self).__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.attention = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)

    def call(self, inputs):
        """
        Inputs shape: (batch_size, time_steps, depth, height, width, channels)
        Converts to 3D → Applies Self-Attention → Converts back to 5D
        """
        batch_size, time_steps, d, h, w, c = tf.unstack(tf.shape(inputs))

        # 5D -> 3D (batch_size, time_steps, features), where features = d * h * w * c
        reshaped_inputs = tf.reshape(inputs, (batch_size, time_steps, d * h * w * c))

        # Apply Self-Attention
        attended = self.attention(reshaped_inputs, reshaped_inputs)

        # 3D -> 5D (reshape back to original format)
        output = tf.reshape(attended, (batch_size, time_steps, d, h, w, c))
        return output


In [None]:
# --- fMRI Model (ConvLSTM) ---
def build_fmri_model():
  input_tensor = Input(shape=(4, 5, 5, 5, 1), name='fmri_input')  # (time, depth, height, width, channels)

  # ConvLSTM3D Layer (keeps shape same due to padding='same')
  x = ConvLSTM3D(filters=64, kernel_size=(3, 3, 3), padding='same', return_sequences=True)(input_tensor)
  x = LayerNormalization()(x)  # Normalize across features

  # Self-Attention Layer (does not change shape)
  x = SelfAttention3D(embed_dim=64, num_heads=4)(x)

  # Another ConvLSTM3D Layer
  x = ConvLSTM3D(filters=64, kernel_size=(3, 3, 3), padding='same', return_sequences=True)(x)
  x = LayerNormalization()(x)

  # Another Self-Attention
  x = SelfAttention3D(embed_dim=64, num_heads=4)(x)

  x = ConvLSTM3D(filters=64, kernel_size=(3, 3, 3), padding='same', return_sequences=True)(x)
  x = LayerNormalization()(x)

  # Another Self-Attention
  x = SelfAttention3D(embed_dim=64, num_heads=4)(x)

  x = Flatten()(x)
  x = Dense(128, activation="relu")(x)

  return Model(inputs=input_tensor, outputs=x, name="fMRI_Model")

In [None]:
# --- sMRI Model (3D CNN) ---
def build_smri_model():
    smri_input = Input(shape=(10, 10, 10, 1), name="smri_input")

    y = Conv3D(filters=32, kernel_size=(3,3,3), activation="relu", padding="same")(smri_input)
    y = LayerNormalization()(y)
    y = Conv3D(filters=64, kernel_size=(3,3,3), activation="relu", padding="same")(y)
    y = LayerNormalization()(y)

    y = GlobalAveragePooling3D()(y)
    y = Dense(128, activation="relu")(y)

    return Model(inputs=smri_input, outputs=y, name="sMRI_Model")

In [None]:
# --- Combine fMRI & sMRI Models ---
def build_combined_model():
    fmri_model = build_fmri_model()
    smri_model = build_smri_model()

    combined = Concatenate()([fmri_model.output, smri_model.output])
    combined = Dense(128, activation="relu")(combined)
    combined = Dropout(0.5)(combined)
    output = Dense(1, activation="sigmoid")(combined)

    model = Model(inputs=[fmri_model.input, smri_model.input], outputs=output, name="Combined_Model")
    return model

In [None]:
# Build & Compile Model
model = build_combined_model()
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
model.summary()

In [None]:
fmri_data = np.random.rand(100, 4, 5, 5, 5, 1)
smri_data = np.random.rand(100, 10, 10, 10, 1)
labels = np.random.randint(0, 2, size=(100,))
fmri_data.shape, smri_data.shape, labels.shape

((100, 4, 5, 5, 5, 1), (100, 10, 10, 10, 1), (100,))

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
train_fmri, val_fmri, train_smri, val_smri, train_labels, val_labels = train_test_split(fmri_data, smri_data, labels, test_size=0.2, random_state=42)
train_fmri.shape, val_fmri.shape, train_smri.shape, val_smri.shape, train_labels.shape, val_labels.shape

((80, 4, 5, 5, 5, 1),
 (20, 4, 5, 5, 5, 1),
 (80, 10, 10, 10, 1),
 (20, 10, 10, 10, 1),
 (80,),
 (20,))

In [None]:
history = model.fit(
    {"fmri_input": train_fmri, "smri_input": train_smri},  # Dictionary format for inputs
    train_labels,  # Output labels
    batch_size=16,
    epochs=30,
    validation_data=(
        {"fmri_input": val_fmri, "smri_input": val_smri},
        val_labels
    ),
    verbose=1
)

Epoch 1/30
[1m3/5[0m [32m━━━━━━━━━━━━[0m[37m━━━━━━━━[0m [1m15s[0m 8s/step - accuracy: 0.6319 - loss: 0.8396

KeyboardInterrupt: 