In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, ConvLSTM2D, GlobalAveragePooling2D, Conv3D, Flatten, Dense, Dropout, LayerNormalization, GlobalAveragePooling3D, Concatenate, TimeDistributed, MaxPooling3D, AveragePooling3D, GlobalMaxPooling3D, LSTM, Lambda
from tensorflow.keras.models import Model
import numpy as np
import pickle as pkl




In [2]:
devices = tf.config.list_physical_devices()
print(devices)

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [3]:
tf.config.set_visible_devices([], 'GPU')

In [4]:
# importing data
def returning_pkl_file_data(path : str):
    with open(path, 'rb') as f:
        temp = pkl.load(f)
    return temp

mci_func = returning_pkl_file_data(r"feature_extraction/MCI_func_52_79_95_79_197.pkl")
mci_struct = returning_pkl_file_data(r'feature_extraction/MCI_struct_cat_52_169_205_169.pkl')
cn_func = returning_pkl_file_data(r'feature_extraction/CN_func_42_79_95_79_197.pkl')
cn_struct = returning_pkl_file_data(r'feature_extraction/CN_struct_cat_42_169_205_169.pkl')

In [None]:
func_data = np.concat((mci_func, cn_func), axis=0, dtype=np.float32)
func_data.shape

(94, 79, 95, 79, 197)

In [None]:
struct_data = np.concat((mci_struct, cn_struct), axis=0, dtype=np.float32)
struct_data.shape

(94, 169, 205, 169)

In [7]:
all_labels = np.concat((np.zeros((len(mci_func),)), np.ones((len(cn_func),))))
all_labels.shape

(94,)

In [8]:
# func_data = np.expand_dims(func_data, axis=len(func_data.shape))
struct_data = np.expand_dims(struct_data, axis=len(struct_data.shape))

func_data.shape, struct_data.shape

((94, 79, 95, 79, 197), (94, 169, 205, 169, 1))

In [9]:
from sklearn.model_selection import train_test_split

func_train, func_test, struct_train, struct_test, y_train, y_test = train_test_split(
    func_data, struct_data, all_labels, test_size=0.2, random_state=42
)

func_train.shape, func_test.shape, struct_train.shape, struct_test.shape, y_train.shape, y_test.shape

((75, 79, 95, 79, 197),
 (19, 79, 95, 79, 197),
 (75, 169, 205, 169, 1),
 (19, 169, 205, 169, 1),
 (75,),
 (19,))

In [19]:
import tensorflow as tf
from tensorflow.keras import layers, Model
import math

