In [1]:
import os
from torch.utils.data import Dataset
import torch
import numpy as np


class CC200Data(Dataset):
    def __init__(self, path, mapping, labels):
        super().__init__()
        self.path = path
        self.mapping = mapping
        self.folder = os.listdir(self.path)
        self.labels = self._map_labels(labels)
        self.files = [self._fix_nans(np.loadtxt(f"{path}/{i}")) for i in self.folder]
        self.region_indices = self._region_mapping()
        self.region_coeff = torch.tensor([self._region_coeffs(i) for i in self.files], dtype=torch.float32)
        self.subnetwork_coeff = [self._subnetwork_coeffs(i) for i in self.files]
        self.region_start_indices = {}
        for i in range(1, 8):
            if i == 1:
                self.region_start_indices[i] = 0
            else:
                self.region_start_indices[i] = self.region_start_indices[i - 1] + len(
                    self.region_indices[i - 1]
                )
                
    def _map_labels(self, mapping):
        labels = []
        for i in self.folder:
            if i[:-14] in mapping.keys():
                labels.append(mapping[i[:-14]])
            else:
                print(i)

        return torch.tensor(labels, dtype=torch.long)
        
    def _fix_nans(self, x):
        std = np.std(x, axis=0)
        zeroes = std == 0
    
        if zeroes.any():
            noise = np.random.normal(loc=0.0, scale=1e-6, size=(x.shape[0], zeroes.sum()))
            x[:, zeroes] = noise
    
        return x

    def _region_coeffs(self, x):
        b = np.zeros_like(x)
        y = 0
        for i in self.region_indices:
            b[:, y : y + len(self.region_indices[i])] = x[:, self.region_indices[i]]
            y += len(self.region_indices[i])
        b = b[:, 15:]
        x = np.corrcoef(b, rowvar=False)
        return x

    def _region_mapping(self):
        region_indices = {i: [] for i in range(0, 8)}
        for i, j in self.mapping.items():
            region_indices[j].append(i - 1)
        return region_indices

    def _subnetwork_coeffs(self, x):
        correlation_matrices = []

        for file_idx, (region_id, indices) in enumerate(self.region_indices.items()):
            if not indices:
                print(f"[INFO] Region {region_id} has no indices, skipping.")
                continue

            submatrix = x[:, indices]
            std = np.std(submatrix, axis=0)

            zero_std_count = np.sum(std == 0)
            if zero_std_count > 0:
                print(f"[INFO] File {self.folder[file_idx]} - Region {region_id} has {zero_std_count} constant columns.")

            try:
                correlation_matrix = np.corrcoef(submatrix, rowvar=False)
            except Exception as e:
                print(f"[ERROR] Correlation failed in region {region_id}")
                print(f"Exception: {e}")
                continue

            flat_corr = self._flatten_matrix(correlation_matrix)
            correlation_matrices.append(torch.tensor(flat_corr, dtype=torch.float32))

        return correlation_matrices

    def _flatten_matrix(self, matrix):
        idx = np.triu_indices_from(matrix, k=1)
        return matrix[idx]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        region_corr = self.region_coeff[idx]
        subnetwork_corrs = self.subnetwork_coeff[idx]
        labels = self.labels[idx]
        return region_corr, subnetwork_corrs, labels

In [6]:
!wget https://raw.githubusercontent.com/broccubali/AutisticAdventures/main/cc200_to_yeo7_mapping.csv

--2025-04-22 03:08:48--  https://raw.githubusercontent.com/broccubali/AutisticAdventures/main/cc200_to_yeo7_mapping.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1319 (1.3K) [text/plain]
Saving to: ‘cc200_to_yeo7_mapping.csv’


2025-04-22 03:08:48 (86.5 MB/s) - ‘cc200_to_yeo7_mapping.csv’ saved [1319/1319]



