<a href="https://colab.research.google.com/github/kelsingo/Modeling-for-Cognitive-Reserve-Study/blob/main/Model_Implementation_Stage_1%262.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Loss Implementation

In [1]:
class CouplingLoss():
    def l1distance(true, pred):
        torch.sum(torch.abs(true - pred))

    def center_shift_loss(true, pred):
        true_centered = true - torch.mean(true, dim=-1, keepdim=True)
        pred_centered = pred - torch.mean(pred, dim=-1, keepdim=True)
        return torch.mean((pred_centered - true_centered) ** 2)

    def interval_loss(true, pred):
        diff_true = true[:, 1:] - true[:, :-1]
        diff_pred = pred[:, 1:] - pred[:, :-1]
        return torch.mean((diff_pred - diff_true) ** 2)

    def loss(self,
              bold_true, bold_pred,
              t_true, t_pred,
              width_true, width_pred,
              amp_true, amp_pred,
              coupling_true, coupling_pred,
              delay_true, delay_pred):

        # Stage 1
        LBOLD = self.l1distance(bold_true, bold_pred)

        # Ltiming = L1(t_true, t_pred) + Lcenter_shift + Linterval
        L_timing_base = self.l1distance(t_true, t_pred)
        L_center_shift = self.center_shift_loss(t_true, t_pred)

        L_interval = self.interval_loss(t_true, t_pred)

        Ltiming = L_timing_base + L_center_shift + L_interval

        # width and amplitude losses
        Lwidth = self.l1distance(width_true, width_pred)

        # Lamplitude uses mean amplitude per sample
        Lamplitude = self.l1distance(torch.mean(amp_true, dim=-1),
                             torch.mean(amp_pred, dim=-1))

        Lstage1 = LBOLD + 0.3 * Ltiming + Lwidth + Lamplitude

        # Stage 2
        Lcoupling = self.l1distance(coupling_true, coupling_pred)
        Ldelay = self.l1distance(delay_true, delay_pred)

        Lstage2 = Lcoupling + Ldelay

        # Stage 3
        Lstage3 = Lstage1 + Lstage2

        return {
            'Lstage1': Lstage1,
            'Lstage2': Lstage2,
            'Lstage3': Lstage3,
            'LBOLD': LBOLD,
            'Ltiming': Ltiming,
            'Lwidth': Lwidth,
            'Lamplitude': Lamplitude,
            'Lcoupling': Lcoupling,
            'Ldelay': Ldelay
        }

# Model Implementation

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from mamba_ssm import Mamba

class MambaModule(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.mb = Mamba(d_model=input_dim, d_state=16, d_conv=4, expand=2)

    def forward(self, x):
        out = self.mb(x)
        return out

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim=None, dropout=0.0):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = dim * 4
        self.ffn = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.ffn(x)

class Conditional_Mamba_Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_roi, dropout):
        super().__init__()
        self.linear_proj = nn.Linear(input_dim, hidden_dim)
        self.mb = MambaModule(hidden_dim)
        self.roi_adapter_weights = nn.Parameter(torch.randn(n_roi, hidden_dim, hidden_dim))
        self.roi_adapter_bias = nn.Parameter(torch.randn(n_roi, hidden_dim))
        self.norm = nn.LayerNorm(hidden_dim)
        self.drop = nn.Dropout(dropout)
        self.hidden_dim = hidden_dim

    def forward(self, x, roi_ids):
        # x: (B*N, L, Din), roi_ids: (B*N)
        h_proj = self.linear_proj(x)
        h_base = self.mb(h_proj)
        w = self.roi_adapter_weights[roi_ids]
        b = self.roi_adapter_bias[roi_ids]
        h_adapt = torch.bmm(h_base, w) + b.unsqueeze(1)
        h_out = h_base + h_adapt
        y = self.drop(self.norm(h_out))
        return y

class CrossAttention(nn.Module):
    def __init__(self, dim, heads=8, dropout=0.0, batch_first=True):
        super().__init__()
        self.norm_q = nn.LayerNorm(dim)
        self.norm_kv = nn.LayerNorm(dim)
        self.attn  = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=batch_first)
        self.drop1 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.ffn   = FeedForward(dim, dropout=dropout)
        self.drop2 = nn.Dropout(dropout)

    def forward(self, q, kv, attn_mask=None, key_padding_mask=None, need_weights=False):
        q_norm = self.norm_q(q)
        kv_norm = self.norm_kv(kv)
        ctx, attn_w = self.attn(
            query=q_norm, key=kv_norm, value=kv_norm,
            attn_mask=attn_mask, key_padding_mask=key_padding_mask,
            need_weights=need_weights
        )
        x = q + self.drop1(ctx)
        x = x + self.drop2(self.ffn(self.norm2(x)))
        return (x, attn_w) if need_weights else x

