## **License and Usage Notice**
---

**MedLiT-seed** and all accompanying source code in this notebook are released under the **GNU Affero General Public License v3.0 (AGPL-3.0)**.

You are free to:
- **Use** the code,
- **Modify** it,
- **Redistribute** it,
- **Build upon** it for academic or commercial purposes,

**provided that** any distributed or publicly deployed derivative work **must also be released under the AGPL-3.0** and must provide end users with access to the corresponding source code, including any modifications.

This includes cases where the code is run as part of a network-accessible service (e.g., APIs, web applications, inference servers), as required by the **Affero clause**.

Please note:
- The MedLiT-seed model weights are provided strictly for **research and educational** use, and may be subject to additional terms from the hosting institution.
- This notebook loads pre-trained weights from a user-provided directory via the variable `weight_path`. You are responsible for ensuring that the model weights you download and use comply with any applicable licensing or data-use restrictions.

By using this notebook, you acknowledge that you have read and understood the license terms.

For details, see: https://www.gnu.org/licenses/agpl-3.0.en.html


---

In [None]:
!pip install medmnist

Collecting medmnist
  Downloading medmnist-3.0.2-py3-none-any.whl.metadata (14 kB)
Collecting fire (from medmnist)
  Downloading fire-0.7.1-py3-none-any.whl.metadata (5.8 kB)