In [5]:
! wget https://s3.amazonaws.com/fcp-indi/data/Projects/ABIDE_Initiative/Phenotypic_V1_0b_preprocessed1.csv

--2025-04-22 03:08:43--  https://s3.amazonaws.com/fcp-indi/data/Projects/ABIDE_Initiative/Phenotypic_V1_0b_preprocessed1.csv
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.184.213, 16.182.72.40, 52.216.107.174, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.184.213|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 449443 (439K) [application/octet-stream]
Saving to: ‘Phenotypic_V1_0b_preprocessed1.csv’


2025-04-22 03:08:43 (4.02 MB/s) - ‘Phenotypic_V1_0b_preprocessed1.csv’ saved [449443/449443]



In [7]:
import pandas as pd

df = pd.read_csv('cc200_to_yeo7_mapping.csv')
cc200_to_yeo7_mapping = dict(zip(df['CC200_Region'], df['Yeo7_Network'])) 

In [8]:
df = pd.read_csv("Phenotypic_V1_0b_preprocessed1.csv")
df = df[["FILE_ID", "DX_GROUP"]]
labels_mapping = dict(zip(df["FILE_ID"], df["DX_GROUP"]))

In [10]:
# had to fix this here cuz pytorch cross entropy loss needs 0 and 1 not 1 and 2
new_labels_mapping = {}
for key, value in labels_mapping.items():
    new_labels_mapping[key] = value - 1  # Subtract 1 to convert 1,2 to 0,1

# Recreate your dataset with adjusted labels
dataset = CC200Data("/kaggle/input/autistic-brains/Outputs/cpac/nofilt_noglobal/rois_cc200", cc200_to_yeo7_mapping, new_labels_mapping)


  self.region_coeff = torch.tensor([self._region_coeffs(i) for i in self.files], dtype=torch.float32)


In [11]:
from torch.utils.data import DataLoader
train_loader = DataLoader(dataset, batch_size=16, shuffle=True)

In [12]:
unique_labels = set()
for _, _, labels in train_loader:
    unique_labels.update(labels.numpy())
print(f"updated label values: {sorted(list(unique_labels))}")

updated label values: [0, 1]


In [13]:
shapes = []
a = next(iter(dataset))[1]
for i in a:
    shapes.append(i.shape[0])

# len(shapes)

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MHSA(nn.Module):
    def __init__(self, embd_dim, num_heads):
        super().__init__()
        self.embd_dim = embd_dim
        self.num_heads = num_heads
        self.head_size = self.embd_dim // self.num_heads
        self.q = nn.Linear(self.embd_dim, self.embd_dim)
        self.k = nn.Linear(self.embd_dim, self.embd_dim)
        self.v = nn.Linear(self.embd_dim, self.embd_dim)
        self.d = self.head_size ** 0.5
        self.mlp = nn.Linear(self.embd_dim, self.embd_dim)
        self.layer_norm = nn.LayerNorm(self.embd_dim)  
        
    def forward(self, x):
        batch_size, M, _ = x.shape
        norm = self.layer_norm(x)
        q = self.q(norm).view(batch_size, M, self.num_heads, self.head_size).transpose(1, 2)
        k = self.k(norm).view(batch_size, M, self.num_heads, self.head_size).transpose(1, 2)
        v = self.v(norm).view(batch_size, M, self.num_heads, self.head_size).transpose(1, 2)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / self.d
        attn_scores = attn_scores.masked_fill(torch.eye(M, device=x.device).bool(), float('-inf'))
        attn_weights = F.softmax(attn_scores, dim=-1)
        context = torch.matmul(attn_weights, v).transpose(1, 2).reshape(batch_size, M, self.embd_dim)
        out = self.mlp(context)
        return out + x, attn_weights

In [15]:
class SubnetworkEmbedder(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, output_dim=128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        return self.mlp(x)

In [16]:
class RegionEmbedder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.region_conv = nn.Conv2d(1, 1, kernel_size=1)
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )       

    def forward(self, x):
        x_conv = self.region_conv(x.unsqueeze(1)) 
        x_conv = x_conv.squeeze(1)  
        return self.mlp(x_conv) 