In [None]:
class HRF_Generator(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_roi, dropout):
        super().__init__()
        self.Conditional_mb = Conditional_Mamba_Encoder(input_dim, hidden_dim, n_roi, dropout)
        self.linear_proj = nn.Linear(hidden_dim, 6)

    def calculate_hrf(self, params, t):
        t = t.unsqueeze(0) + 1e-9  # (1, L)

        # Use softplus to ensure time/shape params are positive
        tp = F.softplus(params[:, 0].unsqueeze(1)) + 1e-9  # Scale 1 (peak time)
        tu = F.softplus(params[:, 1].unsqueeze(1)) + 1e-9  # Scale 2 (undershoot time)
        A  = params[:, 2].unsqueeze(1)                   # Amplitude 1 (peak)
        au = params[:, 3].unsqueeze(1)                   # Amplitude 2 (undershoot ratio)
        a1 = F.softplus(params[:, 4].unsqueeze(1)) + 1   # Shape 1
        a2 = F.softplus(params[:, 5].unsqueeze(1)) + 1   # Shape 2

        # Term 1 (Peak)
        term1 = A * torch.pow(t / tp, a1 - 1) * torch.exp(-t / tp)

        # Term 2 (Undershoot)
        term2 = au * torch.pow(t / tu, a2 - 1) * torch.exp(-t / tu)

        # Final HRF
        hrf = term1 - term2
        return hrf # (B*N, L)

    def forward(self, x, roi_ids):
        # x: (B*N, L, Din), roi_ids: (B*N)

        # 1: Get features from Conditional Mamba
        features = self.Conditional_mb(x, roi_ids)

        # 2: Project features to parameter space
        # (B*N, L, Dh) -> (B*N, L, 6)
        param_features = self.linear_proj(features)

        # 3: Pool features over time to get 6 params per signal
        # (B*N, L, 6) -> (B*N, 6)
        hrf_params = torch.mean(param_features, dim=1)

        # 4: Create the time vector
        L = x.shape[1]
        time_vector = torch.arange(L, device=x.device, dtype=x.dtype)

        # 5: Generate the HRF time series using the parameters
        # (B*N, 6) -> (B*N, L)
        hrf_timeseries = self.calculate_hrf(hrf_params, time_vector)

        # 6: Reshape to (B*N, L, 1) to match diagram
        return hrf_timeseries.unsqueeze(-1)

In [None]:
class BOLD_Deconvolver(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_roi, num_peaks, num_params, dropout):
        super().__init__()
        self.conditional_mamba = Conditional_Mamba_Encoder(input_dim, hidden_dim, n_roi, dropout)
        self.learnable_query = nn.Parameter(torch.randn(1, num_peaks, hidden_dim))
        self.cross_attention = CrossAttention(dim=hidden_dim, heads=8, dropout=dropout)
        self.param_head = nn.Linear(hidden_dim, num_params)

    def calculate_lfp_signal(self, params, t):
        t = t.unsqueeze(0) + 1e-9 # (1, L)

        # Use softplus to ensure tau and n are positive
        A = params[:, 0].unsqueeze(1)                   # Amplitude
        tau = F.softplus(params[:, 1].unsqueeze(1)) + 1e-9 # Time-scale
        n = F.softplus(params[:, 2].unsqueeze(1)) + 1    # Shape (n > 1)

        # Gamma function: A * (t/tau)^(n-1) * exp(-t/tau)
        log_base = torch.log(t / tau)
        log_term1 = (n - 1) * log_base
        term2 = -t / tau
        signal = A * torch.exp(log_term1 + term2)

        return signal # (B*N, L)

    def forward(self, x, roi_ids):
        # x: (B*N, L, Din), roi_ids: (B*N)

        # 1: Conditional Mamba
        # (B*N, L, 1) -> (B*N, L, H)
        features = self.conditional_mamba(x, roi_ids)

        # 2: Cross Attention
        query = self.learnable_query.expand(x.shape[0], -1, -1)
        # (B*N, num_peaks, H)
        context_vector = self.cross_attention(q=query, kv=features)

        # 3: Linear (to LFP-like parameters)
        pooled_context = torch.mean(context_vector, dim=1) # (B*N, H)
        lfp_params = self.param_head(pooled_context) # (B*N, 3)

        # 4: Deconv BOLD (Generate time series)
        L = x.shape[1]
        time_vector = torch.arange(L, device=x.device, dtype=x.dtype)

        # (B*N, 3) -> (B*N, L)
        deconv_signal = self.calculate_lfp_signal(lfp_params, time_vector)

        # 5: Reshape to match diagram output: (B*N, L, 1)
        return deconv_signal.unsqueeze(-1)

