In [104]:
import os
from torch.utils.data import Dataset
import torch
import numpy as np


class CC200Data(Dataset):
    def __init__(self, path, mapping, labels):
        super().__init__()
        self.path = path
        self.mapping = mapping
        self.folder = os.listdir(self.path)
        self.labels = self._map_labels(labels)
        self.files = [np.loadtxt(f"{path}/{i}") for i in self.folder]
        self.region_indices = self._region_mapping()
        self.region_coeff = torch.tensor([np.corrcoef(i, rowvar=False) for i in self.files], dtype=torch.float32)
        self.subnetwork_coeff = [self._subnetwork_coeffs(i) for i in self.files]

    def _map_labels(self, mapping):
        labels = []
        for i in self.folder:
            if i[:-14] in mapping.keys():
                labels.append(mapping[i[:-14]])
            else:
                print(i)

        return torch.tensor(labels, dtype=torch.long)
        
    def _region_mapping(self):
        region_indices = {i: [] for i in range(1, 8)}
        for i, j in self.mapping.items():
            if j == 0:
                continue
            region_indices[j].append(i - 1)
        return region_indices
    
    def _subnetwork_coeffs(self, x):
        correlation_matrices = []
        
        for file_idx, (region_id, indices) in enumerate(self.region_indices.items()):
            if not indices:
                print(f"[INFO] Region {region_id} has no indices, skipping.")
                continue
    
            submatrix = x[:, indices]
            std = np.std(submatrix, axis=0)
    
            zero_std_count = np.sum(std == 0)
            if zero_std_count > 0:
                print(f"[INFO] File {self.folder[file_idx]} - Region {region_id} has {zero_std_count} constant columns.")
                
            try:
                correlation_matrix = np.corrcoef(submatrix, rowvar=False)
            except Exception as e:
                print(f"[ERROR] Correlation failed in region {region_id}")
                print(f"Exception: {e}")
                continue
    
            flat_corr = self._flatten_matrix(correlation_matrix)
            correlation_matrices.append(torch.tensor(flat_corr, dtype=torch.float32))
    
        return correlation_matrices

        
    def _flatten_matrix(self, matrix):
        idx = np.triu_indices_from(matrix, k=1)
        return matrix[idx]
    
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        region_corr = self.region_coeff[idx]
        subnetwork_corrs = self.subnetwork_coeff[idx]
        labels = self.labels[idx]
        return region_corr, subnetwork_corrs, labels

In [78]:
!wget https://raw.githubusercontent.com/broccubali/AutisticAdventures/main/cc200_to_yeo7_mapping.csv
!wget https://s3.amazonaws.com/fcp-indi/data/Projects/ABIDE_Initiative/Phenotypic_V1_0b_preprocessed1.csv

--2025-04-17 04:35:33--  https://raw.githubusercontent.com/broccubali/AutisticAdventures/main/cc200_to_yeo7_mapping.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1319 (1.3K) [text/plain]
Saving to: ‘cc200_to_yeo7_mapping.csv.1’


2025-04-17 04:35:33 (39.3 MB/s) - ‘cc200_to_yeo7_mapping.csv.1’ saved [1319/1319]

--2025-04-17 04:35:33--  https://s3.amazonaws.com/fcp-indi/data/Projects/ABIDE_Initiative/Phenotypic_V1_0b_preprocessed1.csv
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.90.182, 54.231.200.64, 16.182.38.248, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.90.182|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 449443 (439K) [application/octet-stream]
Saving to: ‘Phenotypic_V1_0b_preprocessed1.csv’

In [95]:
import pandas as pd

df = pd.read_csv('/kaggle/working/cc200_to_yeo7_mapping.csv')
cc200_to_yeo7_mapping = dict(zip(df['CC200_Region'], df['Yeo7_Network'])) 

In [96]:
df = pd.read_csv("/kaggle/working/Phenotypic_V1_0b_preprocessed1.csv")
df = df[["FILE_ID", "DX_GROUP"]]
labels_mapping = dict(zip(df["FILE_ID"], df["DX_GROUP"]))

In [105]:
dataset = CC200Data("/kaggle/input/autistic-brains/Outputs/cpac/nofilt_noglobal/rois_cc200", cc200_to_yeo7_mapping, labels_mapping)

