# A transformer-based unified multimodal framework for Alzheimer's disease assessment

<https://www.sciencedirect.com/science/article/pii/S0010482524010643#bib38>

* **model interpretation**

Applying **Grad-CAM** 

> The gradient value of each token offers a cue for interpreting the models decisions, illuminating the specific tokens or positions within the input image and non-image data that influenced the models output.
>
> 
> Through assigning SHAP values to specific voxels or by mapping internal network nodes, SHAP offered insights into the contribution of individual features to the model's output.

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, in_channels=3, embed_dim=768):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(
            in_channels, embed_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        x = self.proj(x)  # [B, E, H/P, W/P]
        x = x.flatten(2)  # [B, E, N]
        x = x.transpose(1, 2)  # [B, N, E]
        return x


class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_dim=3072,
        dropout_rate=0.1,
    ):
        super(VisionTransformer, self).__init__()
        self.patch_embed = PatchEmbedding(
            img_size=img_size, patch_size=patch_size, embed_dim=embed_dim
        )

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, (img_size // patch_size) ** 2 + 1, embed_dim)
        )
        self.pos_drop = nn.Dropout(dropout_rate)

        self.transformer_blocks = nn.ModuleList(
            [
                nn.TransformerEncoderLayer(
                    d_model=embed_dim,
                    nhead=num_heads,
                    dim_feedforward=mlp_dim,
                    dropout=dropout_rate,
                )
                for _ in range(depth)
            ]
        )

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B, _, _, _ = x.size()
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for block in self.transformer_blocks:
            x = block(x)

        x = self.norm(x)
        return self.head(x[:, 0])


# Example usage
model = VisionTransformer(img_size=224, patch_size=16, num_classes=1000)

# Fake data for testing
input_tensor = torch.randn(1, 3, 224, 224)
output = model(input_tensor)
print(output.shape)  # Should print torch.Size([1, 1000])

torch.Size([1, 1000])


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, in_channels=3, embed_dim=768):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.conv = nn.Conv2d(
            in_channels, embed_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        # B: batch size, C: channels, H: height, W: width, N: num_patches
        x = self.conv(x)  # [B, C, H/P, W/P]
        x = x.flatten(2)  # [B, C, N]
        x = x.transpose(1, 2)  # [B, N, C]
        return x

In [5]:
# check shape of output
img_size, patch_size = 224, 16
pe = PatchEmbedding(img_size, patch_size)
x = torch.randn(1, 3, img_size, img_size)
print(pe(x).shape)  # torch.Size([1, 196, 768])

torch.Size([1, 196, 768])


In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim**-0.5
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.fc = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Dropout(dropout)
        self.proj_drop = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.size()
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, self.head_dim)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.fc(x)
        x = self.proj_drop(x)
        return x

In [7]:
# check shape of output
pe = PatchEmbedding(224, 16)

In [32]:
class ViTMLP(nn.Module):
    def __init__(self, embed_dim, mlp_dim, dropout=0.1):
        super(ViTMLP, self).__init__()
        self.fc1 = nn.Linear(embed_dim, mlp_dim)
        self.fc2 = nn.Linear(mlp_dim, embed_dim)
        self.act = nn.GELU()
        self.drop1 = nn.Dropout(dropout)
        self.drop2 = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        # print("ViTMLP shape:", x.shape)
        return x