In [None]:
class Stage1Model(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_roi, dropout,
                 num_peaks=5, lfp_params=3):
        super().__init__()
        self.hrf_generator = HRF_Generator(input_dim, hidden_dim, n_roi, dropout)
        self.bold_deconvolver = BOLD_Deconvolver(input_dim, hidden_dim, n_roi, num_peaks, lfp_params, dropout)

    def forward(self, x, roi_ids):
        # 0: Get dimensions
        B, N, L, Din = x.shape

        # 1: Reshape BOLD
        # (B, N, L, 1) -> (B*N, L, 1)
        x_flat = x.reshape(B*N, L, Din)

        # Also flatten roi_ids: (B, N) -> (B*N)
        ids_flat = roi_ids.reshape(B*N)

        # 2: Pass through both branches
        # (B*N, L, 1)
        hrf_signal = self.hrf_generator(x_flat, ids_flat)

        # (B*N, L, 1)
        deconv_signal = self.bold_deconvolver(x_flat, ids_flat)

        # 3: Convolve deconv BOLD (neural) with HRF
        # Both inputs must be (B*N, 1, L) for conv1d
        deconv_permuted = deconv_signal.permute(0, 2, 1) # (B*N, 1, L)
        hrf_permuted = hrf_signal.permute(0, 2, 1)     # (B*N, 1, L)

        # Use padding='same' to keep length L
        # Use groups=B*N for depthwise conv (each signal convolved with its own HRF)
        reconstructed_flat = F.conv1d(
            deconv_permuted,
            hrf_permuted,
            padding='same',
            groups=B*N
        ) # Output shape: (B*N, 1, L)

        # 4: Reshape back to (B, N, L, 1)
        reconstructed_bold = reconstructed_flat.permute(0, 2, 1).reshape(B, N, L, 1)

        return reconstructed_bold

In [None]:
class CausalityMapper(nn.Module):
    def __init__(self, hidden_dim, mlp_hidden_dim):
        super().__init__()

        # The input to the MLPs will be 2*H (from concatenation)
        mlp_input_dim = 2 * hidden_dim

        # MLP for Coupling Strengths
        self.mlp_strength = nn.Sequential(
            nn.Linear(mlp_input_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, 1)
        )

        # MLP for Delay Timings
        self.mlp_delay = nn.Sequential(
            nn.Linear(mlp_input_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, 1)
        )

    def forward(self, features):
        # 1. Mean pool over time (L)
        # (B, N, L, H) -> (B, N, H)
        node_features = torch.mean(features, dim=2)

        # 2. Expand & Concat to create pairwise features
        B, N, H = node_features.shape

        # (B, N, H) -> (B, N, 1, H) -> (B, N, N, H)
        feat_i = node_features.unsqueeze(2).expand(-1, -1, N, -1)

        # (B, N, H) -> (B, 1, N, H) -> (B, N, N, H)
        feat_j = node_features.unsqueeze(1).expand(-1, N, -1, -1)

        # (B, N, N, 2*H)
        pairwise_features = torch.cat([feat_i, feat_j], dim=-1)

        # 3. Pass through MLPs
        # (B, N, N, 2*H) -> (B, N, N, 1) -> (B, N, N)
        coupling_strengths = self.mlp_strength(pairwise_features).squeeze(-1)

        # (B, N, N, 2*H) -> (B, N, N, 1) -> (B, N, N)
        delay_timings = self.mlp_delay(pairwise_features).squeeze(-1)

        return coupling_strengths, delay_timings

In [None]:
class Stage2Model(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_roi, dropout, mlp_hidden_dim=None):
        super().__init__()

        if mlp_hidden_dim is None:
            mlp_hidden_dim = hidden_dim

        self.n_roi = n_roi
        self.conditional_mamba = Conditional_Mamba_Encoder(input_dim, hidden_dim, n_roi, dropout)
        self.causality_mapper = CausalityMapper(hidden_dim, mlp_hidden_dim)

    def forward(self, deconv_bold_flat, roi_ids_flat):
        # 1. Pass through Conditional Mamba
        # (B*N, L, 1) -> (B*N, L, H)
        features_flat = self.conditional_mamba(deconv_bold_flat, roi_ids_flat)

        # 2. Reshape features for Causality Mapper
        # (B*N, L, H) -> (B, N, L, H)
        BN, L, H = features_flat.shape
        B = BN // self.n_roi
        N = self.n_roi
        features = features_flat.reshape(B, N, L, H)

        # 3. Pass through Causality Mapper
        # (B, N, L, H) -> (B, N, N), (B, N, N)
        coupling_strengths, delay_timings = self.causality_mapper(features)

        return coupling_strengths, delay_timings