class PositionalEncoding(layers.Layer):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len

    def get_angles(self, position, i):
        angles = 1 / tf.pow(10000., (2 * (i // 2)) / tf.cast(self.d_model, tf.float32))
        return position * angles

    def call(self, inputs):
        seq_length = tf.shape(inputs)[1]
        position = tf.range(seq_length, dtype=tf.float32)[:, tf.newaxis]
        i = tf.range(self.d_model, dtype=tf.float32)[tf.newaxis, :]
        angle_rads = self.get_angles(position, i)
        
        # Apply sin to even indices, cos to odd indices
        sin_mask = tf.cast(tf.range(self.d_model) % 2 == 0, tf.float32)
        cos_mask = 1 - sin_mask
        
        
        pos_encoding = (tf.sin(angle_rads) * sin_mask + tf.cos(angle_rads)) * cos_mask
        pos_encoding = pos_encoding[tf.newaxis, ...]
        
        return inputs + pos_encoding

class TransformerEncoderBlock(layers.Layer):
    def __init__(self, d_model, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
        self.ffn = tf.keras.Sequential([
            layers.Dense(ff_dim, activation="relu"),
            layers.Dense(d_model)
        ])
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(dropout)
        self.dropout2 = layers.Dropout(dropout)

    def call(self, inputs, training=False):
        attn_output = self.attn(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

def build_transformer_model(input_shape=(func_data.shape[1:]),  # (depth, height, width, time, channels)
                            embedding_dim=128,
                            num_heads=8,
                            ff_dim=512,
                            num_layers=4,
                            dropout=0.1,
                            use_mean_pooling=False):
    # Input layer
    fmri_input = layers.Input(shape=input_shape, name="fmri_input", dtype=tf.float32)
    
    # Preprocessing
    x = layers.Permute((4, 1, 2, 3))(fmri_input)  # Move time to first position: (batch, time, depth, height, width)
    x = layers.Reshape((-1, input_shape[0] * input_shape[1] * input_shape[2]))(x)  # (batch, time, spatial_features)
    
    # Embedding layer
    x = layers.Dense(embedding_dim, activation="relu")(x)
    
    # Positional encoding
    x = PositionalEncoding(embedding_dim)(x)
    
    # Transformer encoder
    for _ in range(num_layers):
        x = TransformerEncoderBlock(embedding_dim, num_heads, ff_dim, dropout)(x)
    
    # Pooling
    if use_mean_pooling:
        pooled_output = layers.GlobalAveragePooling1D()(x)
    else:  # Use first token (BERT-style)
        pooled_output = x[:, 0, :]
    
    # Classification
    output = layers.Dense(1, activation="sigmoid")(pooled_output)
    
    return Model(inputs=fmri_input, outputs=output, name="MeanTransformer" if use_mean_pooling else "Transformer")

# Example usage
def build_combined_model(fmri_shape=(func_data.shape[1:]),
                         smri_shape=(struct_data.shape[1:]),
                         transformer_params=None):
    # Default transformer parameters
    if transformer_params is None:
        transformer_params = {
            'embedding_dim': 128,
            'num_heads': 8,
            'ff_dim': 512,
            'num_layers': 4,
            'dropout': 0.1,
            'use_mean_pooling': False
        }
    
    # Build models
    fmri_model = build_transformer_model(input_shape=fmri_shape, **transformer_params)
    
    # sMRI model (3D CNN)
    smri_input = layers.Input(shape=smri_shape, name="smri_input", dtype=tf.float32)
    y = layers.Conv3D(32, 3, activation="relu", padding="same")(smri_input)
    y = layers.MaxPooling3D(2)(y)
    y = layers.Conv3D(64, 3, activation="relu", padding="same")(y)
    y = layers.GlobalAveragePooling3D()(y)
    y = layers.Dense(128, activation="relu")(y)
    smri_model = Model(inputs=smri_input, outputs=y, name="sMRI_Model")
    
    # Combine models
    combined = layers.Concatenate()([fmri_model.output, smri_model.output])
    combined = layers.Dense(256, activation="relu")(combined)
    combined = layers.Dropout(0.5)(combined)
    output = layers.Dense(1, activation="sigmoid")(combined)
    
    return Model(inputs=[fmri_model.input, smri_model.input], outputs=output, name="Combined_Model")

tf.keras.mixed_precision.set_global_policy('float32')
# Build and compile model
model = build_combined_model()
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
model.summary()

Model: "Combined_Model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 fmri_input (InputLayer)     [(None, 4, 4, 4, 5)]         0         []                            
                                                                                                  
 permute_8 (Permute)         (None, 5, 4, 4, 4)           0         ['fmri_input[0][0]']          
                                                                                                  
 reshape_8 (Reshape)         (None, 5, 64)                0         ['permute_8[0][0]']           
                                                                                                  
 dense_20 (Dense)            (None, 5, 128)               8320      ['reshape_8[0][0]']           
                                                                                     

In [16]:
func_data = np.random.rand(100, 4, 4, 4, 5).astype(np.float32)
struct_data = np.random.rand(100, 10, 10, 10, 1).astype(np.float32)
all_labels = np.random.randint(0, 2, size=(100,))
func_data.shape, struct_data.shape, all_labels.shape

((100, 4, 4, 4, 5), (100, 10, 10, 10, 1), (100,))

In [17]:
from sklearn.model_selection import train_test_split

In [18]:
func_train, func_test, struct_train, struct_test, y_train, y_test = train_test_split(func_data, struct_data, all_labels, test_size=0.2, random_state=42)
func_train.shape, func_test.shape, struct_train.shape, struct_test.shape, y_train.shape, y_test.shape

((80, 4, 4, 4, 5),
 (20, 4, 4, 4, 5),
 (80, 10, 10, 10, 1),
 (20, 10, 10, 10, 1),
 (80,),
 (20,))

In [20]:
history = model.fit(
    {"fmri_input": func_train, "smri_input": struct_train},  # Dictionary format for inputs
    y_train,  # Output labels
    batch_size=8,
    epochs=10,
    validation_data=(
        {"fmri_input": func_test, "smri_input": struct_test},
        y_test
    ),
    verbose=1
)

Epoch 1/10


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