Downloading medmnist-3.0.2-py3-none-any.whl (25 kB)
Downloading fire-0.7.1-py3-none-any.whl (115 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.9/115.9 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fire, medmnist
Successfully installed fire-0.7.1 medmnist-3.0.2


## **1. Import the necessary libraries**
---

In [None]:
import einops as ein
import numpy as np
import math
import matplotlib
import matplotlib.pyplot as plt
import os
import pandas as pd
import random
import torch
import shutil
import torch.nn as nn
import torch.nn.functional as F


from medmnist import INFO, Evaluator
from medmnist import (
    OrganCMNIST, BloodMNIST, DermaMNIST, OrganAMNIST, OrganSMNIST,
    PathMNIST, PneumoniaMNIST, RetinaMNIST, BreastMNIST, TissueMNIST,
    OCTMNIST, ChestMNIST
)
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score
from torchvision.transforms import v2
from torchvision.utils import make_grid
from torch.distributions.normal import Normal
from torch.utils.data import DataLoader
from tqdm import tqdm
from google.colab import drive


print("Versions of key libraries")
print("---")
print("torch:      ", torch.__version__)
print("numpy:      ", np.__version__)
print("matplotlib: ", matplotlib.__version__)


def setup_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

Versions of key libraries
---
torch:       2.5.1
numpy:       2.0.1
matplotlib:  3.10.3


## **2. Hyperparameter setup**
---

In [None]:
class Parser():
    def __init__(self):

        self.seed         = 42         # Deterministic

        # Input setup
        # ----------------------------------
        self.img_size     = (256, 256)    # image size
        self.patch_size   = 16            # patch size
        self.num_patches  = (self.img_size[0] // self.patch_size) ** 2 + 1
        self.patch_length = (self.patch_size ** 2) * 3
        self.num_classes  = 12
        self.batch_size   = 128

        # Encoder setup
        # ----------------------------------
        self.num_heads    = 6
        self.num_experts  = 3
        self.k            = 2
        self.embed_size   = 216
        self.hidden_size  = 27
        self.num_groups   = 3
        self.dropout_rate = 0.1
        self.out_proj     = 216

        self.num_layers   = 9
        self.num_blocks   = 3
        self.hidden_ratio = 3
        self.lyrs_per_block  = self.num_layers / self.num_blocks


args    = Parser()
device  = 'cuda' if torch.cuda.is_available() else 'cpu'
setup_seed(args.seed)

## **3. Modules required to build Mixture-of-Experts layer**
---

In [None]:
class FeedForward(nn.Module):
    def __init__(self,
                 embed_size,
                 hidden_size,
                 output_size,
                 dropout_rate):
        super().__init__()
        self.FeedForward  = nn.Sequential(nn.Linear(in_features=embed_size,
                                                    out_features=hidden_size),
                                          nn.ReLU(),
                                          nn.Linear(in_features=hidden_size,
                                                    out_features=output_size),
                                          nn.Dropout(dropout_rate))

    def forward(self, x):
        return self.FeedForward(x)
                                    # 'x' has a shape of (batchSize, num_embeds, output_size)

In [None]:
class FFNSwiGLU(nn.Module):
    def __init__(self,
                 embed_size,
                 hidden_size,
                 output_size,
                 dropout_rate=0.0,
                 V=None,
                 W2=None):
        super(FFNSwiGLU, self).__init__()


        self.W      = nn.Linear(in_features=embed_size,
                                out_features=hidden_size)

        self.V      = V
        self.W2     = W2
        if self.V is None:
            self.V  = nn.Linear(in_features=embed_size,
                                out_features=hidden_size)
        if self.W2 is None:
            self.W2 = nn.Linear(in_features=hidden_size,
                                out_features=output_size)
        self.dropout= nn.Dropout(dropout_rate)

    def forward(self, x):
        xW        = self.W(x)
        xV        = self.V(x)

        output    = self.W2(F.silu(xW) * xV)

        return self.dropout(output)

In [None]:
class DataDispatcher(object):
    def __init__(self,
                 num_experts=4):

        self.num_experts      = num_experts
        self.gates            = None
        self.batch_index      = None
        self.part_sizes       = None
        self.score_for_expert = None

    def dispatch(self,
                 data,
                 gates):
                                            # 'data' has a shape of (batch_size, length)
                                            # 'gates' has a shape of (batch_size, num_experts)
        self.gates            = gates
        experts               = torch.nonzero(self.gates)
                                            # Get the index of the expert to which the input 'data'
                                            # will be directed to
                                            # the shape of `expert` is (batch_size * k, 2)
        (sorted_experts,
        sorted_experts_idx)  = experts.sort(dim=0)
                                            # Sort the `expert` along the column direction,
                                            # so that the rows (in `data`) that shall be directed to
                                            # specific expert can be put together for later slicing

                                            # for `sorted_expert_idx`, what matters is the second column,
                                            # where the index tells where the value comes from
                                            # in the second column of `sorted_expert`

                                            # both `sorted_experts` and `sorted_experts_idx` has
                                            # a shape of (batch_size * k, 2)

        (_, expert_index)     = sorted_experts.split(1, dim=1)
                                            # Split `sorted_experts` by columns and take only the second columns
                                            # `expert_index` has a shape of (batch_size * k, 1)

        self.batch_index      = experts[sorted_experts_idx[:, 1], 0]
                                            # Get the index of row (from `data`) that will go to
                                            # the expert specified in 'expert_index'
                                            # `batch_index` has a shape of batch_size * k (1d array)

        self.part_sizes       = (self.gates > 0).sum(0).tolist()
                                            # Count the number of inputs that will go to each expert
                                            # 'part_sizes' is a list of the number of rows that will
                                            # go to each expert

                                            # Based on the indices in `batch_index`, we retrieve
                                            # the corresponding rows in `gates`, as a result
                                            # `gates_retrieved` has a shape of (batch_size * k, num_expert)
        gates_retrieved       = self.gates[self.batch_index]
        self.score_for_expert = torch.gather(input=gates_retrieved,
                                           dim=1,
                                           index=expert_index)
                                            # Then for each row, take the score that will be used for calculation
                                            # for the expert specified in `expert_index`
                                            # `score_for_expert` has a shape of (batch_size * k, 1)

                                            # Prepare the data to be dispatched to each expert
        data_retrieved        = data[self.batch_index]
                                            # `data_retrieved` has a shape of (batch_size * k, length)

                                            # Split the data into a list of `num_expert` tensors
        data_dispatched       = torch.split(data_retrieved,
                                          self.part_sizes,
                                          dim=0)

        return data_dispatched


    def combine(self,
                expert_output,
                multiply_by_gates=True):
                                            # `expert_output` is a list of torch tensors. The amount
                                            # of torch tensors is equal to `num_experts`
        merged          = torch.cat(expert_output, dim=0)

        if multiply_by_gates:
            merged      = merged.mul(self.score_for_expert)
                                            # `merged` has a shape of (batch_size * k, length)
        zeros           = torch.zeros(self.gates.size(0),
                                    expert_output[0].size(1),
                                    device=merged.device)
                                            # create a tensor that has the same shape of input data
                                            # NOTE: compared to David Rau's implementation,
                                            # we did not set 'requires_grad' to True at here

        combined        = zeros.index_add(dim=0,
                                        index=self.batch_index,
                                        source=merged)
                                            # for each index in `batch_index`, retrieve the rows in `merged` by the index, add those rows,
                                            # and put the output back into the row specified by the index

        return combined
                                            # both `zeros` and `combined` have a shape of
                                            # (batch_size, length)

In [None]:
class MixtureOfExperts(nn.Module):
    def __init__(self,
                 embed_size,                # embed_length
                 hidden_size,
                 dropout_rate=0.0,
                 num_experts=4,
                 k=2,
                 expert_type='FFN',
                 noisy_gating=True,
                 noise_epsilon=1e-2,
                 loss_coef=1e-2):
      super().__init__()
      self.num_experts  = num_experts
      self.k            = k
      self.embed_size   = embed_size
      self.hidden_size  = hidden_size
      self.dropout_rate = dropout_rate
      self.noisy_gating = noisy_gating
      self.noise_epsilon= noise_epsilon
      self.loss_coef    = loss_coef
      self.H_x          = None
      self.w_gate       = nn.Parameter(torch.randn(embed_size,
                                                   num_experts),
                                       requires_grad=True)
                                            # `w_gate` has a shape of (embed_size, num_experts)
      self.w_noise      = nn.Parameter(torch.randn(embed_size,
                                                   num_experts),
                                       requires_grad=True)
                                            # `w_noise` has a shape of (embed_size, num_experts)

      self.dispatcher   = DataDispatcher(num_experts=self.num_experts)
      self.expert_type  = expert_type

      if self.expert_type == 'FFN':
        self.experts      = nn.ModuleList([FeedForward(self.embed_size,
                                                       self.hidden_size,
                                                       self.embed_size,
                                                       self.dropout_rate) for i in range(self.num_experts)])
      elif self.expert_type == 'FFNSwiGLU':
        self.experts      = nn.ModuleList([FFNSwiGLU(self.embed_size,
                                                     self.hidden_size,
                                                     self.embed_size,
                                                     self.dropout_rate) for i in range(self.num_experts)])
      elif self.expert_type == 'FFNSwiGLUShared':
        V                 = nn.Linear(in_features=self.embed_size,
                                      out_features=self.hidden_size)
        W2                = nn.Linear(in_features=self.hidden_size,
                                      out_features=self.embed_size)

        self.experts      = nn.ModuleList([FFNSwiGLU(self.embed_size,
                                                     self.hidden_size,
                                                     self.embed_size,
                                                     self.dropout_rate,
                                                     V,
                                                     W2) for i in range(self.num_experts)])

      self.softplus     = nn.Softplus()
      self.softmax      = nn.Softmax(dim=1)
      self.normal       = Normal(0.0, 1.0)

      assert self.k <= self.num_experts


    # ----------------------
    def cv_squared(self, x):
                                            # calculate the squared coefficient of variation
                                            # x has to be a 1d array
      eps       = 1e-10

      if x.shape[0] == 1:
          return torch.tensor([0],
                              device=x.device,
                              dtype=x.dtype)
                                            # The case where the array has only 1 value
      else:
          return x.var() /( x.mean()**2 + eps)


    #----------------------
    def to_2d(self, x):
      return ein.rearrange(x, 'b r c -> (b r) c')


    def to_3d(self, x, r):
      return ein.rearrange(x, '(b r) c -> b r c', r=r)



    # ----------------------
    def prob_in_top_k(self,
                      x_W_g,
                      logits,
                      noise_std,
                      top_logits):
      batch             = x_W_g.size(0)
      m                 = top_logits.size(1)
      top_logits_flatten= top_logits.flatten()
                                          # `top_logits_flatten` has a shape of
                                          # (batch_size * num_patches * k,)

      threshold_positions_if_in = torch.arange(batch,
                                               device=x_W_g.device) * m + self.k
                                          # `threshold_position_if_in` has a shape of
                                          # (batch_size * num_patches,)
      threshold_if_in             = torch.unsqueeze(torch.gather(top_logits_flatten,
                                                                 0,
                                                                 threshold_positions_if_in),
                                                    1)
      is_in                       = torch.gt(logits, threshold_if_in)
                                          # `is_in` has a shape of
                                          # (batch_size * num_patches, num_experts)

      threshold_positions_if_out  = threshold_positions_if_in - 1
      threshold_if_out            = torch.unsqueeze(torch.gather(top_logits_flatten,
                                                                 0,
                                                                 threshold_positions_if_out),
                                                    1)

      sub_in        = x_W_g - threshold_if_in
      sub_out       = x_W_g - threshold_if_out
      sub           = torch.where(is_in, sub_in, sub_out)
                                          # `sub_in`, `sub_out` and `sub` have a shape of
                                          # (batch_size * num_patches, num_experts)

      prob_of_sub   = self.normal.cdf(sub/noise_std)
                                          # `prob_of_sub` has a shape of
                                          # (batch_size * num_patches, num_experts)

      return prob_of_sub

    # ----------------------
    def noisy_top_k_gating(self,
                           x,
                           train=False,
                           noise_epsilon=1e-2):
                                            # in this implementation, the shape of x
                                            # is (batch_size*num_patches, embed_length)
      x_W_g           = x @ self.w_gate
                                            # `x_W_g` has a shape of (batch_size * num_patches, num_experts)

      if self.noisy_gating and train:
          x_W_noise   = x @ self.w_noise
                                            # `x_W_noise` has a shape of (batch_size * num_patches, num_experts)
          noise_std   = self.softplus(x_W_noise) + noise_epsilon
          logits      = x_W_g + torch.randn_like(x_W_g) * noise_std
      else:
          logits      = x_W_g
                                            # `noise_std`, `logits`
                                            # all have a shape of (batch_size * num_patches, num_experts)
      self.H_x        = logits

      (top_logits,
       top_indices)   = logits.topk(min(self.k + 1, self.num_experts), dim=1)
                                            # `top_logits` is required for calculating load-balacing loss
                                            # `top_logits` and `top_indices` have a shape of
                                            # (batch_size * num_patches, k + 1)
      top_k_logits    = top_logits[:, :self.k]
      top_k_indices   = top_indices[:, :self.k]
      top_k_gates     = self.softmax(top_k_logits)
                                            # `top_k_logits`, `top_k_indices` and `top_k_gates` have a shape of
                                            # (batch_size * num_patches, k)
      zeros           = torch.zeros_like(logits)
                                            # NOTE: compared to David Rau's implementation,
                                            # we did not set 'requires_grad' to True at here
      gates           = zeros.to(x.dtype).scatter(1,
                                                  top_k_indices,
                                                  top_k_gates)
                                             # `zeros` and `gates` have a shape of
                                             # (batch_size * num_patches, num_experts)

      if self.noisy_gating and train and (self.num_experts - self.k > 0):
          load        = self.prob_in_top_k(x_W_g,
                                           logits,
                                           noise_std,
                                           top_logits).sum(0)
                                            # `load` has a shape of (num_experts,)
      else:
          load        = (gates > 0).sum(0)
          load        = load.float()
                                            # `load` has a shape of (num_experts,)

      return (gates, load)


    # ----------------------
    def forward(self,
                x,
                train=False,
                previous_loss=0):
                                            # `x` has a shape of
                                            # (batch_size, num_patches, embed_length)

      r              = x.shape[1]           # required for self.to_3d()
      x_2d           = self.to_2d(x)
      (gates,
        load)        = self.noisy_top_k_gating(x_2d,
                                               train,
                                               noise_epsilon=self.noise_epsilon)

      expert_inputs = self.dispatcher.dispatch(x_2d, gates)
      expert_outputs= [self.experts[i](expert_inputs[i]) for i in range(self.num_experts)]
      y             = self.dispatcher.combine([expert_outputs[i] for i in range(self.num_experts)])

      if train:
          importance= gates.sum(0)
          loss      = self.loss_coef*(self.cv_squared(importance)+self.cv_squared(load))
          loss      = loss+previous_loss

          return (self.to_3d(y,r), loss)
      else:
          return self.to_3d(y,r)



## **4. Modules and functions required for MAE-ViT Classifier**
---

In [None]:
# 1.
# -------------------------
def get_patch_embeddings(images,
                         patch_size):

    unfold  = F.unfold(images,
                       kernel_size=patch_size,
                       stride=patch_size)
                                      # `unfold` has a shape of
                                      # [batch_size, embed_length, num_patches]

    embeds  = unfold.transpose(1, 2)
                                      # `embeds` has a shape of
                                      # [batch_size, num_patches, embed_length]
    return embeds

# 2.
# -------------------------
def get_patch_w_class_embed(images,
                            patch_size,
                            cls='zeros'):
    embeds            = get_patch_embeddings(images,
                                             patch_size)

    if cls == 'zeros':
        # cls_embeds    = torch.zeros(embeds.shape[0],
        #                             1,
        #                             embeds.shape[-1])
        cls_embeds = torch.zeros(embeds.shape[0], 1, embeds.shape[-1], device=embeds.device)
        embeds        = torch.cat((embeds,cls_embeds),
                                  dim=1)
                                      # `embeds` has a shape of
                                      # [batch_size, num_patches + 1, embed_length]

    return embeds


# 3.
# -------------------------
class GQAttentionHead(nn.Module):
    def __init__(self,
                 head_size,
                 embed_size,
                 heads_per_group,
                 attn_mask=None,
                 dropout_p=0.0,
                 is_causal=False):
        super().__init__()

        self.attn_mask= attn_mask
        self.dropout_p= dropout_p
        self.is_causal= is_causal

        self.key      = nn.Linear(in_features=embed_size,
                                  out_features=head_size,
                                  bias=False)

        self.queries  = nn.ModuleList([nn.Linear(in_features=embed_size,
                                                 out_features=head_size,
                                                 bias=False) for _ in range(heads_per_group)])

        self.value    = nn.Linear(in_features=embed_size,
                                  out_features=head_size,
                                  bias=False)


    def forward(self, x):

        k               = self.key(x)     # 'k' has a shape of (batchSize, num_embeds, head_size)
        v               = self.value(x)

        attn_out        = torch.cat([F.scaled_dot_product_attention(q(x),
                                                                    k,
                                                                    v,
                                                                    attn_mask=self.attn_mask,
                                                                    dropout_p=self.dropout_p,
                                                                    is_causal=self.is_causal) for q in self.queries], dim=-1)
                                          # 'attn_out' now has a shape of (batchSize, num_embeds, num_heads*head_size)
        return attn_out


# 4.
# -------------------------
class GroupedQueryAttention(nn.Module):
    def __init__(self,
                 num_heads,
                 embed_size,
                 dropout_rate,
                 num_groups,
                 attn_mask=None,
                 dropout_p=0.0,
                 is_causal=False):
        super().__init__()
        self.Heads    = nn.ModuleList([GQAttentionHead(embed_size // num_heads,
                                                       embed_size,
                                                       num_heads // num_groups,     # heads_per_group
                                                       attn_mask,
                                                       dropout_p,
                                                       is_causal) for _ in range(num_groups)])
                                    # embed_size // num_heads gives head_size

        self.proj     = nn.Linear(in_features=num_heads*(embed_size // num_heads),
                                  out_features=embed_size,
                                  bias=False)
        self.dropout  = nn.Dropout(dropout_rate)

                                    # The input 'x' should have a shape of (batchSize, num_embeds, embed_size)
    def forward(self, x):
        attn_out      = torch.cat([h(x) for h in self.Heads], dim=-1)
                                    # 'attn_out' now has a shape of (batchSize, num_embeds, num_heads*head_size)
        attn_out      = self.proj(attn_out)
                                    # 'attn_out' now has a shape of (batchSize, num_embeds, embed_size)
        return self.dropout(attn_out)

# 5.
# -------------------------
class TransformerMoEBlock(nn.Module):
    def __init__(self,
                 num_heads,
                 embed_size,
                 hidden_size,
                 num_groups,
                 num_experts,
                 k,
                 dropout_rate,
                 expert_type='FFN'):
        super().__init__()
        self.GQAttn               = GroupedQueryAttention(num_heads=num_heads,
                                                          embed_size=embed_size,
                                                          dropout_rate=dropout_rate,
                                                          num_groups=num_groups)
        self.MoE                  = MixtureOfExperts(embed_size=embed_size,
                                                     hidden_size=hidden_size,
                                                     dropout_rate=dropout_rate,
                                                     num_experts=num_experts,
                                                     k=k,
                                                     expert_type=expert_type)

        self.LayerNorm1           = nn.LayerNorm(normalized_shape=embed_size)
        self.LayerNorm2           = nn.LayerNorm(normalized_shape=embed_size)

                                    # The input 'x' should have a shape of (batchSize, num_embeds, embed_size)
    def forward(self,
                inputs):
        x               = inputs[0]
        train           = inputs[1]
        previous_loss   = inputs[2]

        x       = self.LayerNorm1(x)
        x       = x + self.GQAttn(x)
                                    # The output 'x' has a shape of (batchSize, num_embeds, embed_size
        if not train:
            x       = x + self.MoE(self.LayerNorm2(x))
                                    # The output 'x' has a shape of (batchSize, num_embeds, embed_size)
            return (x, train, 0)
        else:
            moe_out = self.MoE(self.LayerNorm2(x),
                                   train,
                                   previous_loss)
            x       = x + moe_out[0]
            return (x, train, moe_out[1])
                                    # 'x[0]' has a shape of (batchSize, num_embeds, embed_size)
                                    # 'x[1] is the loss


# 6.
# -------------------------
                                        # The input to `maskEncoder` is of shape
                                        # (batch_size, num_patches, patch_length)
class MaskEncoder(nn.Module):
    def __init__(self,
                 num_heads,
                 num_layers,
                 num_patches,
                 patch_length,
                 embed_size,
                 hidden_size,
                 num_groups,
                 num_experts,
                 k,
                 out_proj,
                 dropout_rate,
                 num_blocks):
        super().__init__()
        self.num_blocks       = num_blocks
        self.num_layers       = num_layers
        self.num_heads        = num_heads
        self.embed_size       = embed_size
        self.hidden_size      = hidden_size
        self.num_groups       = num_groups
        self.num_experts      = num_experts
        self.k                = k
        self.dropout_rate     = dropout_rate
        self.MaskEmbed        = nn.Parameter(torch.randn(embed_size),
                                             requires_grad=True)
        self.Proj             = nn.Linear(in_features=patch_length,
                                          out_features=embed_size,
                                          bias=False)
        self.PositionEmbeds   = nn.Embedding(num_embeddings=num_patches,
                                             embedding_dim=embed_size)

        self.TransformerBlocks= self.create_transformer_blocks()
        self.LayerNorm        = nn.LayerNorm(normalized_shape=embed_size)
        self.OutProj          = nn.Linear(in_features=embed_size,
                                          out_features=out_proj,
                                          bias=False)
        self.register_buffer('full_pos_idx',
                             torch.arange(num_patches))



    def create_transformer_blocks(self,
                                  hidden_ratio=args.hidden_ratio):
        num_experts           = self.num_experts
        k                     = self.k
        lyrs_per_block        = math.ceil(self.num_layers/self.num_blocks)
        blocks                = nn.ModuleList()
        max_hidden_size       = self.hidden_size * hidden_ratio

        for i in range(self.num_layers):
            additional_hidden = int(((self.num_layers-1)-i)*(max_hidden_size-self.hidden_size)/(self.num_layers-1))
            blocks.append(TransformerMoEBlock(num_heads=self.num_heads,
                                              embed_size=self.embed_size,
                                              hidden_size=self.hidden_size+additional_hidden,
                                              num_groups=self.num_groups,
                                              num_experts=num_experts,
                                              k=k,
                                              dropout_rate=self.dropout_rate,
                                              expert_type='FFNSwiGLUShared'))
            if (i+1) % lyrs_per_block == 0:
                num_experts   += 1

        return nn.Sequential(*blocks)


    def forward(self,
                x,
                moe_train=False,
                mae_train=False):
        if not mae_train:
            x       = self.Proj(x) + self.PositionEmbeds(self.full_pos_idx)

            if not moe_train:
                return self.LayerNorm(self.TransformerBlocks((x,
                                                              moe_train,
                                                              0))[0])
            else:
                y   = self.TransformerBlocks((x,
                                              moe_train,
                                              0))
                return (self.LayerNorm(y[0]),
                        y[2])
        else:
            masked_patches    = x[0]
            unmasked_patches  = x[1]
            masked_indices    = x[2]
            unmasked_indices  = x[3]

            unmasked_embeds   = self.Proj(unmasked_patches) + self.PositionEmbeds(unmasked_indices)
            y                 = self.TransformerBlocks((unmasked_embeds,
                                                       mae_train,
                                                       0))
            unmasked_embeds   = self.LayerNorm(y[0])

            masked_embeds     = self.LayerNorm(self.PositionEmbeds(masked_indices) + self.MaskEmbed)

            return (self.OutProj(masked_embeds),
                    self.OutProj(unmasked_embeds),
                    y[2])                         # 'y[2]' gives us the loss from mixture of experts



# 7.
# -------------------------
class ViTClassifier(nn.Module):
    def __init__(self, encoder, embed_size, num_classes):
        super().__init__()
        self.Encoder = encoder
        self.Dropout = nn.Dropout(0.2)
        self.Head = nn.Linear(in_features=embed_size, out_features=num_classes, bias=True)

    def forward(self, x, moe_train=False):
        if moe_train:
            (x, l) = self.Encoder(x, moe_train=moe_train)

            if torch.isnan(x).any() or torch.isnan(l).any():
                raise ValueError("Encoder output or auxiliary loss l contains NaN.")

            x = self.Head(self.Dropout(x[:, -1, :]))
            return (x, l)

        else:
            x = self.Encoder(x)
            x = self.Head(x[:, -1, :])
            return x


# 8.
# -------------------------
                                                      # trainable parameter counting function for pytorch
def count_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


## **5. Build ViT classifier**
---

In [None]:
class FinetuneHead(nn.Module):
    def __init__(self, embed_dim, num_classes, task_type="multi-class"):
        super().__init__()
        self.fc = nn.Linear(embed_dim, num_classes)
        self.task_type = task_type

    def forward(self, x):
        return self.fc(x)


def build_model(task_type):
    encoder = MaskEncoder(
        args.num_heads, args.num_layers, args.num_patches, args.patch_length,
        args.embed_size, args.hidden_size, args.num_groups, args.num_experts,
        args.k, args.out_proj, args.dropout_rate, args.num_blocks
    )

    model = ViTClassifier(encoder, args.embed_size, args.num_classes)
    model.Head = FinetuneHead(embed_dim=args.embed_size, num_classes=args.num_classes, task_type=task_type)
    return model

## **6. Setup data pipeline**
---

In [None]:
drive.mount('/content/gdrive')

data_source       = None
weight_path       = '/content/gdrive/My Drive/your_weights_folder'
data_path         = 'data'
                                            # Create the data directory
os.makedirs(data_path, exist_ok=True)

if data_source is not None:
    shutil.copytree(data_source,
                    data_path,
                    dirs_exist_ok=True)
    data_download = False
else:
    data_download = True


transform = v2.Compose([
    v2.Resize(args.img_size),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

subset_class_map = {
    'chest': ChestMNIST,
    'path': PathMNIST,
    'pneumonia': PneumoniaMNIST,
    'retina': RetinaMNIST,
    'breast': BreastMNIST,
    'tissue': TissueMNIST,
    'oct': OCTMNIST,
    'organc': OrganCMNIST,
    'organs': OrganSMNIST,
    'organa': OrganAMNIST,
    'blood': BloodMNIST,
    'derma': DermaMNIST
}

## **7. Test finetuned model on each subset**
---

In [None]:
def get_loss_function(task_type):
    if task_type == "multi-label":
        return nn.BCEWithLogitsLoss()
    else:
        return nn.CrossEntropyLoss()


def get_predictions(logits, task_type, threshold=None):
    if task_type == "multi-label":
        if threshold is None:
            threshold = 0.5
        return (torch.sigmoid(logits) > threshold).float()
    else:
        return torch.argmax(logits, dim=1)


trained_subsets   = ["chest",
                     "retina",
                     "breast",
                     "tissue",
                     "oct",
                     "organs",
                     "organa",
                     "organc",
                     "blood",
                     "derma",
                     "pneumonia",
                     "path"]

subsets           = ['retina',
                     'chest',
                     'breast',
                     'tissue',
                     'oct',
                     'organs',
                     'organa',
                     'organc',
                     'blood',
                     'derma',
                     'pneumonia',
                     'path']

multilabel_subsets= {"chest"}
threshold         = None                              # for multi-label dataset

for subset in subsets:
  model_path      = os.path.join(weight_path, (subset + "_best_model.pth"))
  task_type       = "multi-label" if subset in multilabel_subsets else "multi-class"

  if subset not in trained_subsets:
    print(f"This model was not trained on a \"{subset}\" subset")
  else:
    print("")
    print(f"Testing model for subset {subset}")
    print("------------------------")
    print(f"Model weight: {model_path}")

                                                      # Create test dataloader
    dataset_class = subset_class_map[subset]
    test_dataset  = dataset_class(split="test",
                                  download=data_download,
                                  transform=transform,
                                  size=224,
                                  root=data_path)

    test_loader   = DataLoader(test_dataset,
                               batch_size=args.batch_size,
                               shuffle=False)
                                                      # Identify and set no. of labels for a given subset,
                                                      # set threshold for prediction function
    test_labels   = test_dataset.labels
    if task_type == "multi-label":
        args.num_classes  = test_labels.shape[1]
        pos_counts        = test_labels.sum(axis=0)
        total             = test_labels.shape[0]
        best_threshold    = threshold or torch.tensor([0.25] * args.num_classes).to(device)
    else:
        args.num_classes  = len(set(test_dataset.labels.flatten().tolist()))
        best_threshold    = None

                                                      # Build model
    model                 = build_model(task_type).to(device)
    model.load_state_dict(torch.load(model_path, weights_only=True))
    model.eval()

                                                      # Run model
    criterion                 = get_loss_function(task_type)
    test_loss, test_accuracy  = [], []
    y_preds, y_trues, y_probs = [], [], []

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Testing:    "):
            x           = get_patch_w_class_embed(images=images,
                                                  patch_size=args.patch_size).to(device)
            if task_type == "multi-label":
                labels  = labels.float().to(device)
                                                      # `labels`: (batch_size, num_classes)
            else:
                labels  = labels.squeeze(1).long().to(device) if labels.ndim == 2 else labels.long().to(device)
                                                      # `labels`: (batch_size, )

            logits      = model(x, moe_train=False)
            logits      = logits[0] if isinstance(logits, tuple) else logits
                                                      # `logits`: (batch_size, num_classes)
            loss        = criterion(logits, labels)
            test_loss.append(loss.item())

            preds       = get_predictions(logits.detach(), task_type, threshold=best_threshold)
                                                      # `preds`: (batch_size, )
            probs       = torch.sigmoid(logits) if task_type == "multi-label" else torch.softmax(logits, dim=1)
                                                      # `probs`: (batch_size, num_classes), AUC calculation
            y_preds.append(preds.cpu().numpy())
            y_trues.append(labels.cpu().numpy())
            y_probs.append(probs.cpu().numpy())

    y_preds       = np.concatenate(y_preds)
    y_trues       = np.concatenate(y_trues)
    y_probs       = np.concatenate(y_probs)
    avg_loss      = sum(test_loss) / len(test_loss)
    avg_acc       = np.mean(y_preds == y_trues)

    if task_type == "multi-class":
        num_classes = y_probs.shape[1]
        if num_classes == 2:
                                                      # Binary class: pick probability of positive class
            auc = roc_auc_score(y_trues, y_probs[:, 1])
        else:
            # Multi-class: use OVR
            auc = roc_auc_score(y_trues, y_probs, multi_class='ovr', average='macro')
    else:
        auc = roc_auc_score(y_trues, y_preds, average='macro')

    print("")
    print(f"[Test] {subset} - Accuracy: {avg_acc:.4f}")
    print(f"[Test] {subset} - Loss    : {avg_loss:.4f}")
    print(f"[Test] {subset} - AUC     : {auc:.4f}")




Testing model for subset retina
------------------------
Model weight: ./all-runs/e216_Ip_IMf_batch256/retina_best_model.pth
Using downloaded and verified file: ./data/retinamnist_224.npz


Testing:    : 100%|██████████| 4/4 [00:01<00:00,  2.02it/s]



[Test] retina - Accuracy: 0.5975
[Test] retina - Loss    : 1.0451
[Test] retina - AUC     : 0.8006

Testing model for subset chest
------------------------
Model weight: ./all-runs/e216_Ip_IMf_batch256/chest_best_model.pth
Using downloaded and verified file: ./data/chestmnist_224.npz


Testing:    : 100%|██████████| 176/176 [01:19<00:00,  2.22it/s]



[Test] chest - Accuracy: 0.9339
[Test] chest - Loss    : 0.1737
[Test] chest - AUC     : 0.5277

Testing model for subset breast
------------------------
Model weight: ./all-runs/e216_Ip_IMf_batch256/breast_best_model.pth
Using downloaded and verified file: ./data/breastmnist_224.npz


Testing:    : 100%|██████████| 2/2 [00:00<00:00,  3.40it/s]



[Test] breast - Accuracy: 0.8974
[Test] breast - Loss    : 0.4204
[Test] breast - AUC     : 0.9175

Testing model for subset tissue
------------------------
Model weight: ./all-runs/e216_Ip_IMf_batch256/tissue_best_model.pth
Using downloaded and verified file: ./data/tissuemnist_224.npz


Testing:    : 100%|██████████| 370/370 [02:47<00:00,  2.21it/s]



[Test] tissue - Accuracy: 0.7207
[Test] tissue - Loss    : 0.8173
[Test] tissue - AUC     : 0.9379

Testing model for subset oct
------------------------
Model weight: ./all-runs/e216_Ip_IMf_batch256/oct_best_model.pth
Using downloaded and verified file: ./data/octmnist_224.npz


Testing:    : 100%|██████████| 8/8 [00:03<00:00,  2.11it/s]



[Test] oct - Accuracy: 0.8590
[Test] oct - Loss    : 0.4123
[Test] oct - AUC     : 0.9744

Testing model for subset organs
------------------------
Model weight: ./all-runs/e216_Ip_IMf_batch256/organs_best_model.pth
Using downloaded and verified file: ./data/organsmnist_224.npz


Testing:    : 100%|██████████| 69/69 [00:31<00:00,  2.18it/s]



[Test] organs - Accuracy: 0.8161
[Test] organs - Loss    : 0.5960
[Test] organs - AUC     : 0.9822

Testing model for subset organa
------------------------
Model weight: ./all-runs/e216_Ip_IMf_batch256/organa_best_model.pth
Using downloaded and verified file: ./data/organamnist_224.npz


Testing:    : 100%|██████████| 139/139 [01:05<00:00,  2.13it/s]



[Test] organa - Accuracy: 0.9646
[Test] organa - Loss    : 0.1371
[Test] organa - AUC     : 0.9991

Testing model for subset organc
------------------------
Model weight: ./all-runs/e216_Ip_IMf_batch256/organc_best_model.pth
Using downloaded and verified file: ./data/organcmnist_224.npz


Testing:    : 100%|██████████| 65/65 [00:30<00:00,  2.14it/s]



[Test] organc - Accuracy: 0.9458
[Test] organc - Loss    : 0.1797
[Test] organc - AUC     : 0.9976

Testing model for subset blood
------------------------
Model weight: ./all-runs/e216_Ip_IMf_batch256/blood_best_model.pth
Using downloaded and verified file: ./data/bloodmnist_224.npz


Testing:    : 100%|██████████| 27/27 [00:15<00:00,  1.69it/s]



[Test] blood - Accuracy: 0.9860
[Test] blood - Loss    : 0.0514
[Test] blood - AUC     : 0.9993

Testing model for subset derma
------------------------
Model weight: ./all-runs/e216_Ip_IMf_batch256/derma_best_model.pth
Using downloaded and verified file: ./data/dermamnist_224.npz


Testing:    : 100%|██████████| 16/16 [00:08<00:00,  1.85it/s]



[Test] derma - Accuracy: 0.8299
[Test] derma - Loss    : 0.6066
[Test] derma - AUC     : 0.9653

Testing model for subset pneumonia
------------------------
Model weight: ./all-runs/e216_Ip_IMf_batch256/pneumonia_best_model.pth
Using downloaded and verified file: ./data/pneumoniamnist_224.npz


Testing:    : 100%|██████████| 5/5 [00:02<00:00,  2.25it/s]



[Test] pneumonia - Accuracy: 0.8910
[Test] pneumonia - Loss    : 0.4483
[Test] pneumonia - AUC     : 0.9850

Testing model for subset path
------------------------
Model weight: ./all-runs/e216_Ip_IMf_batch256/path_best_model.pth
Using downloaded and verified file: ./data/pathmnist_224.npz


Testing:    : 100%|██████████| 57/57 [00:31<00:00,  1.83it/s]


[Test] path - Accuracy: 0.9532
[Test] path - Loss    : 0.1637
[Test] path - AUC     : 0.9976