In [17]:
class RegionEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, embd_dim, num_heads, num_layers):
        super().__init__()
        self.reg_embd = RegionEmbedder(input_dim, hidden_dim, embd_dim)
        self.mhsa_layers = nn.ModuleList([MHSA(embd_dim, num_heads) for _ in range(num_layers)])

    def forward(self, x):
        x_reg = self.reg_embd(x)
        x_in = x_reg
        attn_weights_all = []
        for mhsa in self.mhsa_layers:
            x_in, attn_weights = mhsa(x_in)
            attn_weights_all.append(attn_weights)
        
        return x_reg + x_in, torch.stack(attn_weights_all).permute(1, 0, 2, 3, 4)

In [18]:
class SubNetworkEncoder(nn.Module):
    def __init__(self, shapes, hidden_dim, embd_dim, num_heads, num_layers):
        super().__init__()
        self.embd_dim = embd_dim
        self.mlps = nn.ModuleList([SubnetworkEmbedder(i, hidden_dim, embd_dim) for i in shapes])
        self.mhsa_layers = nn.ModuleList([MHSA(embd_dim, num_heads) for _ in range(num_layers)])

    def forward(self, x):
        batch_size = x[0].shape[0]
        x = torch.stack([mlp(f) for mlp, f in zip(self.mlps, x)], dim=1)
        attn_weights_all = []
        for mhsa in self.mhsa_layers:
            x, attn_weights = mhsa(x)
            attn_weights_all.append(attn_weights)

        return x, torch.stack(attn_weights_all).permute(1, 0, 2, 3, 4)

In [19]:
class StepOne(nn.Module):
    def __init__(self, input_dim, hidden_dim, embd_dim, num_heads, num_layers):
        super().__init__()
        self.reg_enc = RegionEncoder(input_dim, hidden_dim, embd_dim, num_heads, num_layers)
        self.subnet_enc = SubNetworkEncoder(shapes, hidden_dim, embd_dim, num_heads, num_layers)
        self.layer_norm = nn.LayerNorm(embd_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embd_dim, hidden_dim),  
            nn.ReLU(),             
            nn.Linear(hidden_dim, embd_dim)    
        )
        self.region_start_indices = list(dataset.region_start_indices.values()) + [185]

    def subNetworkAttendRegions(self, subnet_attn_map, region_attn_map):
        region_to_subnet = torch.zeros(185, dtype=torch.long)
        for subnet_id in range(7):
            start = self.region_start_indices[subnet_id]
            end = self.region_start_indices[subnet_id + 1]
            region_to_subnet[start:end] = subnet_id  
        subnet_i = region_to_subnet.view(-1, 1).expand(185, 185) 
        subnet_j = region_to_subnet.view(1, -1).expand(185, 185)  

        mask = subnet_i != subnet_j  

        attn_multiplier = subnet_attn_map[:, :, :, subnet_i, subnet_j]  
        attn_multiplier = attn_multiplier * mask

        return region_attn_map * attn_multiplier 
    
    def sinkhorn(self, attn, n_iters=5, eps=1e-6):
        attn = attn + eps  
        for _ in range(n_iters):
            attn = attn / attn.sum(dim=-1, keepdim=True)
            attn = attn / attn.sum(dim=-2, keepdim=True)
        return attn
    
    def forward(self, x):
        x0 = self.reg_enc(x[0])
        x1 = self.subnet_enc(x[1])
        o = torch.cat((x0[0], x1[0]), dim=1)
        o_norm = self.layer_norm(o)
        o_norm = self.mlp(o)
        o = o + o_norm
        print(o.shape)
        o_reg = o[:, :185, :]
        o_sub = o[:, 185:, :]
        adj_matrix = self.subNetworkAttendRegions(x1[1], x0[1])
        adj_matrix = self.sinkhorn(adj_matrix)
        return o_reg, o_sub, adj_matrix

