In [1]:
import torch
import torch.nn as nn
import yaml
import einops
import torchvision
import os
from tqdm import tqdm
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt
import numpy as np
from torchvision.transforms.functional import to_pil_image
import torch.nn.functional as F
import umap.umap_ as umap
from utils import make_coord_grid
import yaml
from datasets import make as make_dataset
from models import make as make_model
import pickle
from sklearn.manifold import TSNE
from datasets.fmri_dataloader import DataModule
from mamba_ssm import Mamba
from mamba_ssm.modules.block import Block
from mamba_ssm.ops.triton.layer_norm import RMSNorm
import math
from patchembed import PatchEmbed
from utils import make_coord_grid
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [2]:
print(device)

cuda:0


In [3]:
data_module = DataModule('./cfgs/hcp_data_all_config_explore.yaml')

In [4]:
data_module.setup()
train_loader = data_module.train_dataloader()
test_loader = data_module.val_dataloader()

Loading cached subject dictionary from data/subj_dict/S1200_age_HCP_filtered_run1_MNI_to_TRs.pickle
Randomly determining splits and saving to ./data/splits/S1200_split100.pkl
Randomly sampling 400 of 863 subjects.
Generating dataset list from scratch...
Number of subjects for S1200 'train': 400
Randomly sampling 50 of 107 subjects.
Loading cached dataset list from data/data_tuple/age_HCP_filtered_run1_MNI_to_TRs_val_seqlen2_within1_between1.0.csv
Cache is stale (subject mismatch). Regenerating...
Generating dataset list from scratch...
Number of subjects for S1200 'val': 50
Randomly sampling 50 of 109 subjects.
Loading cached dataset list from data/data_tuple/age_HCP_filtered_run1_MNI_to_TRs_test_seqlen2_within1_between1.0.csv
Cache is stale (subject mismatch). Regenerating...
Generating dataset list from scratch...
Number of subjects for S1200 'test': 50

Total training segments: 4800
Total validation segments: 600
Total test segments: 600


In [None]:
for batch in test_loader:
    fmri, subj, target_value, tr, sex = batch.values()
    break

In [None]:
print(fmri.shape)