[INFO] File UM_1_0050321_rois_cc200.1D - Region 5 has 1 constant columns.
[INFO] File UM_1_0050321_rois_cc200.1D - Region 5 has 3 constant columns.
[INFO] File UM_1_0050315_rois_cc200.1D - Region 1 has 19 constant columns.
[INFO] File Stanford_0051165_rois_cc200.1D - Region 2 has 13 constant columns.
[INFO] File NYU_0050993_rois_cc200.1D - Region 3 has 16 constant columns.
[INFO] File UCLA_1_0051211_rois_cc200.1D - Region 4 has 9 constant columns.
[INFO] File USM_0050496_rois_cc200.1D - Region 6 has 7 constant columns.
[INFO] File Yale_0050619_rois_cc200.1D - Region 7 has 15 constant columns.
[INFO] File USM_0050496_rois_cc200.1D - Region 6 has 1 constant columns.
[INFO] File UM_1_0050321_rois_cc200.1D - Region 5 has 1 constant columns.
[INFO] File UM_1_0050321_rois_cc200.1D - Region 5 has 2 constant columns.
[INFO] File UM_1_0050321_rois_cc200.1D - Region 5 has 2 constant columns.
[INFO] File USM_0050496_rois_cc200.1D - Region 6 has 3 constant columns.
[INFO] File Yale_0050619_rois_cc

In [106]:
next(iter(dataset))[2]

tensor(1)

In [11]:
shapes = []
a = next(iter(dataset))[1]
for i in a:
    shapes.append(i.shape[0])

In [12]:
from torch.utils.data import DataLoader
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MHSA(nn.Module):
    def __init__(self, embd_dim, num_heads):
        super().__init__()
        self.embd_dim = embd_dim
        self.num_heads = num_heads
        self.head_size = self.embd_dim // self.num_heads
        self.q = nn.Linear(self.embd_dim, self.embd_dim)
        self.k = nn.Linear(self.embd_dim, self.embd_dim)
        self.v = nn.Linear(self.embd_dim, self.embd_dim)
        self.d = self.head_size ** 0.5
        self.mlp = nn.Linear(self.embd_dim, self.embd_dim)
        self.layer_norm = nn.LayerNorm(self.embd_dim)  
        
    def forward(self, x):
        batch_size, M, _ = x.shape
        norm = self.layer_norm(x)
        q = self.q(norm).view(batch_size, M, self.num_heads, self.head_size).transpose(1, 2)
        k = self.k(norm).view(batch_size, M, self.num_heads, self.head_size).transpose(1, 2)
        v = self.v(norm).view(batch_size, M, self.num_heads, self.head_size).transpose(1, 2)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / self.d
        attn_scores = attn_scores.masked_fill(torch.eye(M, device=x.device).bool(), float('-inf'))
        attn_weights = F.softmax(attn_scores, dim=-1)
        context = torch.matmul(attn_weights, v).transpose(1, 2).reshape(batch_size, M, self.embd_dim)
        out = self.mlp(context)
        return out + x, attn_weights

In [14]:
class SubnetworkEmbedder(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, output_dim=128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        return self.mlp(x)

In [15]:
class RegionEmbedder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.region_conv = nn.Conv2d(1, 1, kernel_size=1)
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )       

    def forward(self, x):
        x_conv = self.region_conv(x.unsqueeze(1)) 
        x_conv = x_conv.squeeze(1)  
        return self.mlp(x_conv)  

In [16]:
class RegionEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, embd_dim, num_heads, num_layers):
        super().__init__()
        self.reg_embd = RegionEmbedder(input_dim, hidden_dim, embd_dim)
        self.mhsa_layers = nn.ModuleList([MHSA(embd_dim, num_heads) for _ in range(num_layers)])

    def forward(self, x):
        x_reg = self.reg_embd(x)
        x_in = x_reg
        attn_weights_all = []
        for mhsa in self.mhsa_layers:
            x_in, attn_weights = mhsa(x_in)
            attn_weights_all.append(attn_weights)
        
        return x_reg + x_in, torch.stack(attn_weights_all)

In [17]:
class SubNetworkEncoder(nn.Module):
    def __init__(self, shapes, hidden_dim, embd_dim, num_heads, num_layers):
        super().__init__()
        self.embd_dim = embd_dim
        self.mlps = [SubnetworkEmbedder(i, hidden_dim, embd_dim) for i in shapes]
        self.mhsa_layers = nn.ModuleList([MHSA(embd_dim, num_heads) for _ in range(num_layers)])
        
    def forward(self, x):
        batch_size = x[0].shape[0]
        x = torch.stack([mlp(f) for mlp, f in zip(self.mlps, x)], dim=1)
        attn_weights_all = []
        for mhsa in self.mhsa_layers:
            x, attn_weights = mhsa(x)
            attn_weights_all.append(attn_weights)
        
        return x, torch.stack(attn_weights_all)

