In [1]:
! pip install einops

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0


In [16]:
from abc import abstractmethod
import torch.nn as nn
import torch
import math
from typing import Optional
import warnings

class BaseAttention(nn.Module):
    @abstractmethod
    def __init__(self):
        super().__init__()

    @abstractmethod
    def forward(self, x, context=None, mask=None):
        pass

def scaled_multihead_dot_product_attention(
    query,
    key,
    value,
    heads,
    past_key_value=None,
    softmax_scale=None,
    bias=None,
    key_padding_mask=None,
    causal=False,
    dropout=0.0,
    training=False,
    needs_weights=False,
    multiquery=False,
):
    q = rearrange(query, 'b s (h d) -> b h s d', h=heads)
    kv_heads = 1 if multiquery else heads
    k = rearrange(key, 'b s (h d) -> b h d s', h=kv_heads)
    v = rearrange(value, 'b s (h d) -> b h s d', h=kv_heads)

    if past_key_value is not None:
        # attn_impl: flash & triton use kernels which expect input shape [b, s, h, d_head].
        # kv_cache is therefore stored using that shape.
        # attn_impl: torch stores the kv_cache in the ordering which is most advantageous
        # for its attn computation ie
        # keys are stored as tensors with shape [b, h, d_head, s] and
        # values are stored as tensors with shape [b, h, s, d_head]
        if len(past_key_value) != 0:
            k = torch.cat([past_key_value[0], k], dim=3)
            v = torch.cat([past_key_value[1], v], dim=2)

        past_key_value = (k, v)

    b, _, s_q, d = q.shape
    s_k = k.size(-1)

    if softmax_scale is None:
        softmax_scale = 1 / math.sqrt(d)

    attn_weight = q.matmul(k) * softmax_scale

    if bias is not None:
        # clamp to 0 necessary for torch 2.0 compile()
        _s_q = max(0, bias.size(2) - s_q)
        _s_k = max(0, bias.size(3) - s_k)
        bias = bias[:, :, _s_q:, _s_k:]

        if (bias.size(-1) != 1 and
                bias.size(-1) != s_k) or (bias.size(-2) != 1 and
                                               bias.size(-2) != s_q):
            raise RuntimeError(
                f'bias (shape: {bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.'
            )
        attn_weight = attn_weight + bias

    min_val = torch.finfo(q.dtype).min

    if key_padding_mask is not None:
        if bias is not None:
            warnings.warn(
                'Propogating key_padding_mask to the attention module ' +\
                'and applying it within the attention module can cause ' +\
                'unneccessary computation/memory usage. Consider integrating ' +\
                'into bias once and passing that to each attention ' +\
                'module instead.'
            )
        attn_weight = attn_weight.masked_fill(
            ~key_padding_mask.view((b, 1, 1, s_k)), min_val)

    if causal and (not q.size(2) == 1):
        s = max(s_q, s_k)
        causal_mask = attn_weight.new_ones(s, s, dtype=torch.float32)
        causal_mask = causal_mask.tril()
        causal_mask = causal_mask.to(torch.bool)
        causal_mask = ~causal_mask
        causal_mask = causal_mask[-s_q:, -s_k:]
        attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k),
                                              min_val)

    attn_weight = torch.softmax(attn_weight, dim=-1)

    if dropout:
        attn_weight = torch.nn.functional.dropout(attn_weight,
                                                  p=dropout,
                                                  training=training,
                                                  inplace=True)

    out = attn_weight.to(v.dtype).matmul(v)
    out = rearrange(out, 'b h s d -> b s (h d)')

    if needs_weights:
        return out, attn_weight, past_key_value
    return out, None, past_key_value