In [9]:
class ViTBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
        super(ViTBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.mlp = ViTMLP(embed_dim, mlp_dim, dropout)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

In [10]:
class ViT(nn.Module):
    def __init__(
        self,
        img_size,
        patch_size,
        in_channels,
        num_classes,
        embed_dim,
        depth,
        num_heads,
        mlp_dim,
        blk_dropout=0.1,
        emb_dropout=0.1,
        lr=1e-4,
    ):
        super(ViT, self).__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.pos_embed = nn.Parameter(
            torch.randn(1, self.patch_embed.num_patches + 1, embed_dim)
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.dropout = nn.Dropout(emb_dropout)
        self.blocks = nn.ModuleList(
            [ViTBlock(embed_dim, num_heads, mlp_dim, blk_dropout) for _ in range(depth)]
        )
        self.head = nn.Sequential(
            nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes)
        )

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embed
        x = self.dropout(x)

        for blk in self.blocks:
            x = blk(x)

        x = x[:, 0]
        return self.head(x)

In [19]:
class Trainer:
    def __init__(self, model, data, num_epochs=10, learning_rate=1e-4):
        self.model = model
        self.data = data
        self.num_epochs = num_epochs
        self.learning_rate = learning_rate
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

    def train(self):
        self.model.train()
        for epoch in range(self.num_epochs):
            total_loss = 0
            correct = 0
            total = 0
            for i, batch in enumerate(self.data):
                inputs, labels = batch
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                total_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            avg_loss = total_loss / len(self.data)
            accuracy = 100 * correct / total
            print(f'Epoch [{epoch + 1}/{self.num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')


In [11]:
# check shape of output
img_size, patch_size, in_channels = 224, 16, 3
num_classes, embed_dim, depth, num_heads, mlp_dim = 1000, 768, 12, 12, 3072
vit = ViT(
    img_size, patch_size, in_channels, num_classes, embed_dim, depth, num_heads, mlp_dim
)
x = torch.randn(1, in_channels, img_size, img_size)
print(vit(x).shape)  # torch.Size([1, 1000])

ViTMLP shape: torch.Size([1, 197, 768])
ViTMLP shape: torch.Size([1, 197, 768])
ViTMLP shape: torch.Size([1, 197, 768])
ViTMLP shape: torch.Size([1, 197, 768])
ViTMLP shape: torch.Size([1, 197, 768])
ViTMLP shape: torch.Size([1, 197, 768])
ViTMLP shape: torch.Size([1, 197, 768])
ViTMLP shape: torch.Size([1, 197, 768])
ViTMLP shape: torch.Size([1, 197, 768])
ViTMLP shape: torch.Size([1, 197, 768])
ViTMLP shape: torch.Size([1, 197, 768])
ViTMLP shape: torch.Size([1, 197, 768])
torch.Size([1, 1000])


In [15]:
# load temp data
import nibabel as nib

nii_file = ['../temp/temp/I35933.nii',
            '../temp/temp/I50468.nii',
            '../temp/temp/I64631.nii']
data = [nib.load(nii_file[i]).get_fdata() for i in range(len(nii_file))]
data = torch.tensor(data).float()
data = data.permute(0, 3, 1, 2)  


  data = torch.tensor(data).float()


In [16]:
data.shape

torch.Size([3, 160, 192, 192])

In [30]:
inputs = data
labels = torch.tensor([0, 1, 2])
data_batch = (inputs, labels)
data_batch = [data_batch for _ in range(10)]

In [33]:
img_size, patch_size, in_channels = 192, 16, 160
num_classes, embed_dim, depth, num_heads, mlp_dim = 3, 768, 12, 12, 3072
vit = ViT(
    img_size, patch_size, in_channels, num_classes, embed_dim, depth, num_heads, mlp_dim
)
trainer = Trainer(vit, data_batch)
trainer.train()

Epoch [1/10], Step [10/10], Loss: 1.7117
Epoch [2/10], Step [10/10], Loss: 1.1826
Epoch [3/10], Step [10/10], Loss: 0.9688
Epoch [4/10], Step [10/10], Loss: 0.6730
Epoch [5/10], Step [10/10], Loss: 0.0609
Epoch [6/10], Step [10/10], Loss: 0.0019
Epoch [7/10], Step [10/10], Loss: 0.0001
Epoch [8/10], Step [10/10], Loss: 0.0001
Epoch [9/10], Step [10/10], Loss: 0.0001
Epoch [10/10], Step [10/10], Loss: 0.0001
