# Installations

In [None]:

!pip uninstall -y tensorboard jax jaxlib
!pip install tensorboard==2.11


In [None]:
# 🔁 Reinstall NumPy FIRST (so PyTorch detects it during install)
!pip install numpy==1.24.4


In [None]:
# 🔧 Reinstall torch + friends (CUDA 11.8)
!pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118


In [None]:
!pip install tensorboard==2.11

In [None]:
!pip install mamba-ssm==1.1.0

In [None]:

# 🔁 Test it
import numpy as np
import torch

x = torch.randn(3, 3)
print("✅ NumPy version:", np.__version__)
print("✅ Torch version:", torch.__version__)
print("✅ NumPy conversion:", x.numpy())

In [None]:

!pip install torcheeg
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, roc_auc_score

import os
from torcheeg.trainers import DANNTrainer
import math
os.chdir('/content/drive/MyDrive/MyMethod')
from efficient_kan.src.efficient_kan.kan import KAN

!pip install nilearn
import os
import numpy as np
import pandas as pd
from nilearn.connectome import ConnectivityMeasure
from nilearn.datasets import fetch_abide_pcp
from nilearn.input_data import NiftiLabelsMasker
from nilearn.datasets import fetch_atlas_craddock_2012
from sklearn.utils import shuffle
import matplotlib.pyplot as plt

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split



import torch
from mamba_ssm import Mamba

# import os
# os.chdir('/content/mamba')
# from mamba.mamba_ssm.modules.mamba_simple import Mamba


In [None]:
# !pip install -U causal_conv1d
!pip install causal-conv1d==1.1.1

In [None]:
# test MAMBA
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
print(x)
print(y)

# Dataloaders and Training