class MultiQueryAttention(BaseAttention):
    """Multi-Query self attention.

    Using torch or triton attention implemetation enables user to also use
    additive bias.

    Look for documentation
    """

    def __init__(
        self,
        d_model: int,
        heads: int,
        attn_impl: str = "torch",
        clip_qkv: Optional[float] = None,
        qk_ln: bool = False,
        softmax_scale: Optional[float] = None,
        attn_pdrop: float = 0.0,
        norm_type: str = "low_precision_layernorm",
        fc_type: str = "torch",
        verbose: int = 0,
        device: Optional[str] = None,
    ):
        super().__init__()

        self.attn_impl = attn_impl
        self.clip_qkv = clip_qkv
        self.qk_ln = qk_ln

        self.d_model = d_model
        self.heads = heads
        self.head_dim = d_model // heads
        self.softmax_scale = softmax_scale
        if self.softmax_scale is None:
            self.softmax_scale = 1 / math.sqrt(self.head_dim)
        self.attn_dropout = attn_pdrop

        fc_kwargs = {}
        if fc_type != "te":
            fc_kwargs["device"] = device
        # - vchiley
        self.Wqkv = nn.Linear(
            d_model,
            d_model + 2 * self.head_dim,
            **fc_kwargs,
        )
        # for param init fn; enables shape based init of fused layers
        fuse_splits = (d_model, d_model + self.head_dim)
        self.Wqkv._fused = (0, fuse_splits)  # type: ignore

        if self.qk_ln:
            norm_class = nn.LayerNorm
            self.q_ln = norm_class(d_model, device=device)
            self.k_ln = norm_class(self.head_dim, device=device)

        self.attn_fn = scaled_multihead_dot_product_attention
        if torch.cuda.is_available() and verbose:
            warnings.warn(
                "Using `attn_impl: torch`. If your model does not use"
                " `alibi` or "
                + "`prefix_lm` we recommend using `attn_impl: flash`"
                " otherwise "
                + "we recommend using `attn_impl: triton`."
            )


        self.out_proj = nn.Linear(
            self.d_model,
            self.d_model,
            **fc_kwargs,
        )
        self.out_proj._is_residual = True  # type: ignore

    def forward(
        self,
        x,
        past_key_value=None,
        bias=None,
        mask=None,
        causal=True,
        needs_weights=False,
    ):
        qkv = self.Wqkv(x)

        if self.clip_qkv:
            qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)

        query, key, value = qkv.split(
            [self.d_model, self.head_dim, self.head_dim], dim=2
        )

        key_padding_mask = mask

        if self.qk_ln:
            # Applying layernorm to qk
            dtype = query.dtype
            query = self.q_ln(query).to(dtype)
            key = self.k_ln(key).to(dtype)

        context, attn_weights, past_key_value = self.attn_fn(
            query,
            key,
            value,
            self.heads,
            past_key_value=past_key_value,
            softmax_scale=self.softmax_scale,
            bias=bias,
            key_padding_mask=key_padding_mask,
            causal=causal,
            dropout=self.attn_dropout,
            training=self.training,
            needs_weights=needs_weights,
            multiquery=True,
        )

        return self.out_proj(context), attn_weights, past_key_value

In [19]:
from torch import nn, Tensor
from einops import rearrange, reduce



class FeedForward(nn.Module):
    def __init__(self, in_dim, hidden_dim, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(
                hidden_dim, in_dim
            ),
        )

    def forward(self, x):
        return self.net(x)

def threed_to_text(
    x: Tensor, max_seq_len: int, dim: int, flatten: bool = False
):
    """
    Converts a 3D tensor to text representation.

    Args:
        x (Tensor): The input tensor of shape (batch_size, sequence_length, input_dim).
        max_seq_len (int): The maximum sequence length of the output tensor.
        dim (int): The dimension of the intermediate tensor.
        flatten (bool, optional): Whether to flatten the intermediate tensor. Defaults to False.

    Returns:
        Tensor: The output tensor of shape (batch_size, max_seq_len, input_dim).
    """
    b, s, d = x.shape

    x = nn.Linear(d, dim)(x)

    x = rearrange(x, "b s d -> b d s")
    x = nn.Linear(s, max_seq_len)(x)
    x = rearrange(x, "b d s -> b s d")
    return x