In [36]:
model = SubNetworkEncoder(shapes, 256, 128, 8, 4)
a = next(iter(train_loader))[1]
model(a)[1].shape

torch.Size([4, 64, 8, 7, 7])

In [37]:
b = next(iter(train_loader))[0]
model1 = RegionEncoder(200, 256, 128, 8, 4)
model1(b)[1].shape

torch.Size([4, 64, 8, 200, 200])

In [81]:
class StepOne(nn.Module):
    def __init__(self, input_dim, hidden_dim, embd_dim, num_heads, num_layers):
        super().__init__()
        self.reg_enc = RegionEncoder(input_dim, hidden_dim, embd_dim, num_heads, num_layers)
        self.subnet_enc = SubNetworkEncoder(shapes, hidden_dim, embd_dim, num_heads, num_layers)
        self.layer_norm = nn.LayerNorm(embd_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embd_dim, hidden_dim),  
            nn.ReLU(),             
            nn.Linear(hidden_dim, embd_dim)    
        )


    def forward(self, x):
        x0 = self.reg_enc(x[0])
        x1 = self.subnet_enc(x[1])
        o = torch.cat((x0[0], x1[0]), dim=1)
        o_norm = self.layer_norm(o)
        o_norm = self.mlp(o)
        o = o + o_norm
        o_reg = o[:, :200, :]
        o_sub = o[:, 200:, :]
        return o_reg, o_sub

In [83]:
model = StepOne(200, 256, 128, 16, 4)
model(next(iter(train_loader)))[1].shape

torch.Size([64, 7, 128])

In [72]:
x = model(next(iter(train_loader)))[1][1]
x.shape

torch.Size([4, 64, 16, 7, 7])

In [73]:
for i in range(7):
    print(x[0][0][0].sum(dim=1))
    print(x[0][0][0].sum(dim=0))

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)
tensor([1.0004, 1.0825, 0.9789, 1.2091, 0.7094, 1.0550, 0.9647],
       grad_fn=<SumBackward1>)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)
tensor([1.0004, 1.0825, 0.9789, 1.2091, 0.7094, 1.0550, 0.9647],
       grad_fn=<SumBackward1>)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)
tensor([1.0004, 1.0825, 0.9789, 1.2091, 0.7094, 1.0550, 0.9647],
       grad_fn=<SumBackward1>)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)
tensor([1.0004, 1.0825, 0.9789, 1.2091, 0.7094, 1.0550, 0.9647],
       grad_fn=<SumBackward1>)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)
tensor([1.0004, 1.0825, 0.9789, 1.2091, 0.7094, 1.0550, 0.9647],
       grad_fn=<SumBackward1>)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 

In [74]:
def sinkhorn(attn, n_iters=5, eps=1e-6):
    attn = attn + eps  
    for _ in range(n_iters):
        attn = attn / attn.sum(dim=-1, keepdim=True)
        attn = attn / attn.sum(dim=-2, keepdim=True)
    return attn
y = sinkhorn(x)

In [75]:
for i in range(7):
    print(y[0][0][0].sum(dim=1))
    print(y[0][0][0].sum(dim=0))

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 

In [76]:
s = y[:, 0, :, :, :]
s.shape

torch.Size([4, 16, 7, 7])

In [77]:
s[0, 0, :, :]

tensor([[9.5824e-07, 1.5369e-01, 9.9605e-02, 1.4148e-01, 2.6500e-01, 1.7412e-01,
         1.6610e-01],
        [1.5972e-01, 9.0825e-07, 2.1339e-01, 1.4781e-01, 1.3955e-01, 1.4719e-01,
         1.9234e-01],
        [2.0664e-01, 1.6504e-01, 1.0493e-06, 2.0958e-01, 1.3758e-01, 1.6063e-01,
         1.2053e-01],
        [1.6250e-01, 1.2655e-01, 1.7902e-01, 7.9477e-07, 1.5100e-01, 2.1836e-01,
         1.6256e-01],
        [1.5115e-01, 1.7011e-01, 1.9677e-01, 1.3087e-01, 1.4827e-06, 1.6547e-01,
         1.8562e-01],
        [1.2713e-01, 2.1097e-01, 1.8881e-01, 1.4846e-01, 1.5178e-01, 9.4508e-07,
         1.7285e-01],
        [1.9286e-01, 1.7364e-01, 1.2239e-01, 2.2180e-01, 1.5509e-01, 1.3422e-01,
         1.0651e-06]], grad_fn=<SliceBackward0>)