In [8]:
# Define PyTorch Dataset
class fMRIDataset_domain(Dataset):
    def __init__(self, fc, cc, labels, age, gender,numeric_institutions, fiq, viq, piq, eye,numeric_handedness ):
        self.fc = torch.tensor(fc, dtype=torch.float32)
        self.cc = torch.tensor(cc, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.age = torch.tensor(age, dtype=torch.float32)
        self.gender = torch.tensor(gender, dtype=torch.float32)
        self.handedness = torch.tensor(numeric_handedness, dtype=torch.float32)
        self.fiq = torch.tensor(fiq, dtype=torch.float32)
        self.viq = torch.tensor(viq, dtype=torch.float32)
        self.piq = torch.tensor(piq, dtype=torch.float32)
        self.eye = torch.tensor(eye, dtype=torch.float32)
        self.site = torch.tensor(numeric_institutions, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):


        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        fc_tensor = torch.tensor(self.fc, dtype=torch.float32).to(device)
        cc_tensor = torch.tensor(self.cc, dtype=torch.float32).to(device)


        # combined_tensors = CombinedTensors(fc_tensor[idx], cc_tensor[idx])
        # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # combined_tensors = combined_tensors.to(device)



        return [fc_tensor[idx], cc_tensor[idx]], self.labels[idx], self.site[idx]



In [9]:
def custom_collate_domain(batch):
    """
    Custom collate function to process a batch of (fc, cc, labels).

    Args:
        batch (list of tuples): Each tuple contains (fc, cc, labels).
            - fc: torch.Tensor, shape (m, n)
            - cc: torch.Tensor, shape (p)
            - labels: torch.Tensor, shape (1) or (num_classes)

    Returns:
        concatenated_features (torch.Tensor): Batched features of shape (batch_size, m*n + p).
        labels (torch.Tensor): Batched labels.
    """
    # Flatten and concatenate fc and cc for each sample in the batch
    features = [torch.cat((item[0][0].flatten(), item[0][1]), dim=0) for item in batch]
    # Stack all features to form a batch
    concatenated_features = torch.stack(features)
    # Extract and stack labels
    labels = torch.stack([item[1] for item in batch])
    domains = torch.stack([item[2] for item in batch])


    return concatenated_features, labels, domains


In [10]:
# Define PyTorch Dataset
class fMRIDataset(Dataset):
    def __init__(self, fc, cc, labels, age, gender,numeric_institutions, fiq, viq, piq, eye,numeric_handedness ):
        self.fc = torch.tensor(fc, dtype=torch.float32)
        self.cc = torch.tensor(cc, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.age = torch.tensor(age, dtype=torch.float32)
        self.gender = torch.tensor(gender, dtype=torch.float32)
        self.handedness = torch.tensor(numeric_handedness, dtype=torch.float32)
        self.fiq = torch.tensor(fiq, dtype=torch.float32)
        self.viq = torch.tensor(viq, dtype=torch.float32)
        self.piq = torch.tensor(piq, dtype=torch.float32)
        self.eye = torch.tensor(eye, dtype=torch.float32)
        self.site = torch.tensor(numeric_institutions, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):


        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        fc_tensor = torch.tensor(self.fc, dtype=torch.float32).to(device)
        cc_tensor = torch.tensor(self.cc, dtype=torch.float32).to(device)


        # combined_tensors = CombinedTensors(fc_tensor[idx], cc_tensor[idx])
        # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # combined_tensors = combined_tensors.to(device)



        return [fc_tensor[idx], cc_tensor[idx]] , self.labels[idx]



In [11]:
# Define PyTorch Dataset
class fMRIDataset_target(Dataset):
    def __init__(self, fc, cc, labels, age, gender,numeric_institutions, fiq, viq, piq, eye,numeric_handedness ):
        self.fc = torch.tensor(fc, dtype=torch.float32)
        self.cc = torch.tensor(cc, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.age = torch.tensor(age, dtype=torch.float32)
        self.gender = torch.tensor(gender, dtype=torch.float32)
        self.handedness = torch.tensor(numeric_handedness, dtype=torch.float32)
        self.fiq = torch.tensor(fiq, dtype=torch.float32)
        self.viq = torch.tensor(viq, dtype=torch.float32)
        self.piq = torch.tensor(piq, dtype=torch.float32)
        self.eye = torch.tensor(eye, dtype=torch.float32)
        self.site = torch.tensor(numeric_institutions, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):


        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        fc_tensor = torch.tensor(self.fc, dtype=torch.float32).to(device)
        cc_tensor = torch.tensor(self.cc, dtype=torch.float32).to(device)


        # combined_tensors = CombinedTensors(fc_tensor[idx], cc_tensor[idx])
        # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # combined_tensors = combined_tensors.to(device)



        return [fc_tensor[idx], cc_tensor[idx]], self.site[idx]



In [12]:
def custom_collate(batch):
    """
    Custom collate function to process a batch of (fc, cc, labels).

    Args:
        batch (list of tuples): Each tuple contains (fc, cc, labels).
            - fc: torch.Tensor, shape (m, n)
            - cc: torch.Tensor, shape (p)
            - labels: torch.Tensor, shape (1) or (num_classes)

    Returns:
        concatenated_features (torch.Tensor): Batched features of shape (batch_size, m*n + p).
        labels (torch.Tensor): Batched labels.
    """
    # Flatten and concatenate fc and cc for each sample in the batch
    features = [torch.cat((item[0][0].flatten(), item[0][1]), dim=0) for item in batch]
    # Stack all features to form a batch
    concatenated_features = torch.stack(features)
    # Extract and stack labels
    labels = torch.stack([item[1] for item in batch])

    return concatenated_features, labels


In [None]:
# > Determine if you want PCA or NOT below :
PCA_flag = False

import numpy as np
import math

npz_file_path = '/content/drive/MyDrive/MyMethod/home/Dataset/ABIDE_CC400_NOTQC_TPE_withSITE_shuffle_randomSeed1234_filtGlobal.npz'
# npz_file_path = '/content/drive/MyDrive/MyMethod/home/Dataset/ABIDE_CC200_PCC_withSITE_shuffle_randomSeed1234_.npz'


cc_data = np.load(npz_file_path)

# Extract data
cc = cc_data["fc"]  # Shape: (1035, 316, 392)
print(cc.shape)
cc_subject_ids = cc_data["subject"]



# Example arrays
# A: Reference arrays (same shuffle)

# Load data from .npz file
npz_file_path = "/content/drive/MyDrive/MyMethod/ABIDE_Data_Directory/ABIDE_CC400_NotCC_NotQC_shufflerandomSeed42.npz"
# npz_file_path = '/content/drive/MyDrive/MyMethod/home/Dataset/ABIDE_CC200_null_shuffle_randomSeed1234_.npz'

data = np.load(npz_file_path)

# B:
# Extract keys
fc = data["fc"]  # Shape: (1035, 316, 392)
print(fc.shape)
labels = data["label"]  # Shape: (1035,)
subject_ids = data["subject"]  # Shape: (1035,)
age = data["age"]  # Shape: (1035,)
gender = data["gender"]  # Shape: (1035,)
handedness = data["handedness"]  # Shape: (1035,)
fiq = data["fiq"]  # Shape: (1035,)
viq = data["viq"]  # Shape: (1035,)
piq = data["piq"]  # Shape: (1035,)
eye = data["eye"]
site = data["site"]
# site = [item.split('_')[0] for item in data['subject']]

# Adjust labels to start from 0
labels -= labels.min()



import numpy as np

# Define the mapping
handedness_mapping = {
    'R': 0.0,           # Right -> 0.0
    'L': 1.0,           # Left -> 1.0
    'nan': 0.0 # nan -> 2.0
}

# Map values to float and handle 'nan' with a default value (e.g., -1.0)
numeric_handedness = np.array([handedness_mapping.get(h, -1.0) for h in handedness], dtype=np.float32)

# Check the result
print(numeric_handedness)

# Example array
institutions =site
# Step 1: Create a mapping for unique strings
unique_values = np.unique(institutions)  # Find unique values
string_to_num_mapping = {string: i for i, string in enumerate(unique_values)}

# print("Mapping:", string_to_num_mapping)

# Step 2: Map the array using the dictionary
numeric_institutions = np.array([string_to_num_mapping[string] for string in institutions], dtype=np.float32)

# Step 3: Print the result
# print("Original Array:", institutions)
# print("Numeric Array:", numeric_institutions)






# Step 1: Create a mapping from B_subject_ids to their indices in A_subject_ids
A_to_index = {value: idx for idx, value in enumerate(cc_subject_ids)}

# Step 2: Find indices in B_subject_ids that correspond to A_subject_ids
matching_indices = [A_to_index[value] for value in subject_ids]

# Step 3: Reorder all B arrays to match A
subject_ids_aligned = subject_ids[np.argsort(matching_indices)]
labels_aligned = labels[np.argsort(matching_indices)]
data_aligned = fc[np.argsort(matching_indices)]
age_aligned = age[np.argsort(matching_indices)]
gender_aligned = gender[np.argsort(matching_indices)]
handedness_aligned = numeric_handedness[np.argsort(matching_indices)]
fiq_aligned = fiq[np.argsort(matching_indices)]
viq_aligned = viq[np.argsort(matching_indices)]
piq_aligned = piq[np.argsort(matching_indices)]
eye_aligned = eye[np.argsort(matching_indices)]
site_aligned = numeric_institutions[np.argsort(matching_indices)]


# normalize
def normalize_data(data):
    min_val = np.min(data)
    max_val = np.max(data)
    print((max_val - min_val))
    normalized_data = (data - min_val) / (max_val - min_val)

    return normalized_data
age_aligned = normalize_data(age_aligned)

fiq_aligned = [100 if x=='-9999' or x=='' or math.isnan(x) else x for x in fiq_aligned]
fiq_aligned = normalize_data(fiq_aligned)
viq_aligned = [100 if x=='-9999' or x=='' or math.isnan(x) else x for x in viq_aligned]
viq_aligned = normalize_data(viq_aligned)
piq_aligned = [100 if x=='-9999' or x=='' or math.isnan(x) else x for x in piq_aligned]
piq_aligned = normalize_data(piq_aligned)
# handedness_aligned = [1 if x=='R' or math.isnan(x) else 0 for x in handedness_aligned]
handedness_aligned = normalize_data(handedness_aligned)
eye_aligned = normalize_data(eye_aligned)



# Get the upper triangle indices (excluding the diagonal)
triu_indices = np.triu_indices(392, k=1)  # Exclude diagonal with k=1
# Extract the upper triangular elements for each 2D slice
upper_triangular = cc[:, triu_indices[0], triu_indices[1]]  # Shape: (1035, 76636)


# filtered datloader -------------
unique_sites, site_counts = np.unique(site_aligned, return_counts=True)
most_common_site = unique_sites[np.argmax(site_counts)]
print(f"Site with most data: {most_common_site}, Count: {np.max(site_counts)}")

# Filter data for the selected site
site_mask = site_aligned == most_common_site
fc_filtered = data_aligned[site_mask]
upper_triangular_filtered = upper_triangular[site_mask]
site_filtered = site_aligned[site_mask]
labels_filtered = labels_aligned[site_mask]
age_filtered = age_aligned[site_mask]
gender_filtered = gender_aligned[site_mask]
handedness_filtered = handedness_aligned[site_mask]
fiq_filtered = fiq_aligned[site_mask]
viq_filtered = viq_aligned[site_mask]
piq_filtered = piq_aligned[site_mask]
eye_filtered = eye_aligned[site_mask]



# Apply PCA
if PCA_flag:
  n_components_pca = 1028  # Number of principal components to keep
  pca = PCA(n_components=n_components_pca)
  upper_triangular_PCA = pca.fit(upper_triangular)
  upper_triangular_PCA = pca.transform(upper_triangular)
  upper_triangular_filtered_PCA = pca.transform(upper_triangular_filtered)
  upper_triangular_filtered = upper_triangular_filtered_PCA
  upper_triangular = upper_triangular_PCA
  print(f'upper_triangular.shape:{upper_triangular.shape}')
  print(f'upper_triangular_filtered.shape:{upper_triangular_filtered.shape}')



# Create dataset and dataloader
filtered_dataset = fMRIDataset_target(fc_filtered, upper_triangular_filtered, labels_filtered, age_filtered, gender_filtered,site_filtered, fiq_filtered, viq_filtered, piq_filtered, eye_filtered, handedness_filtered)
filtered_loader = DataLoader(filtered_dataset, batch_size=8, shuffle=True, collate_fn=custom_collate)


# ----------------------------------------------

# Extract all existing data into a dictionary
data_dict = {
    'fc': data_aligned,
    'cc': upper_triangular,
    'age': age_aligned,
    'gender': gender_aligned,
    'handedness': handedness_aligned,
    'fiq': fiq_aligned,
    'viq': viq_aligned,
    'piq': piq_aligned,
    'labels': labels_aligned,
    'subject_ids': subject_ids_aligned,
    'eye': eye_aligned,
    'site': site_aligned

}



# Verify alignment
print("cc_subject_ids:", cc_subject_ids)
print("subject_ids_aligned:", subject_ids_aligned)
print("Is subject alignment correct?", np.array_equal(cc_subject_ids, subject_ids_aligned))


In [None]:
import torch.nn as nn
class KAN_Mamba_Class(nn.Module):
    def __init__(self,  grid_size=12, spline_order=10, scale_noise=0.0, scale_base=1.0, scale_spline=1.0, grid_eps=0.02):
        super(KAN_Mamba_Class, self).__init__()

        self.device = 'cuda:0'
        self.fc_layer1 = nn.Linear(76636, 512)

        # Define the KAN layer with output size 2
        self.kan_layer = KAN(
                            # [dim_in, int(math.ceil(dim_in/2)), int(math.ceil(dim_in/8))],
                             [512, 256],
                             grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise,
                             scale_base=scale_base, scale_spline=scale_spline, grid_eps=grid_eps).to('cuda:0')
        self.relu_layer = nn.ReLU()
        # Add a linear layer to map the ReLU output from size 4 to size 2
        self.fc_layer2 = nn.Linear(256, 64)

        # Define a softmax layer to output 2 classes
        self.softmax_layer = nn.Softmax(dim=1)


        self.mamba = Mamba(
            d_model=392,
            d_state=128,
            d_conv=4,
            expand=2,
            dt_min=0.001,
            dt_max=0.1,
            dt_init='random',
            dt_scale=1.0,
            dt_init_floor=1e-4,
            conv_bias=True,
            bias=False,
            use_fast_path=True,
            layer_idx=None
        )
        self.mamb_fc = nn.Sequential(
            nn.Linear(392, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )


    def forward(self, x):
        # print(f'x.shape:{x.shape}')
        num_data = x.shape[0]
        m = x[:, 0:316*392] # fc
        # print(f'm flat sahpe:{m.shape}')
        m = m.reshape(num_data, 316, 392)
        # m = torch.from_numpy(m)
        m = m.to(self.device)

        ####### KAN ######
        xx = x[:, 316*392:] # cc
        # xx=torch.from_numpy(xx)
        xx = xx.to(self.device)


        # print(f'm.shape : {m.shape}')
        # print(f'xx.shape : {xx.shape}')
        # print(f"fc_layer1.device: {next(self.fc_layer1.parameters()).device}")

        xx = self.fc_layer1(xx)
        # Pass input through the KAN layer
        kan_out = self.kan_layer(xx)
        # Pass the output through the softmax layer for a 2-class output
        xx=self.relu_layer(kan_out)
        xx= self.fc_layer2(xx)
        xx = self.softmax_layer(xx)

        # print(f'--xx.shape:{xx.shape}')

        ###### MAMBA ######
        # print(f'm input shape : {m.shape}')
        m = self.mamba(m)
        # print(f'm output shape : {m.shape}')
        m = m.mean(dim=1)
        # print(f'm AVG shape : {m.shape}')
        m = self.mamb_fc(m)
        m=self.softmax_layer(m)

        # print(f'--m.shape:{m.shape}')






        return torch.cat((xx, m), dim=1)
        # return xx



In [15]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np

def plot_tsne_label(features, labels, epoch, n_components=2):
    """
    Plot t-SNE visualization of features with points color-coded by their labels.
    Supports 3 or more labels dynamically with custom label names.
    """
    tsne = TSNE(n_components=n_components, random_state=42)
    reduced_features = tsne.fit_transform(features)

    # Map numerical labels to meaningful names
    label_names = {0: "ASD", 1: "TC"}  # Add more mappings if needed

    if n_components == 2:
        # 2D Visualization
        plt.figure(figsize=(10, 7))
        plt.xlim(reduced_features[:, 0].min()-3 , reduced_features[:, 0].max()+3 )
        plt.ylim(reduced_features[:, 1].min() -3, reduced_features[:, 1].max() +3)
        unique_labels = np.unique(labels)
        print(unique_labels)
        colors = plt.cm.get_cmap('tab20', len(unique_labels))  # Dynamically generate colors

        for idx, label in enumerate(unique_labels):
            indices = np.where(labels == label)[0]
            plt.scatter(
                reduced_features[indices, 0],  # First t-SNE component
                reduced_features[indices, 1],  # Second t-SNE component
                label=label_names.get(label, f"{label}"),  # Use custom names or default
                alpha=0.7,
                color=colors(idx)
            )

        plt.title(f't-SNE Visualization of Features (Epoch {epoch})')
        plt.xlabel('t-SNE Component 1')
        plt.ylabel('t-SNE Component 2')
        plt.legend(title="Labels")
        plt.grid(True)
        # plt.savefig(f'./features_label_train_{epoch}.pdf', format='pdf', bbox_inches='tight', dpi=300)

        plt.show()


def plot_tsne_domain(features, labels, epoch, n_components=2):
    """
    Plot t-SNE visualization of features with points color-coded by their labels.
    Supports 3 or more labels dynamically with custom label names.
    """
    tsne = TSNE(n_components=n_components, random_state=42)
    reduced_features = tsne.fit_transform(features)

    label_names = {
    0: 'CALTECH',
    1: 'CMU',
    2: 'KKI',
    3: 'LEUVEN_1',
    4: 'LEUVEN_2',
    5: 'MAX_MUN',
    6: 'NYU',
    7: 'OHSU',
    8: 'OLIN',
    9: 'PITT',
    10: 'SBL',
    11: 'SDSU',
    12: 'STANFORD',
    13: 'TRINITY',
    14: 'UCLA_1',
    15: 'UCLA_2',
    16: 'UM_1',
    17: 'UM_2',
    18: 'USM',
    19: 'YALE'
      }


    # Map numerical labels to meaningful names
    # label_names = {0: "ASD", 1: "TC"}  # Add more mappings if needed

    if n_components == 2:
        # 2D Visualization
        plt.figure(figsize=(10, 7))
        plt.xlim(reduced_features[:, 0].min() -3 , reduced_features[:, 0].max() +3 )
        plt.ylim(reduced_features[:, 1].min() -3, reduced_features[:, 1].max() +3)
        unique_labels = np.unique(labels)

        print(unique_labels)
        colors = plt.cm.get_cmap('tab20', len(unique_labels))  # Dynamically generate colors

        for idx, label in enumerate(unique_labels):
            indices = np.where(labels == label)[0]
            plt.scatter(
                reduced_features[indices, 0],  # First t-SNE component
                reduced_features[indices, 1],  # Second t-SNE component
                label=label_names.get(label, f"{label}"),  # Use custom names or default
                alpha=0.7,
                color=colors(idx)
            )

        plt.title(f't-SNE Visualization of Features (Epoch {epoch})')
        plt.xlabel('t-SNE Component 1')
        plt.ylabel('t-SNE Component 2')
        plt.legend(title="Labels")
        plt.grid(True)
        # plt.savefig(f'./feature_domain_train_{epoch}.pdf', format='pdf', bbox_inches='tight', dpi=300)

        plt.show()

In [None]:
import torcheeg
from torcheeg.trainers import DANNTrainer

import math
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, roc_auc_score

import torch.nn as nn

class Classifier(nn.Module):
    def __init__(self, input_dim=64, hidden_dim=32, output_dim=2, dropout_prob=0.5):
        super(Classifier, self).__init__()

        device = 'cuda:0'
        self.layers = nn.Sequential(KAN(
                             [input_dim, hidden_dim],
                             grid_size=5, spline_order=5, scale_noise=0.0, scale_base=1.0, scale_spline=1.0, grid_eps=0.02).to(device),
            # nn.Linear(input_dim, hidden_dim),  # First linear layer
            nn.ReLU(),                        # Activation function
            nn.Dropout(dropout_prob),         # Dropout for regularization
            nn.Linear(hidden_dim, output_dim), # Output layer for classification
            nn.Softmax(dim=1)

        )

    def forward(self, x):
        # print(f'x shape in Classifier : {x.shape}')

        return self.layers(x)

class DomainClassifier(nn.Module):
    def __init__(self, input_dim=64, hidden_dim=32, output_dim=20, dropout_prob=0.5):
        super(DomainClassifier, self).__init__()

        device = 'cuda:0'
        self.layers = nn.Sequential(
            KAN([input_dim, hidden_dim],
                             grid_size=5, spline_order=5, scale_noise=0.0, scale_base=1.0, scale_spline=1.0, grid_eps=0.02).to(device),  # First linear layer
            nn.ReLU(),                        # Activation function
            nn.Dropout(dropout_prob),         # Dropout for regularization
            nn.Linear(hidden_dim, output_dim), # Output layer for domain classification
            nn.Softmax(dim=1)

        )

    def forward(self, x):
        # print(f'x shape in Domain classifier : {x.shape}')
        return self.layers(x)


def kan_mamba_train_model(train_loader, val_loader,target_loader, val_loader_domain):

    # extractor =  nn.Sequential(
    #     nn.Linear(76636, 512),   # Reduce to intermediate dimensions
    #     nn.ReLU(),
    #     nn.Dropout(0.5),      # Dropout with a probability of 0.5
    #     nn.Linear(512, 128),     # Final output layer (2 classes for classification)
    # ).to(device)
    device = 'cuda:0'
    # extractor = KAN_Class(dim_in=76636, grid_size=12, spline_order=10, scale_noise=0.0, scale_base=1.0, scale_spline=1.0, grid_eps=0.02)
    # extractor = KAN_Class(dim_in=76636, grid_size=12, spline_order=10, scale_noise=0.0, scale_base=1.0, scale_spline=1.0, grid_eps=0.02)
    extractor = KAN_Mamba_Class().to(device)

    # Initialize classifiers
    classifier = Classifier(input_dim=128, hidden_dim=64, output_dim=2, dropout_prob=0.5).to(device)
    domain_classifier = DomainClassifier(input_dim=128, hidden_dim=64, output_dim=20, dropout_prob=0.5).to(device)


    # Define the classifier with dropout
    # classifier = nn.Sequential(
    #     nn.Linear(128, 64),   # Reduce to intermediate dimensions
    #     nn.ReLU(),
    #     nn.Dropout(0.5),      # Dropout with a probability of 0.5
    #     nn.Linear(64, 2),     # Final output layer (2 classes for classification)
    # ).to(device)

    # # Define the domain classifier with dropout
    # domain_classifier = nn.Sequential(
    #     nn.Linear(128, 64),   # Reduce to intermediate dimensions
    #     nn.ReLU(),
    #     nn.Dropout(0.5),      # Dropout with a probability of 0.5
    #     nn.Linear(64, 10),    # Final output layer (10 domains for domain classification)
    # ).to(device)


    trainer = DANNTrainer(extractor,
                          classifier,
                          domain_classifier,
                          num_classes=2,
                          devices=1,
                          accelerator='gpu',
                          lr_scheduler_decay = 0.75,
                          lr=0.0001
                          # lr=0.0000001
                          )
    trainer.verbose = True


    source_loader = train_loader
    # target_loader = target_loader

    # trainer.fit(source_loader, target_loader, train_loader, max_epochs=5)  # !!! This was what V1 results.
    for i in range(10):
      print(f'//////////  {i}  \\\\\\\\\\\\')
      trainer.fit(source_loader, target_loader, val_loader, max_epochs=1)
      # R=trainer.test(val_loader) # !!! This was what V1 results.

      # ---------------------------------
      # ------------ TEST ---------------
      # Assuming you have a trained `extractor` and `classifier`
      all_preds = []
      all_labels = []
      all_domains = []
      all_features = []

      trainer.extractor.eval()
      trainer.classifier.eval()

      with torch.no_grad():
          for data, label, domain in val_loader_domain:

              # print(f'domain is : {domain}')
              # print(f'label is : {label}')

              x = data.to(device)
              # x = [tensor.to(device) for tensor in x]
              # print(f'x.shape: {x.shape}')
              y = label.to(device)
              d = domain.to(device)
              trainer.extractor = trainer.extractor.to(device)
              trainer.classifier = trainer.classifier.to(device)
              features = trainer.extractor(x)  # Extract features

              all_features.append(features.cpu().numpy())
              # all_features.append(features.detach().cpu())
              preds = trainer.classifier(features)  # Predict class probabilities
              all_preds.append(preds.argmax(dim=1).cpu().numpy())  # Predicted labels
              # all_preds.append(preds.argmax(dim=1).detach().cpu())
              all_labels.append(y.cpu().numpy())  # True labels
              all_domains.append(d.cpu().numpy())  # True labels
              # all_labels.append(y.detach().cpu())  # True labels
              # all_domains.append(d.detach().cpu())  # True labels



      # Flatten lists into arrays
      all_preds = np.concatenate(all_preds)
      all_labels = np.concatenate(all_labels)
      all_domains = np.concatenate(all_domains)
      # all_preds = torch.cat(all_preds, dim=0)
      # all_labels = torch.cat(all_labels, dim=0)
      # all_domains = torch.cat(all_domains, dim=0)

      # all_preds = torch.cat(all_preds, dim=0).cpu()
      # all_labels = torch.cat(all_labels, dim=0).cpu()
      # all_domains = torch.cat(all_domains, dim=0).cpu()
      # all_domains = [item for sublist in all_domains for item in sublist]
      # all_domains = np.array(all_domains)
      print('**********************************************************')
      print(all_labels)
      print(all_domains)
      all_features = np.concatenate(all_features)
      # np.save(f'/content/drive/MyDrive/MyMethod/all_features_val_{i}.npy', all_features)


      # Plot t-SNE label
      plot_tsne_label(all_features, all_labels, epoch=i)

      # Plot t-SNE domain
      plot_tsne_domain(all_features, all_domains, epoch=i)


      # Confusion Matrix
      conf_matrix = confusion_matrix(all_labels, all_preds)
      tn, fp, fn, tp = conf_matrix.ravel()  # For binary classification
      accuracy = (tp + tn) / (tp + tn + fp + fn)
      specificity = tn / (tn + fp)


      # Precision, Recall, F1 Score
      precision = precision_score(all_labels, all_preds)
      recall = recall_score(all_labels, all_preds)
      f1 = f1_score(all_labels, all_preds)

      # AUC (binary classification)
      if len(np.unique(all_labels)) == 2:
          auc = roc_auc_score(all_labels, all_preds)
      else:
          # For multiclass, use probabilities
          all_preds_prob = classifier(features).softmax(dim=1).cpu().numpy()
          auc = roc_auc_score(all_labels, all_preds_prob, multi_class='ovr')

      # Print Results
      print('------------------------------------')
      print("Confusion Matrix:\n", conf_matrix)
      print([tn, fp, fn, tp])
      print(f"Accuracy: {accuracy:.4f}")
      print(f"Precision: {precision:.4f}")
      print(f"Recall: {recall:.4f}")
      print(f"Specificity: {specificity:.4f}")
      print(f"F1 Score: {f1:.4f}")
      print(f"AUC: {auc:.4f}")

      print(f'preds:{all_preds}')
      print(f'all_labels:{all_labels}')
      print('------------------------------------')

      R = {'preds':all_preds, 'labels':all_labels, 'conf_mat':[tn, fp, fn, tp], 'acc':accuracy, 'pre':precision, 'rec':recall, 'spe':specificity, 'F1':f1, 'AUC':auc}
    return R, all_features, trainer.classifier

In [None]:
from sklearn.model_selection import KFold
import numpy as np
from efficient_kan.src.efficient_kan.kan import KAN

# Example data
cc = upper_triangular
kf = KFold(n_splits=10, shuffle=True, random_state=42)

# Loop through folds
fold = 1
for i, (train_index, val_index) in enumerate(kf.split(cc)):
  print(f"Fold {fold}:")
  if i < 10 :


    # Split data
    train_fc, val_fc = data_aligned[train_index], data_aligned[val_index]
    train_cc, val_cc = cc[train_index], cc[val_index]
    train_labels, val_labels = labels_aligned[train_index], labels_aligned[val_index]

    train_age, val_age = age_aligned[train_index], age_aligned[val_index]
    train_gender, val_gender = gender_aligned[train_index], gender_aligned[val_index]
    train_handedness, val_handedness = handedness_aligned[train_index], handedness_aligned[val_index]
    train_fiq, val_fiq = fiq_aligned[train_index], fiq_aligned[val_index]
    train_viq, val_viq = viq_aligned[train_index], viq_aligned[val_index]
    train_piq, val_piq = piq_aligned[train_index], piq_aligned[val_index]
    train_eye, val_eye = eye_aligned[train_index], eye_aligned[val_index]
    train_site, val_site = site_aligned[train_index], site_aligned[val_index]

    print(f"Train shape: {train_fc.shape}, Validation shape: {val_fc.shape}")

    # Create Dataset objects
    train_dataset = fMRIDataset(train_fc,train_cc, train_labels, train_age, train_gender,train_site, train_fiq, train_viq, train_piq, train_eye, train_handedness)
    val_dataset = fMRIDataset(    val_fc, val_cc, val_labels, val_age, val_gender,val_site, val_fiq, val_viq, val_piq, val_eye, val_handedness)
    # test_dataset = fMRIDataset(test_fc, test_labels)
    val_domain_dataset = fMRIDataset_domain(    val_fc, val_cc, val_labels, val_age, val_gender,val_site, val_fiq, val_viq, val_piq, val_eye, val_handedness)
    train_domain_dataset = fMRIDataset_domain(train_fc,train_cc, train_labels, train_age, train_gender,train_site, train_fiq, train_viq, train_piq, train_eye, train_handedness)


    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=custom_collate)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=custom_collate)

    val_loader_domain = DataLoader(val_domain_dataset, batch_size=8, shuffle=False, collate_fn=custom_collate_domain)
    train_loader_domain = DataLoader(train_domain_dataset, batch_size=8, shuffle=False, collate_fn=custom_collate_domain)



    # Train your model here using train_fc and train_labels
    # Validate on val_fc and val_labels
    RR = kan_mamba_train_model(train_loader, val_loader, filtered_loader, val_loader_domain)

  fold += 1