def scatter3d_to_4d_spatial(x: Tensor, dim: int):
    """
    Scatters a 3D tensor into a 4D spatial tensor using einops.

    Args:
        x (Tensor): The input tensor of shape (b, s, d), where b is the batch size, s is the spatial dimension, and d is the feature dimension.
        dim (int): The dimension along which to scatter the tensor.

    Returns:
        Tensor: The scattered 4D spatial tensor of shape (b, (s*s1), d), where s1 is the new spatial dimension after scattering.
    """
    b, s, d = x.shape

    # Scatter the 3D tensor into a 4D spatial tensor
    x = rearrange(x, "b s d -> b (s s1) d")

    return x



class EEGConvEmbeddings(nn.Module):
    def __init__(
        self,
        num_channels,
        conv_channels,
        kernel_size,
        stride=1,
        padding=0,
    ):
        """
        Initializes the EEGConvEmbeddings module.

        Args:
        - num_channels (int): Number of EEG channels in the input data.
        - conv_channels (int): Number of output channels for the convolutional layer.
        - kernel_size (int): Size of the convolutional kernel.
        - stride (int, optional): Stride of the convolution. Default: 1.
        - padding (int, optional): Padding added to both sides of the input. Default: 0.
        """
        super(EEGConvEmbeddings, self).__init__()

        self.conv1 = nn.Conv1d(
            in_channels=num_channels,
            out_channels=conv_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )

        # Additional layers and operations can be added here

    def forward(self, x):
        """
        Forward pass of the EEGConvEmbeddings module.

        Args:
        - x (Tensor): Input tensor of shape (batch_size, num_channels, time_samples)

        Returns:
        - Tensor: Output tensor after convolution
        """
        x = self.conv1(x)
        return x


class FMRIEmbedding(nn.Module):
    def __init__(
        self,
        in_channels=1,
        out_channels=32,
        kernel_size=3,
        stride=1,
        padding=1,
    ):
        """
        Initializes an fMRI Embedding Network.

        Args:
        - in_channels (int): Number of input channels (scans/modalities).
        - out_channels (int): Number of output channels for the convolutional layer.
        - kernel_size (int): Size of the convolutional kernels.
        - stride (int): Stride of the convolutions.
        - padding (int): Padding added to the input.

        Example:
        model = fMRIEmbeddingNet()
        x = torch.randn(1, 1, 32, 32, 32)
        input_tensor = torch.randn(8, 1, 64, 64, 64)  # 8 fMRI scans
        output_tensor = model(input_tensor)
        print(output_tensor.shape)  # torch.Size([8, 32, 64, 64, 64])


        """
        super(FMRIEmbedding, self).__init__()

        self.conv1 = nn.Conv3d(
            in_channels, out_channels, kernel_size, stride, padding
        )
        # Additional layers can be added here as needed

    def forward(self, x):
        """
        Forward pass of the fMRI Embedding Network.

        Args:
        - x (Tensor): Input tensor of shape (batch_size, in_channels, D, H, W)

        Returns:
        - Tensor: Output embedding tensor
        """
        x = self.conv1(x)
        # Additional operations can be added here as needed
        return x


