In [2]:
import os
from torch.utils.data import Dataset
import torch
import numpy as np


class CC200Data(Dataset):
    def __init__(self, path, mapping):
        super().__init__()
        self.path = path
        self.mapping = mapping
        self.folder = os.listdir(self.path)
        self.files = [np.loadtxt(f"{path}/{i}") for i in self.folder]
        self.region_coeff = torch.tensor([np.corrcoef(i, rowvar=False) for i in self.files], dtype=torch.float32)
        self.subnetwork_coeff = [self._subnetwork_coeffs(i) for i in self.files]
        
    def _region_mapping(self, x):
        region_indices = {i: [] for i in range(1, 8)}
        for i, j in self.mapping.items():
            if j == 0:
                continue
            region_indices[j].append(i - 1)
        return region_indices

    def _subnetwork_coeffs(self, x):
        region_indices = self._region_mapping(x)
        correlation_matrices = []
        for i in region_indices.keys():
            m = x[:, region_indices[i]]
            correlation_matrix = np.corrcoef(m, rowvar=False)
            correlation_matrices.append(torch.tensor(self._flatten_matrix(correlation_matrix), dtype=torch.float32))
        return correlation_matrices
        
    def _flatten_matrix(self, matrix):
        idx = np.triu_indices_from(matrix, k=1)
        return matrix[idx]
    
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        region_corr = self.region_coeff[idx]
        subnetwork_corrs = self.subnetwork_coeff[idx]
        return region_corr, subnetwork_corrs

In [3]:
!wget https://raw.githubusercontent.com/broccubali/AutisticAdventures/main/cc200_to_yeo7_mapping.csv

--2025-04-17 01:12:44--  https://raw.githubusercontent.com/broccubali/AutisticAdventures/main/cc200_to_yeo7_mapping.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1319 (1.3K) [text/plain]
Saving to: ‘cc200_to_yeo7_mapping.csv’


2025-04-17 01:12:44 (59.7 MB/s) - ‘cc200_to_yeo7_mapping.csv’ saved [1319/1319]



In [4]:
import pandas as pd

df = pd.read_csv('/kaggle/working/cc200_to_yeo7_mapping.csv')
cc200_to_yeo7_mapping = dict(zip(df['CC200_Region'], df['Yeo7_Network'])) 

In [5]:
dataset = CC200Data("/kaggle/input/autistic-brains/Outputs/cpac/nofilt_noglobal/rois_cc200", cc200_to_yeo7_mapping)

  c /= stddev[:, None]
  c /= stddev[:, None]
  c /= stddev[None, :]
  c /= stddev[None, :]
  self.region_coeff = torch.tensor([np.corrcoef(i, rowvar=False) for i in self.files], dtype=torch.float32)


In [6]:
shapes = []
a = next(iter(dataset))[1]
for i in a:
    shapes.append(i.shape[0])

In [7]:
from torch.utils.data import DataLoader
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MHSA(nn.Module):
    def __init__(self, embd_dim, num_heads):
        super().__init__()
        self.embd_dim = embd_dim
        self.num_heads = num_heads
        self.head_size = self.embd_dim // self.num_heads
        self.q = nn.Linear(self.embd_dim, self.embd_dim)
        self.k = nn.Linear(self.embd_dim, self.embd_dim)
        self.v = nn.Linear(self.embd_dim, self.embd_dim)
        self.d = self.head_size ** 0.5
        self.mlp = nn.Linear(self.embd_dim, self.embd_dim)
        self.layer_norm = nn.LayerNorm(self.embd_dim)  
        
    def forward(self, x):
        batch_size, M, _ = x.shape
        norm = self.layer_norm(x)
        q = self.q(norm).view(batch_size, M, self.num_heads, self.head_size).transpose(1, 2)
        k = self.k(norm).view(batch_size, M, self.num_heads, self.head_size).transpose(1, 2)
        v = self.v(norm).view(batch_size, M, self.num_heads, self.head_size).transpose(1, 2)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / self.d
        attn_scores = attn_scores.masked_fill(torch.eye(M, device=x.device).bool(), float('-inf'))
        attn_weights = F.softmax(attn_scores, dim=-1)
        context = torch.matmul(attn_weights, v).transpose(1, 2).reshape(batch_size, M, self.embd_dim)
        out = self.mlp(context)
        return out + x, attn_weights

In [25]:
class SubnetworkEmbedder(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, output_dim=128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        return self.mlp(x)

In [26]:
class RegionEmbedder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.region_conv = nn.Conv2d(1, 1, kernel_size=1)
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )       

    def forward(self, x):
        x_conv = self.region_conv(x.unsqueeze(1)) 
        x_conv = x_conv.squeeze(1)  
        return self.mlp(x_conv)  

In [30]:
class RegionEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, embd_dim, num_heads, num_layers):
        super().__init__()
        self.reg_embd = RegionEmbedder(input_dim, hidden_dim, embd_dim)
        self.mhsa_layers = nn.ModuleList([MHSA(embd_dim, num_heads) for _ in range(num_layers)])

    def forward(self, x):
        x_reg = self.reg_embd(x)
        x_in = x_reg
        attn_weights_all = []
        for mhsa in self.mhsa_layers:
            x_in, attn_weights = mhsa(x_in)
            attn_weights_all.append(attn_weights)
        
        return x_reg + x_in, torch.stack(attn_weights_all)

In [31]:
class SubNetworkEncoder(nn.Module):
    def __init__(self, shapes, hidden_dim, embd_dim, num_heads, num_layers):
        super().__init__()
        self.embd_dim = embd_dim
        self.mlps = [SubnetworkEmbedder(i, hidden_dim, embd_dim) for i in shapes]
        self.mhsa_layers = nn.ModuleList([MHSA(embd_dim, num_heads) for _ in range(num_layers)])
        
    def forward(self, x):
        batch_size = x[0].shape[0]
        x = torch.stack([mlp(f) for mlp, f in zip(self.mlps, x)], dim=1)
        attn_weights_all = []
        for mhsa in self.mhsa_layers:
            x, attn_weights = mhsa(x)
            attn_weights_all.append(attn_weights)
        
        return x, torch.stack(attn_weights_all)

In [34]:
model = SubNetworkEncoder(shapes, 256, 128, 8, 4)
a = next(iter(train_loader))[1]
model(a)[1].shape

torch.Size([4, 64, 8, 7, 7])

In [37]:
b = next(iter(train_loader))[0]
model1 = RegionEncoder(200, 256, 128, 8, 4)
model1(b)[1].shape

torch.Size([4, 64, 8, 200, 200])

In [40]:
class StepOne(nn.Module):
    def __init__(self, input_dim, hidden_dim, embd_dim, num_heads, num_layers):
        super().__init__()
        self.reg_enc = RegionEncoder(input_dim, hidden_dim, embd_dim, num_heads, num_layers)
        self.subnet_enc = SubNetworkEncoder(shapes, hidden_dim, embd_dim, num_heads, num_layers)

    def forward(self, x):
        x0 = self.reg_enc(x[0])
        x1 = self.subnet_enc(x[1])
        return x0, x1

In [41]:
model = StepOne(200, 256, 128, 8, 4)
model(next(iter(train_loader)))[1][1].shape

torch.Size([4, 64, 8, 7, 7])