In [9]:
class LAINRDecoder(nn.Module):
    def __init__(self, feature_dim, input_dim, output_dim, sigma_q, sigma_ls, n_patches, hidden_dim = 256, context_dim = None):
        super().__init__()
        self.layer_num = len(sigma_ls)
        self.n = feature_dim//(2*input_dim)
        self.omegas = torch.logspace(1, math.log10(sigma_q), self.n)
        self.patch_num = int(math.sqrt(n_patches))
        
        '''self.omegas_l = {str(sig_l):torch.logspace(1, math.log10(sig_l), self.n) for sig_l in sigma_ls}
        self.bandwidth_lins_lins = nn.ModuleDict({
            str(sig_l): nn.Linear(feature_dim, hidden_dim) for sig_l in sigma_ls
        })
        self.modulation_lins = nn.ModuleDict({
            str(sig_l): nn.Linear(hidden_dim, hidden_dim) for sig_l in sigma_ls
        })'''

        self.alpha = 10.0

        self.omegas_l = [torch.logspace(1, math.log10(sigma_ls[i]), self.n) for i in range(self.layer_num)]

        self.query_lin = nn.Linear(feature_dim, hidden_dim)

        #self.modulation_ca = PerPixelCrossAttention(query_dim = hidden_dim, heads=2)
        self.modulation_ca = SharedTokenCrossAttention(query_dim = hidden_dim, heads=2)
        
        self.bandwidth_lins = nn.ModuleList([
            nn.Linear(feature_dim, hidden_dim) for i in range(self.layer_num)
                                       ])
        
        self.modulation_lins = nn.ModuleList([
            nn.Linear(hidden_dim, hidden_dim) for i in range(self.layer_num)
        ])

        self.hv_lins = nn.ModuleList([
            nn.Linear(hidden_dim, hidden_dim) for _ in range(len(sigma_ls) - 1)
        ])

        self.out_lins = nn.ModuleList([
            nn.Linear(hidden_dim, output_dim) for _ in range(len(sigma_ls))
        ])
        self.act = nn.ReLU()
        

    def calc_gamma(self, x, omegas):
        #x is passed as H*W, D
        L = x.shape[0]
        coords = x.unsqueeze(-1)  # (H*W, 2, 1)
        omegas = omegas.view(1, 1, -1).to(x.device)  # (1, 1, F)
        
        
        arg = torch.pi * coords * omegas  # shape: (B, 2, F)
        sin_part = torch.sin(arg)
        cos_part = torch.cos(arg)
        
        gamma = torch.cat([sin_part, cos_part], dim=-1).view(L, -1)  
        
        return gamma

    def get_patch_index(self, grid, H, W):
        y = grid[:, 0]
        x = grid[:, 1]
        row = (y * H).to(torch.int32)
        col = (x * W).to(torch.int32)
        return row * W + col  # index in [0, N-1]

    def approximate_relative_distances(self, target_index, H, W, m):
        alpha = self.alpha
        N = H * W  
        t = target_index / N
        token_positions = torch.tensor([(i + 0.5) / m for i in range(m)])
    
        rel_distances = -1*alpha*torch.stack([torch.abs((t - s)**2) for s in token_positions], dim = 0)
        return rel_distances

        
    def forward(self, x, tokens):
        B, query_shape = x.shape[0], x.shape[1: -1]
        x = x.view(B, -1, x.shape[-1]) #B, HW, 2
        grid = x[0]
        indexes = self.get_patch_index(grid, self.patch_num, self.patch_num)
        rel_distances = self.approximate_relative_distances(indexes, self.patch_num, self.patch_num, len(tokens[0]))
        bias = rel_distances.transpose(1, 0)
        bias = einops.repeat(bias, 'l n -> b l n', b=B) #B, L, HW
        x_q = einops.repeat(self.calc_gamma(x[0], self.omegas), 'l d -> b l d', b=B) #B, HW, input_dim
        x_q = self.act(self.query_lin(x_q))


        #(B, HW, 1, D)
        #(B, HW, L, D)

        
        '''tokens = einops.repeat(tokens, 'b l d -> b p l d', p=query_shape[0]*query_shape[1])
        modulation_vector = self.modulation_ca(x_q.unsqueeze(2), context = tokens).squeeze(2)'''
        modulation_vector = self.modulation_ca(x_q, context = tokens, bias = bias)

        modulations_l = []
        h_f = []


        for k in range(self.layer_num):
            x_l = einops.repeat(self.calc_gamma(x[0], self.omegas_l[k]), 'l d -> b l d', b=B)
            h_l = self.act(self.bandwidth_lins[k](x_l))
            h_f.append(h_l)
            m_l = self.act(h_l + self.modulation_lins[k](modulation_vector))
            modulations_l.append(m_l)

        h_v = [modulations_l[0]]

        for i in range(self.layer_num - 1):
            h_vl = self.act(self.hv_lins[i](modulations_l[i+1] + h_v[i]))
            h_v.append(h_vl)

        outs = [self.out_lins[i](h_v[i]) for i in range(self.layer_num)]

        out = sum(outs)
        out = out.view(B, *query_shape, -1)  # (B, H, W, output_dim)

        return out

# helpers
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def cache_fn(f):
    cache = None
    @wraps(f)
    def cached_fn(*args, _cache = True, **kwargs):
        if not _cache:
            return f(*args, **kwargs)
        nonlocal cache
        if cache is not None:
            return cache
        cache = f(*args, **kwargs)
        return cache
    return cached_fn

# helper classes
class PreNorm(nn.Module):
    def __init__(self, dim, fn, context_dim = None):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None

    def forward(self, x, **kwargs):
        x = self.norm(x)
        if exists(self.norm_context):
            context = kwargs['context']
            normed_context = self.norm_context(context)
            kwargs.update(context = normed_context)
        return self.fn(x, **kwargs)

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Linear(dim * mult, dim)
        )
    def forward(self, x):
        return self.net(x)

class SharedTokenCrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=2, dim_head=64):
        super().__init__()
        context_dim = default(context_dim, query_dim)
        inner_dim = dim_head * heads
        self.heads = heads
        self.dim_head = dim_head
        self.scale = dim_head ** -0.5

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, query_dim)

    def forward(self, x, context, bias=None):
        # x: (B, HW, D)         ← 1 query per pixel (you can squeeze that 1)
        # context: (B, L, D)       ← shared tokens

        B, HW, D = x.shape

        H = self.heads
        Dh = self.dim_head
        D_inner = H * Dh

        q = self.to_q(x)              # (B, HW, H*Dh)
        kv = self.to_kv(context)      # (B, L, 2*H*Dh)
        k, v = kv.chunk(2, dim=-1)    # (B, L, H*Dh)

        # Reshape
        q = q.view(B, HW, H, Dh).transpose(1, 2)   # (B, H, HW, Dh)
        k = k.view(B, -1, H, Dh).transpose(1, 2)   # (B, H, L, Dh)
        v = v.view(B, -1, H, Dh).transpose(1, 2)   # (B, H, L, Dh)

        # Attention
        sim = torch.matmul(q, k.transpose(-1, -2)) * self.scale  # (B, H, HW, L)
        if bias != None:
            bias = einops.repeat(bias, 'b l n -> b h l n', h=H) #B, L, HW
            sim += bias
        attn = sim.softmax(dim=-1)
        out = torch.matmul(attn, v)                              # (B, H, HW, Dh)

        out = out.transpose(1, 2).contiguous().view(B, HW, D_inner)  # (B, HW, H*Dh)
        out = self.to_out(out)                                        # (B, HW, D)
        return out                                   # (B, HW, 1, D)

In [23]:
class BiMamba(torch.nn.Module):
    def __init__(self, dim = 512):
        super(BiMamba, self).__init__()
        
        self.f_mamba = Mamba(d_model = dim)
        self.r_mamba = Mamba(d_model = dim)
        
    def forward(self, x, **kwargs):
        x_f = self.f_mamba(x, **kwargs)
        x_r = torch.flip(self.r_mamba(torch.flip(x, dims=[1]), **kwargs), dims=[1])
        out = (x_f + x_r)/2
        
        return out
   
class MambaEncoder(torch.nn.Module):
    def __init__(self, depth = 6, dim = 768, ff_dim = None, dropout=0.):
        super(MambaEncoder, self).__init__()
        if not ff_dim:
            self.ff_dim = 4*dim
        else: 
            self.ff_dim = ff_dim
        token_dim = dim
        self.blocks = nn.ModuleList([
            Block(
                dim=token_dim,
                mixer_cls= lambda dim: BiMamba(dim),
                mlp_cls= lambda dim: torch.nn.Sequential(
                    nn.Linear(dim, self.ff_dim),
                    nn.GELU(),
                    nn.Dropout(dropout),
                    nn.Linear(self.ff_dim, dim),
                    nn.Dropout(dropout),
                ),
                norm_cls= nn.LayerNorm,  # or RMSNorm, 
                fused_add_norm=False
            )
            for _ in range(depth)
        ])
    
    def forward(self, x):
        residual = None
        for block in self.blocks:
            x, residual = block(x, residual=residual)
        return x
        
def init_wb(shape):
    weight = torch.empty(shape[1], shape[0] - 1)
    nn.init.kaiming_uniform_(weight, a=math.sqrt(5))

    bias = torch.empty(shape[1], 1)
    fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight)
    bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
    nn.init.uniform_(bias, -bound, bound)

    return torch.cat([weight, bias], dim=1).t().detach()