In [None]:
class HGCN(nn.Module):
    def __init__(self, input_dim, output_dim, num_layers=4, num_heads=16):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.num_heads = num_heads

        # One linear layer per head per layer: (L, H, in, out)
        self.W = nn.Parameter(torch.randn(num_layers, num_heads, input_dim, output_dim))
        self.activation = nn.ReLU()

    def forward(self, features, attention_maps):
        """
        features: [B, N, F_in]
        attention_maps: [B, L, H, N, N] — soft adjacency or incidence maps
        """
        B, L, H, N, _ = attention_maps.shape
        F_in, F_out = self.input_dim, self.output_dim

        # Apply W to input features: [B, 1, 1, N, F_in] x [L, H, F_in, F_out]
        # -> output: [B, L, H, N, F_out]
        features_exp = features[:, None, None, :, :]  # [B, 1, 1, N, F_in]
        weights = self.W[None, :, :, :, :]            # [1, L, H, F_in, F_out]
        transformed = torch.matmul(features_exp, weights)  # [B, L, H, N, F_out]

        # Apply hypergraph attention maps
        # attention_maps: [B, L, H, N, N]
        # transformed:     [B, L, H, N, F_out]
        output = torch.matmul(attention_maps, transformed)  # [B, L, H, N, F_out]
        output = self.activation(output)

        # Concatenate heads → [B, L, N, H * F_out]
        output = output.permute(0, 1, 3, 2, 4).reshape(B, L, N, H * F_out)

        # Concatenate layers → [B, N, L * H * F_out]
        output = output.permute(0, 2, 1, 3).reshape(B, N, L * H * F_out)

        return output

