In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# load data
data = pd.read_csv('../dataset/final_data/cerebellum.CerebNet.stats980.csv')

In [2]:
data.head(5)

Unnamed: 0,Group,Age,Sex,Volume_mm30,normMean0,normStdDev0,normMin0,normMax0,Volume_mm31,normMean1,...,Volume_mm328,normMean28,normStdDev28,normMin28,normMax28,Volume_mm329,normMean29,normStdDev29,normMin29,normMax29
0,AD,78,M,12137.632,107.9979,6.6932,78.0,126.0,37830.688,80.8978,...,1778.488,84.8443,11.8863,43.0,111.0,4975.197,81.3287,12.636,31.0,111.0
1,AD,66,M,12228.416,108.5842,7.3703,77.0,128.0,43448.425,79.0423,...,2183.821,79.9011,14.1439,37.0,116.0,5486.459,77.786,14.652,30.0,116.0
2,AD,77,M,11190.005,116.1092,7.59,83.0,135.0,41515.417,89.4409,...,1933.458,88.9038,11.2102,51.0,109.0,5094.094,84.7262,12.3657,36.0,109.0
3,AD,73,M,11765.181,111.5798,7.4022,75.0,137.0,40915.116,83.3961,...,1806.855,82.8725,14.0701,37.0,110.0,4479.252,80.2851,14.072,30.0,110.0
4,AD,62,M,12482.472,107.1022,6.2091,75.0,122.0,46933.489,82.0554,...,1997.634,87.1267,10.2202,52.0,111.0,5481.872,84.6852,11.5048,38.0,111.0


In [3]:
X = data.iloc[:, 3:].values
X = X.reshape(X.shape[0], -1, 5)
y = data.iloc[:, 0].values
# X: (n_samples, node_num, node_dim)
X.shape, y.shape

((980, 30, 5), (980,))

In [4]:
# standardize
X_standardized = (X - X.mean(axis=0)) / X.std(axis=0)
X_standardized.shape

(980, 30, 5)

In [5]:
np.unique(y, return_counts=True)

(array(['AD', 'CN', 'MCI'], dtype=object), array([221, 282, 477]))

In [6]:
y_tri = np.where(y == "AD", 2, np.where(y == "MCI", 1, np.where(y == "CN", 0, -1)))

In [7]:
from sklearn.metrics.pairwise import cosine_similarity

# generate neighbors matrix (masked)
def generate_neighbors_matrix(X_standardized, threshold=0.5):
    # compute coefficients
    X_similarity = [cosine_similarity(X_standardized[i, :, :]) for i in range(X_standardized.shape[0])]
    X_similarity = np.array(X_similarity)
    print("X_similarity.shape:", X_similarity.shape)
    n_samples, node_num, _ = X_similarity.shape
    neighbors_matrix = np.zeros((n_samples, node_num, node_num))
    neighbors_matrix = (X_similarity > threshold).astype(int)
    return neighbors_matrix

X_neighbors = generate_neighbors_matrix(X_standardized, threshold=0.8)
print("X_neighbors.shape: ", X_neighbors.shape)

X_similarity.shape: (980, 30, 30)
X_neighbors.shape:  (980, 30, 30)


In [8]:
# X_neighbors.sum(axis=1).sum(axis=1) - 30

In [25]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, MultiHeadAttention, Dense, Masking
from tensorflow.keras.models import Model
from tensorflow.keras.layers import LayerNormalization


n_samples, node_num, node_dim = X_standardized.shape
num_heads = 8
embed_dim = 16


# Define the model

class ReduceMeanLayer(tf.keras.layers.Layer):
    def call(self, inputs):
        return tf.reduce_mean(inputs, axis=-1)

inputs = Input(shape=(node_num, node_dim))
# embedding
embedding = Dense(embed_dim)(inputs)

def Attention(embedding, num_heads=num_heads, embed_dim=embed_dim):
    # Multi-head self-attention layer with masking
    attention_output = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)(embedding, embedding, attention_mask=X_neighbors)
    # Add & Norm
    attention_output = LayerNormalization(epsilon=1e-6)(attention_output + embedding)
    return attention_output

attention_output = embedding
for round in range(3):
    attention_output = Attention(attention_output)

# Feed forward layer
# Compress the attention output to shape (n_samples, node_num)
compressed_output = ReduceMeanLayer()(attention_output)

# 3 classes classification
outputs = Dense(3, activation='softmax')(compressed_output)

model = Model(inputs=inputs, outputs=outputs)
outputs.shape
model.summary()

In [26]:
batch_size = 20

# Reshape X to include batch size
X_reshaped = X.reshape(-1, batch_size, X.shape[1], X.shape[2])

# Reshape y to include batch size
y_reshaped = y_tri.reshape(-1, batch_size)

X_reshaped.shape, y_reshaped.shape

((49, 20, 30, 5), (49, 20))

In [27]:
model.compile(
    optimizer="adam", 
    loss="sparse_categorical_crossentropy", 
    metrics=["accuracy"]
)
model.fit(X_standardized, y_tri, epochs=10, batch_size=980)

Epoch 1/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.3337 - loss: 1.0986
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 168ms/step - accuracy: 0.4867 - loss: 1.0980
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 170ms/step - accuracy: 0.4867 - loss: 1.0973
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 173ms/step - accuracy: 0.4867 - loss: 1.0966
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 172ms/step - accuracy: 0.4867 - loss: 1.0959
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 168ms/step - accuracy: 0.4867 - loss: 1.0952
Epoch 7/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 185ms/step - accuracy: 0.4867 - loss: 1.0944
Epoch 8/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 174ms/step - accuracy: 0.4867 - loss: 1.0936
Epoch 9/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[

<keras.src.callbacks.history.History at 0x74cf4ad997e0>

array(['AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD',
       'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD',
       'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD',
       'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD',
       'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD',
       'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD',
       'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD',
       'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD',
       'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD',
       'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD',
       'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD',
       'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD',
       'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD',
       'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'AD', 'A