class MambaInr(torch.nn.Module):
    
    def __init__(self, hyponet, mamba_encoder, tokenizer = None, num_lp = 128, type = 'equidistant', n_group = 1, latent_token_len = 64):
        super().__init__()

        #self.dim = mamba_encoder['args']['dim'] #replace with mamba_encoder
        self.dim = 64
        self.latent_token_len = latent_token_len
        #self.tokenizer = models.make(tokenizer, args={'dim': self.dim}) #replace w correct tokenizer
        #self.hyponet = models.make(hyponet, args={'hidden_dim': self.dim})
        self.hyponet = hyponet
        self.mamba_encoder = mamba_encoder

        #self.mamba_encoder = models.make(mamba_encoder) #replace w mamba
        #self.input_len = self.tokenizer.n_patches
        self.input_len = 432
        #self.gen_layers = gen_layers
        self.type = type
        self.num_lp = num_lp
        self.patchifier = PatchEmbed(img_size=(96, 96, 96, 2),
        patch_size=(16, 16, 16, 1), embed_dim = 64)
        self.n_group = n_group

                
            
        self.lps = nn.Parameter(torch.randn(self.num_lp, self.dim))
        self.lp_idxs = None
        self.set_lp_idxs(self.input_len, type = self.type, n = self.n_group)
        self.perm = self.compute_interleave_permutation(self.input_len, self.num_lp)

    def set_lp_idxs(self, seq_len, type = 'equidistant', n = 1):
        total_len = seq_len + self.num_lp
        if type == 'equidistant':
            insert_idxs = torch.linspace(0, total_len - 1, steps=self.num_lp).long()
            self.lp_idxs = insert_idxs
        elif type == 'middle':
            insert_idxs = (np.array(range(self.num_lp))+(seq_len//2)).tolist()
            self.lp_idxs = insert_idxs
        elif type == 'n_group':
            if self.num_lp%n != 0:
                raise Exception("n must divide number of lps evenly")
            insert_idxs = []
            pre_idxs = torch.linspace(0, total_len - n, steps=self.num_lp//n).long()
            for idx in pre_idxs:
                insert_idxs.extend([idx+i for i in range(n)])
            self.lp_idxs = insert_idxs
            

    def add_lp(self, x):
        B, L, D = x.shape
        w = einops.repeat(self.lps, 'n d -> b n d', b=B)  # (B, N, D)
        x_full = torch.cat([x, w], dim=1)  # (B, L + N, D)
        x_perm = x_full[:, self.perm]  # (B, L + N, D) — interleaved

        return x_perm
    
    def extract_lp_tokens(self, x):
        return x[:, self.lp_idxs]

    def compute_interleave_permutation(self, seq_len, n_insert):
        total_len = seq_len + n_insert
        insert_idxs = torch.linspace(0, total_len - 1, steps=n_insert).long()  
        token_ids = torch.arange(seq_len + n_insert)
        perm = torch.full((total_len,), -1, dtype=torch.long)
        

        perm[insert_idxs] = torch.arange(seq_len, seq_len + n_insert)
        input_token_ids = torch.arange(seq_len)
        perm[perm == -1] = input_token_ids
        
        return perm

    def scan(self, x):
        
        B, pD, pH, pW, pT, D = x.shape
        x = x.permute(0, 4, 1, 2, 3, 5)  # (B, pT, pD, pH, pW, D)
        #x = x.permute(0, 2, 3, 4, 1, 5)
    
        # Flatten spatial+time dims into one dimension
        x = x.reshape(B, pD*pH*pW*pT, D)
        return x

    def forward(self, data, coord):
        all_data = data.float() # (B, C, Z, H, W, T)
        B = all_data.shape[0]
        dtokens = self.patchifier(all_data) #(B, pD, pH, pW, pT, D)
        dtokens = self.scan(dtokens) # (B, pD*pH*pW*pT, D)
        #dtokens = self.tokenizer(data)
        #Here, delete patches and keep track of which ones are deleted. Can create the pixel mask here too and return it
        #lps = einops.repeat(self.lps, 'n d -> b n d', b=B)
        all_tokens = self.add_lp(dtokens)
        mamba_out = self.mamba_encoder(all_tokens)
        mamba_out = self.extract_lp_tokens(mamba_out)
        pred = self.hyponet(coord, mamba_out)

        return pred


In [26]:
hyponet = LAINRDecoder(feature_dim = 64, input_dim = 4, output_dim = 1, sigma_q = 16, sigma_ls = [128, 32], hidden_dim = 64, n_patches = 432).to(device) #6x6x6x2
mamba_encoder = MambaEncoder(dim = 64).to(device)
MambaINR = MambaInr(hyponet, mamba_encoder).to(device)

torch.Size([1, 6, 6, 6, 2, 64])


In [27]:
gt = fmri[:2].to(device)
B = gt.shape[0]
coord = make_coord_grid(gt.shape[-4:], (0, 1), device=gt.device).to(device)
#coord = self.add_gaussian_noise_to_grid(coord)
coord = einops.repeat(coord, 'z h w t d -> b z h w t d', b=B)
print(coord.shape)

pred = MambaINR(gt, coord)

torch.Size([2, 96, 96, 96, 2, 4])
torch.float32


In [28]:
print(pred.shape)

torch.Size([2, 96, 96, 96, 2, 1])