In [25]:
class StepOneWithHGCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, embd_dim, num_heads, num_layers, num_classes=2):
        super().__init__()
        # Copy paste
        self.reg_enc = RegionEncoder(input_dim, hidden_dim, embd_dim, num_heads, num_layers)
        self.subnet_enc = SubNetworkEncoder(shapes, hidden_dim, embd_dim, num_heads, num_layers)
        self.layer_norm = nn.LayerNorm(embd_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embd_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embd_dim)
        )
        self.region_start_indices = list(dataset.region_start_indices.values()) + [185]
        self.hgcn = HGCN(input_dim=embd_dim, output_dim=embd_dim//num_heads)
        hgcn_output_dim = (embd_dim//num_heads) * num_heads * num_layers
        self.classifier = nn.Sequential(
            nn.Linear(hgcn_output_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, num_classes)
        )
    
    def subNetworkAttendRegions(self, subnet_attn_map, region_attn_map):
        region_to_subnet = torch.zeros(185, dtype=torch.long)
        for subnet_id in range(7):
            start = self.region_start_indices[subnet_id]
            end = self.region_start_indices[subnet_id + 1]
            region_to_subnet[start:end] = subnet_id
        subnet_i = region_to_subnet.view(-1, 1).expand(185, 185)
        subnet_j = region_to_subnet.view(1, -1).expand(185, 185)
        mask = subnet_i != subnet_j
        attn_multiplier = subnet_attn_map[:, :, :, subnet_i, subnet_j]
        mask = mask.to("cuda")
        attn_multiplier = attn_multiplier * mask
        return region_attn_map * attn_multiplier
    
    def sinkhorn(self, attn, n_iters=5, eps=1e-6):
        attn = attn + eps
        for _ in range(n_iters):
            attn = attn / attn.sum(dim=-1, keepdim=True)
            attn = attn / attn.sum(dim=-2, keepdim=True)
        return attn
    
    def forward(self, x):
        x0, region_attn = self.reg_enc(x[0])
        x1, subnet_attn = self.subnet_enc(x[1])

        o = torch.cat((x0, x1), dim=1)
        o_norm = self.layer_norm(o)
        o_norm = self.mlp(o_norm)
        o = o + o_norm
        o_reg = o[:, :185, :]  # First 185 nodes are regions
        o_sub = o[:, 185:, :]  # Remaining nodes are subnetworks
        # Process attention maps to create the combined attention 
        # with shape [batch_size, num_layers (4), num_heads (16), 185, 185]
        # combined_attn = self.subNetworkAttendRegions(subnet_attn, region_attn)
        # combined_attn = self.sinkhorn(combined_attn)
        
        # # Pass through HGCN - only process the region features with the combined attention
        # hgcn_output = self.hgcn(o_reg, combined_attn)

        # Option 1: combined attention
        # combined_attn = self.subNetworkAttendRegions(subnet_attn, region_attn)
        # combined_attn = self.sinkhorn(combined_attn)
        
        # Option 2: sub-net attention
        region_attn = self.sinkhorn(region_attn)
        combined_attn = region_attn

        hgcn_output = self.hgcn(o_reg, combined_attn) # leave unchanged

        # Global average pooling for classification
        pooled_output = hgcn_output.mean(dim=1)  # [batch_size, output_dim*num_heads*num_layers]
        # Final classif
        logits = self.classifier(pooled_output)
        return logits

In [26]:
from tqdm import tqdm

model = StepOneWithHGCN(
    input_dim=185,
    hidden_dim=256,
    embd_dim=128,
    num_heads=16,
    num_layers=4,
    num_classes=2
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

num_epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    # tqdm for batch progress
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for batch_idx, (region_data, subnetwork_data, labels) in pbar:
        region_data = region_data.to(device)
        subnetwork_data = [subnet.to(device) for subnet in subnetwork_data]
        labels = labels.to(device)

        optimizer.zero_grad()

        x = (region_data, subnetwork_data)
        logits = model(x)
        loss = criterion(logits, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(logits.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        pbar.set_postfix({
            'Loss': f'{running_loss / (batch_idx + 1):.4f}',
            'Acc': f'{100 * correct / total:.2f}%'
        })

torch.save(model.state_dict(), 'brain_network_model.pth')


Epoch 1/50: 100%|██████████| 56/56 [00:08<00:00,  6.53it/s, Loss=0.7821, Acc=50.57%]
Epoch 2/50: 100%|██████████| 56/56 [00:07<00:00,  7.24it/s, Loss=0.7567, Acc=52.38%]
Epoch 3/50: 100%|██████████| 56/56 [00:07<00:00,  7.25it/s, Loss=0.7362, Acc=51.36%]
Epoch 4/50: 100%|██████████| 56/56 [00:07<00:00,  7.25it/s, Loss=0.7141, Acc=54.52%]
Epoch 5/50: 100%|██████████| 56/56 [00:07<00:00,  7.23it/s, Loss=0.7184, Acc=50.68%]
Epoch 6/50: 100%|██████████| 56/56 [00:07<00:00,  7.23it/s, Loss=0.7250, Acc=50.11%]
Epoch 7/50: 100%|██████████| 56/56 [00:07<00:00,  7.22it/s, Loss=0.7178, Acc=54.07%]
Epoch 8/50: 100%|██████████| 56/56 [00:07<00:00,  7.21it/s, Loss=0.7127, Acc=52.04%]
Epoch 9/50: 100%|██████████| 56/56 [00:07<00:00,  7.20it/s, Loss=0.7093, Acc=51.70%]
Epoch 10/50: 100%|██████████| 56/56 [00:07<00:00,  7.17it/s, Loss=0.6998, Acc=51.58%]
Epoch 11/50: 100%|██████████| 56/56 [00:07<00:00,  7.16it/s, Loss=0.7011, Acc=48.42%]
Epoch 12/50: 100%|██████████| 56/56 [00:07<00:00,  7.15it/s, Lo