class MorpheusEncoder(nn.Module):
    """
    MorpheusEncoder is a module that performs encoding on EEG data using multi-head attention and feed-forward networks.

    Args:
        dim (int): The dimension of the input data.
        heads (int): The number of attention heads.
        depth (int): The number of layers in the encoder.
        dim_head (int): The dimension of each attention head.
        dropout (int): The dropout rate.
        num_channels (int): The number of input channels in the EEG data.
        conv_channels (int): The number of output channels after the convolutional layer.
        kernel_size (int): The size of the convolutional kernel.
        stride (int, optional): The stride of the convolutional layer. Defaults to 1.
        padding (int, optional): The padding size for the convolutional layer. Defaults to 0.
        ff_mult (int, optional): The multiplier for the feed-forward network hidden dimension. Defaults to 4.

    Attributes:
        dim (int): The dimension of the input data.
        heads (int): The number of attention heads.
        depth (int): The number of layers in the encoder.
        dim_head (int): The dimension of each attention head.
        dropout (int): The dropout rate.
        num_channels (int): The number of input channels in the EEG data.
        conv_channels (int): The number of output channels after the convolutional layer.
        kernel_size (int): The size of the convolutional kernel.
        stride (int): The stride of the convolutional layer.
        padding (int): The padding size for the convolutional layer.
        ff_mult (int): The multiplier for the feed-forward network hidden dimension.
        mha (MultiheadAttention): The multi-head attention module.
        ffn (FeedForward): The feed-forward network module.
        eeg_embedding (EEGConvEmbeddings): The EEG convolutional embedding module.

    """

    def __init__(
        self,
        dim: int,
        heads: int,
        depth: int,
        dim_head: int,
        dropout: int,
        num_channels,
        conv_channels,
        kernel_size,
        stride=1,
        padding=0,
        ff_mult: int = 4,
        *args,
        **kwargs,
    ):
        super(MorpheusEncoder, self).__init__()
        self.dim = dim
        self.heads = heads
        self.depth = depth
        self.dim_head = dim_head
        self.dropout = dropout
        self.num_channels = num_channels
        self.conv_channels = conv_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.ff_mult = ff_mult

        self.mha = nn.MultiheadAttention(
            dim,
            heads,
            dropout
        )

        self.ffn = FeedForward(dim, dim, dropout, *args, **kwargs)

        self.eeg_embedding = EEGConvEmbeddings(
            num_channels, conv_channels, kernel_size, stride, padding
        )

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass of the MorpheusEncoder module.

        Args:
            x (Tensor): The input tensor of shape (batch_size, seq_len, dim).

        Returns:
            Tensor: The output tensor of shape (batch_size, seq_len, dim).

        """
        x = self.eeg_embedding(x)
        # print(x.shape)

        x = self.mha(x, x, x) + x

        x = self.ffn(x) + x

        return x


class MorpheusDecoder(nn.Module):
    """
    MorpheusDecoder is a module that performs decoding in the Morpheus model.

    Args:
        dim (int): The dimension of the input and output tensors.
        heads (int): The number of attention heads.
        depth (int): The number of layers in the decoder.
        dim_head (int): The dimension of each attention head.
        dropout (int): The dropout rate.
        num_channels: The number of channels in the input tensor.
        conv_channels: The number of channels in the convolutional layers.
        kernel_size: The size of the convolutional kernel.
        in_channels: The number of input channels for the FMRI embedding.
        out_channels: The number of output channels for the FMRI embedding.
        stride (int, optional): The stride of the convolutional layers. Defaults to 1.
        padding (int, optional): The padding size for the convolutional layers. Defaults to 0.
        ff_mult (int, optional): The multiplier for the feed-forward network dimension. Defaults to 4.

    Attributes:
        dim (int): The dimension of the input and output tensors.
        heads (int): The number of attention heads.
        depth (int): The number of layers in the decoder.
        dim_head (int): The dimension of each attention head.
        dropout (int): The dropout rate.
        num_channels: The number of channels in the input tensor.
        conv_channels: The number of channels in the convolutional layers.
        kernel_size: The size of the convolutional kernel.
        stride (int): The stride of the convolutional layers.
        padding (int): The padding size for the convolutional layers.
        ff_mult (int): The multiplier for the feed-forward network dimension.
        frmi_embedding (nn.Linear): The linear layer for FRMI embedding.
        masked_attn (MultiQueryAttention): The masked attention module.
        mha (MultiheadAttention): The multihead attention module.
        frmni_embedding (FMRIEmbedding): The FMRI embedding module.
        ffn (FeedForward): The feed-forward network module.
        proj (nn.Linear): The linear layer for projection to original dimension.
        softmax (nn.Softmax): The softmax activation function.
        encoder (MorpheusEncoder): The MorpheusEncoder module.

    """

    def __init__(
        self,
        dim: int,
        heads: int,
        depth: int,
        dim_head: int,
        dropout: int,
        num_channels,
        conv_channels,
        kernel_size,
        in_channels,
        out_channels,
        stride=1,
        padding=0,
        ff_mult: int = 4,
        *args,
        **kwargs,
    ):
        super(MorpheusDecoder, self).__init__()
        self.dim = dim
        self.heads = heads
        self.depth = depth
        self.dim_head = dim_head
        self.dropout = dropout
        self.num_channels = num_channels
        self.conv_channels = conv_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.ff_mult = ff_mult

        self.frmi_embedding = nn.Linear(num_channels, dim)

        self.masked_attn = MultiQueryAttention(
            dim,
            heads,
        )

        self.mha = nn.MultiheadAttention(
            dim,
            heads,
            dropout
        )

        self.frmni_embedding = FMRIEmbedding(
            in_channels, out_channels, kernel_size, stride, padding
        )

        self.ffn = FeedForward(dim, dim, dropout, *args, **kwargs)

        self.proj = nn.Linear(dim, num_channels)

        self.softmax = nn.Softmax(1)

        self.encoder = MorpheusEncoder(
            dim,
            heads,
            depth,
            dim_head,
            dropout,
            num_channels,
            conv_channels,
            kernel_size,
            stride,
            padding,
            ff_mult,
            *args,
            **kwargs,
        )

    def forward(self, frmi: Tensor, eeg: Tensor) -> Tensor:
        """
        Perform forward pass through the MorpheusDecoder.

        Args:
            frmi (Tensor): The FRMI input tensor.
            eeg (Tensor): The EEG input tensor.

        Returns:
            Tensor: The output tensor after decoding.

        """
        # X = FRMI of shapef
        # # MRI data is represented as a 4D tensor: [batch_size, channels, depth, height, width].
        # # EEG data is represented as a 3D tensor: [batch_size, channels, time_samples].
        x = self.frmi_embedding(frmi)

        # Rearrange to text dimension
        x = reduce(x, "b c d h w -> b (h w) (c d)", "sum")

        # Rearrange tensor to be compatible with attn
        x = threed_to_text(x, self.num_channels, self.dim)

        # Masked Attention
        x, _, _ = self.masked_attn(x)

        # EEG Encoder
        eeg = self.encoder(eeg)

        # Multihead Attention
        x = self.mha(x, eeg, x) + x

        # Feed Forward
        x = self.ffn(x) + x

        # Projection to original dimension
        x = self.proj(x)

        # Softmax
        x = self.softmax(x)

        return x


class Morpheus(nn.Module):
    """
    Morpheus model implementation.

    Args:
        dim (int): Dimension of the model.
        heads (int): Number of attention heads.
        depth (int): Number of layers in the model.
        dim_head (int): Dimension of each attention head.
        dropout (int): Dropout rate.
        num_channels: Number of input channels.
        conv_channels: Number of channels in the convolutional layers.
        kernel_size: Size of the convolutional kernel.
        in_channels: Number of input channels for the convolutional layers.
        out_channels: Number of output channels for the convolutional layers.
        stride (int, optional): Stride value for the convolutional layers. Defaults to 1.
        padding (int, optional): Padding value for the convolutional layers. Defaults to 0.
        ff_mult (int, optional): Multiplier for the feed-forward layer dimension. Defaults to 4.

    Attributes:
        dim (int): Dimension of the model.
        heads (int): Number of attention heads.
        depth (int): Number of layers in the model.
        dim_head (int): Dimension of each attention head.
        dropout (int): Dropout rate.
        num_channels: Number of input channels.
        conv_channels: Number of channels in the convolutional layers.
        kernel_size: Size of the convolutional kernel.
        stride (int): Stride value for the convolutional layers.
        padding (int): Padding value for the convolutional layers.
        ff_mult (int): Multiplier for the feed-forward layer dimension.
        layers (nn.ModuleList): List of MorpheusDecoder layers.
        norm (nn.LayerNorm): Layer normalization module.

    Examples:
        >>> model = Morpheus(
        ...     dim=128,
        ...     heads=4,
        ...     depth=2,
        ...     dim_head=32,
        ...     dropout=0.1,
        ...     num_channels=32,
        ...     conv_channels=32,
        ...     kernel_size=3,
        ...     in_channels=1,
        ...     out_channels=32,
        ...     stride=1,
        ...     padding=1,
        ...     ff_mult=4,
        ... )
        >>> frmi = torch.randn(1, 1, 32, 32, 32)
        >>> eeg = torch.randn(1, 32, 128)
        >>> output = model(frmi, eeg)
        >>> print(output.shape)


    """

    def __init__(
        self,
        dim: int,
        heads: int,
        depth: int,
        dim_head: int,
        dropout: int,
        num_channels,
        conv_channels,
        kernel_size,
        in_channels,
        out_channels,
        stride=1,
        padding=0,
        ff_mult: int = 4,
        scatter: bool = False,
        *args,
        **kwargs,
    ):
        super(Morpheus, self).__init__()
        self.dim = dim
        self.heads = heads
        self.depth = depth
        self.dim_head = dim_head
        self.dropout = dropout
        self.num_channels = num_channels
        self.conv_channels = conv_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.ff_mult = ff_mult
        self.scatter = scatter

        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(
                MorpheusDecoder(
                    dim,
                    heads,
                    depth,
                    dim_head,
                    dropout,
                    num_channels,
                    conv_channels,
                    kernel_size,
                    in_channels,
                    out_channels,
                    stride,
                    padding,
                    ff_mult,
                    *args,
                    **kwargs,
                )
            )

        self.norm = nn.LayerNorm(num_channels)

    def forward(self, frmi: Tensor, eeg: Tensor) -> Tensor:
        """
        Forward pass of the Morpheus model.

        Args:
            frmi (Tensor): Input tensor for the frmi modality.
            eeg (Tensor): Input tensor for the eeg modality.

        Returns:
            Tensor: Output tensor of the model.

        """
        for layer in self.layers:
            x = layer(frmi, eeg)
            x = self.norm(x)

        if self.scatter:
            # Scatter the tensor to 4d spatial tensor
            s1 = x.shape[1]
            x = rearrange(x, "b (s s1) d -> b s s1 d", s1=s1)

        return x

In [20]:
model = Morpheus(
    dim=128,  # Dimension of the model
    heads=4,  # Number of attention heads
    depth=2,  # Number of transformer layers
    dim_head=32,  # Dimension of each attention head
    dropout=0.1,  # Dropout rate
    num_channels=32,  # Number of input channels
    conv_channels=32,  # Number of channels in convolutional layers
    kernel_size=3,  # Kernel size for convolutional layers
    in_channels=1,  # Number of input channels for convolutional layers
    out_channels=32,  # Number of output channels for convolutional layers
    stride=1,  # Stride for convolutional layers
    padding=1,  # Padding for convolutional layers
    ff_mult=4,  # Multiplier for feed-forward layer dimension
    scatter = False, # Whether to scatter to 4d representing spatial dimensions
)

# Creating random tensors for input data
frmi = torch.randn(1, 1, 32, 32, 32)  # Random tensor for FRMI data
eeg = torch.randn(1, 32, 128)  # Random tensor for EEG data

# Passing the input data through the model to get the output
output = model(frmi, eeg)

# Printing the shape of the output tensor
print(output.shape)

TypeError: can only concatenate tuple (not "Tensor") to tuple

In [21]:
class MorpheusDecoder(nn.Module):
    """
    MorpheusDecoder is a module that performs decoding in the Morpheus model.

    Args:
        dim (int): The dimension of the input and output tensors.
        heads (int): The number of attention heads.
        depth (int): The number of layers in the decoder.
        dim_head (int): The dimension of each attention head.
        dropout (int): The dropout rate.
        num_channels: The number of channels in the input tensor.
        conv_channels: The number of channels in the convolutional layers.
        kernel_size: The size of the convolutional kernel.
        in_channels: The number of input channels for the FMRI embedding.
        out_channels: The number of output channels for the FMRI embedding.
        stride (int, optional): The stride of the convolutional layers. Defaults to 1.
        padding (int, optional): The padding size for the convolutional layers. Defaults to 0.
        ff_mult (int, optional): The multiplier for the feed-forward network dimension. Defaults to 4.

    Attributes:
        dim (int): The dimension of the input and output tensors.
        heads (int): The number of attention heads.
        depth (int): The number of layers in the decoder.
        dim_head (int): The dimension of each attention head.
        dropout (int): The dropout rate.
        num_channels: The number of channels in the input tensor.
        conv_channels: The number of channels in the convolutional layers.
        kernel_size: The size of the convolutional kernel.
        stride (int): The stride of the convolutional layers.
        padding (int): The padding size for the convolutional layers.
        ff_mult (int): The multiplier for the feed-forward network dimension.
        frmi_embedding (nn.Linear): The linear layer for FRMI embedding.
        masked_attn (MultiQueryAttention): The masked attention module.
        mha (MultiheadAttention): The multihead attention module.
        frmni_embedding (FMRIEmbedding): The FMRI embedding module.
        ffn (FeedForward): The feed-forward network module.
        proj (nn.Linear): The linear layer for projection to original dimension.
        softmax (nn.Softmax): The softmax activation function.
        encoder (MorpheusEncoder): The MorpheusEncoder module.

    """

    def __init__(
        self,
        dim: int,
        heads: int,
        depth: int,
        dim_head: int,
        dropout: int,
        num_channels,
        conv_channels,
        kernel_size,
        in_channels,
        out_channels,
        stride=1,
        padding=0,
        ff_mult: int = 4,
        *args,
        **kwargs,
    ):
        super(MorpheusDecoder, self).__init__()
        self.dim = dim
        self.heads = heads
        self.depth = depth
        self.dim_head = dim_head
        self.dropout = dropout
        self.num_channels = num_channels
        self.conv_channels = conv_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.ff_mult = ff_mult

        self.frmi_embedding = nn.Linear(num_channels, dim)

        self.masked_attn = MultiQueryAttention(
            dim,
            heads,
        )

        self.frmni_embedding = FMRIEmbedding(
            in_channels, out_channels, kernel_size, stride, padding
        )

        self.ffn = FeedForward(dim, dim, dropout, *args, **kwargs)

        self.proj = nn.Linear(dim, num_channels)

        self.softmax = nn.Softmax(1)

    def forward(self, frmi: Tensor) -> Tensor:
        """
        Perform forward pass through the MorpheusDecoder.

        Args:
            frmi (Tensor): The FRMI input tensor.

        Returns:
            Tensor: The output tensor after decoding.

        """
        # X = FRMI of shapef
        # # MRI data is represented as a 4D tensor: [batch_size, channels, depth, height, width].
        x = self.frmi_embedding(frmi)

        # Rearrange to text dimension
        x = reduce(x, "b c d h w -> b (h w) (c d)", "sum")

        # Rearrange tensor to be compatible with attn
        x = threed_to_text(x, self.num_channels, self.dim)

        # Masked Attention
        x, _, _ = self.masked_attn(x)

        # Feed Forward
        x = self.ffn(x) + x

        # Projection to original dimension
        x = self.proj(x)

        # Softmax
        x = self.softmax(x)

        return x

In [24]:
decoder = MorpheusDecoder(
    dim=128,  # Dimension of the model
    heads=4,  # Number of attention heads
    depth=2,  # Number of transformer layers
    dim_head=32,  # Dimension of each attention head
    dropout=0.1,  # Dropout rate
    num_channels=32,  # Number of input channels
    conv_channels=32,  # Number of channels in convolutional layers
    kernel_size=3,  # Kernel size for convolutional layers
    in_channels=1,  # Number of input channels for convolutional layers
    out_channels=32,  # Number of output channels for convolutional layers
    stride=1,  # Stride for convolutional layers
    padding=1,  # Padding for convolutional layers
    ff_mult=4,  # Multiplier for feed-forward layer dimension
)

# Creating random tensors for input data
frmi = torch.randn(1, 1, 32, 32, 32)  # Random tensor for FRMI data

# Passing the input data through the model to get the output
output = decoder(frmi)

# Printing the shape of the output tensor
print(output.shape)

torch.Size([1, 32, 32])


In [25]:
!pip install nilearn

Collecting nilearn
  Downloading nilearn-0.10.3-py3-none-any.whl (10.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.4/10.4 MB[0m [31m27.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: nilearn
Successfully installed nilearn-0.10.3


In [26]:
from nilearn import datasets
from nilearn.datasets import load_mni152_template
from nilearn.image import resample_to_img
from nilearn import masking
from nilearn import image as nimg
from nilearn import plotting
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import MultiTaskLasso

In [27]:
adhd_dataset = datasets.fetch_adhd(n_subjects=2)


Added README.md to /root/nilearn_data


Dataset created in /root/nilearn_data/adhd

Downloading data from https://www.nitrc.org/frs/download.php/7781/adhd40_metadata.tgz ...


 ...done. (1 seconds, 0 min)
Extracting data from /root/nilearn_data/adhd/fbef5baff0b388a8c913a08e1d84e059/adhd40_metadata.tgz..... done.


Downloading data from https://www.nitrc.org/frs/download.php/7782/adhd40_0010042.tgz ...


Downloaded 23175168 of 44414948 bytes (52.2%,    0.9s remaining) ...done. (2 seconds, 0 min)
Extracting data from /root/nilearn_data/adhd/31769c9cee5cd55f045e62633d651f3d/adhd40_0010042.tgz..... done.


Downloading data from https://www.nitrc.org/frs/download.php/7783/adhd40_0010064.tgz ...


Downloaded 43098112 of 45583539 bytes (94.5%,    0.1s remaining) ...done. (1 seconds, 0 min)
Extracting data from /root/nilearn_data/adhd/31769c9cee5cd55f045e62633d651f3d/adhd40_0010064.tgz..... done.


In [41]:
adhd_raw_data = nimg.get_data(nimg.load_img(adhd_dataset.func[0]))

In [42]:
print(adhd_raw_data.shape)

(61, 73, 61, 176)


In [44]:
adhd_raw_data = rearrange(adhd_raw_data, "x y z t -> t z y x")

In [47]:
print(adhd_raw_data.shape)

(176, 61, 73, 61)


In [51]:
adhd_raw_data = adhd_raw_data.reshape(1, 176, 61, 73, 61)
adhd_raw_data = torch.Tensor(adhd_raw_data)

In [54]:
decoder = MorpheusDecoder(
    dim=128,  # Dimension of the model
    heads=4,  # Number of attention heads
    depth=2,  # Number of transformer layers
    dim_head=32,  # Dimension of each attention head
    dropout=0.1,  # Dropout rate
    num_channels=61,  # Number of input channels
    conv_channels=32,  # Number of channels in convolutional layers
    kernel_size=3,  # Kernel size for convolutional layers
    in_channels=1,  # Number of input channels for convolutional layers
    out_channels=32,  # Number of output channels for convolutional layers
    stride=1,  # Stride for convolutional layers
    padding=1,  # Padding for convolutional layers
    ff_mult=4,  # Multiplier for feed-forward layer dimension
)

# Passing the input data through the model to get the output
output = decoder(adhd_raw_data)

# Printing the shape of the output tensor
print(output.shape)

torch.Size([1, 61, 61])


In [55]:
output

tensor([[[0.0000e+00, 0.0000e+00, 1.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         ...,
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 2.8178e-01],
         [1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 5.0594e-26,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          1.2752e-43, 0.0000e+00]]], grad_fn=<SoftmaxBackward0>)