<h1 align="center"> Channel Adaptive Vision Transformer: How to Use </h1>

This notebook is a step-by-step guide on how to use the Channel Adaptive Vision Transformer (ChAdaViT) model for image classification. The ChAdaViT model is a vision transformer that can adaptively take as input images from different number of channels, and project them into the same embedding space. This is particularly useful when working with multi-channel images, such as medical microscopy or even geopspatial images with multiple modalities.

In [3]:
import torch
import torch.nn as nn
import numpy as np
import hashlib

from src.backbones.vit.dev_chada_vit import DevChAdaViT

In [2]:
# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


## Download weights
You can download the model weights under this URL: https://drive.google.com/file/d/1SUfUwerHJlf0vo9mdgM0mRn9TNZkaqXl/view?usp=drive_link   
Make sure to download it on the same directory as this notebook, and give the right permissions.

Enter the path of the weights:

In [3]:
CKPT_PATH = "weights.ckpt"

You can check the hash of the downloaded file here:

In [4]:
def check_hash(file_path, expected_hash):
    md5 = hashlib.md5()
    with open(file_path, "rb") as f:
        while chunk := f.read(4096):
            md5.update(chunk)
    return md5.hexdigest() == expected_hash

In [None]:
check_hash(CKPT_PATH, "e8a24ac58b8e34bdce10e0024d507f2e")

## Params

In [4]:
# Params
PATCH_SIZE = 16
EMBED_DIM = 192
RETURN_ALL_TOKENS = False
MAX_NUMBER_CHANNELS = 10

## Load State Dict

In [6]:
model = DevChAdaViT(
    patch_size=PATCH_SIZE,
    embed_dim=EMBED_DIM,
    return_all_tokens=RETURN_ALL_TOKENS,
    max_number_channels=MAX_NUMBER_CHANNELS,
    depth=1,
)

In [7]:
assert (
    CKPT_PATH.endswith(".ckpt")
    or CKPT_PATH.endswith(".pth")
    or CKPT_PATH.endswith(".pt")
)
state = torch.load(CKPT_PATH, map_location="cpu")["state_dict"]
for k in list(state.keys()):
    if "encoder" in k:
        state[k.replace("encoder", "backbone")] = state[k]
    if "backbone" in k:
        state[k.replace("backbone.", "")] = state[k]
    del state[k]
model.load_state_dict(state, strict=False)
model.to(device)
model.eval()

NameError: name 'CKPT_PATH' is not defined

## Generate Random Images (Optional)
If you are here, you probably want to test the model with your own images :)      
But anyway, you can use the following code to generate random images with different number of channels to simply check if the model is working as expected.

In [8]:
def generate_data(num_images: int, max_num_channels=MAX_NUMBER_CHANNELS):
    imgs = []
    labels = []
    for i in range(num_images):
        num_channels = np.random.randint(1, max_num_channels + 1)
        imgs.append(torch.randn(num_channels, 224, 224))
        labels.append(torch.randint(0, 1, (1,)))
    data = list(zip(imgs, labels))
    return data

In [9]:
data = generate_data(num_images=10, max_num_channels=MAX_NUMBER_CHANNELS)
imgs, labels = zip(*data)
distribution = {}
for img in imgs:
    num_channels = img.shape[0]
    distribution[num_channels] = distribution.get(num_channels, 0) + 1
print(
    f"Number of generated images: {len(imgs)} \n Distribution of number of channels: {distribution}"
)

Number of generated images: 10 
 Distribution of number of channels: {4: 2, 6: 3, 10: 1, 9: 2, 7: 1, 1: 1}


## Prepare Data

One of the key elements of the ChAdaViT model is the ability to adapt to different number of channels. In this section, we will prepare the data to be fed into the model. We will use the `torchvision` library to load the data, and then we will create a custom dataset that will adapt the images to the model.

In [10]:
def collate_images(batch: list):
    """
    Collate a batch of images into a list of channels, a list of labels and a mapping of the number of channels per image.

    Args:
        batch (list): A list of tuples of (img, label)

    Return:
        channels_list (torch.Tensor): A tensor of shape (X*num_channels, 1, height, width)
        labels_list (torch.Tensor): A tensor of shape (batch_size, )
        num_channels_list (list): A list of the number of channels per image
    """
    num_channels_list = []
    channels_list = []
    labels_list = []

    # Iterate over the list of images and extract the channels
    for image, label in batch:
        labels_list.append(label)
        num_channels = image.shape[0]
        num_channels_list.append(num_channels)

        for channel in range(num_channels):
            channel_image = image[channel, :, :].unsqueeze(0)
            channels_list.append(channel_image)

    channels_list = torch.cat(channels_list, dim=0).unsqueeze(
        1
    )  # Shape: (X*num_channels, 1, height, width)

    batched_labels = torch.tensor(labels_list)

    return channels_list, batched_labels, num_channels_list

In [11]:
collated_batch = collate_images(data)

In [12]:
collated_batch[2]

[4, 6, 10, 9, 9, 6, 7, 4, 1, 6]

## Extract Features

In [13]:
@torch.no_grad()
def extract_features(
    model: nn.Module,
    batch: torch.Tensor,
    mixed_channels: bool,
    return_all_tokens: bool,
):
    """
    Forwards a batch of images X and extracts the features from the backbone.

    Args:
        model (nn.Module): The model to forward the images through.
        X (torch.Tensor): The input tensor of shape (batch_size, 1, height, width).
        list_num_channels (list): A list of the number of channels per image.
        index (int): The index of the image to extract the features from.
        mixed_channels (bool): Whether the images have mixed number of channels or not.
        return_all_tokens (bool): Whether to return all tokens or not.

    Returns:
        feats (Dict): A dictionary containing the extracted features.
    """
    model.eval()

    # Overwrite model "mixed_channels" parameter for evaluation on "normal" datasets with uniform channels size
    model.mixed_channels = mixed_channels

    X, targets, list_num_channels = batch
    X = X.to(device, non_blocking=True)

    feats = model(x=X, index=0, list_num_channels=[list_num_channels])

    if not mixed_channels:
        if return_all_tokens:
            # Concatenate feature embeddings per image
            chunks = feats.view(sum(list_num_channels), -1, feats.shape[-1])
            chunks = torch.split(chunks, list_num_channels, dim=0)
            # Concatenate the chunks along the batch dimension
            feats = torch.stack(chunks, dim=0)
        # Assuming tensor is of shape (batch_size, num_tokens, backbone_output_dim)
        feats = feats.flatten(start_dim=1)

    return feats

In [16]:
device = 'cuda:2'
extracted_features = extract_features(
    model=model.to(device),
    batch=collated_batch,
    mixed_channels=True,
    return_all_tokens=RETURN_ALL_TOKENS,
)

In [17]:
assert extracted_features.shape[0] == len(
    collated_batch[2]
)  # num_embeddings == num_images, even with different number of channels
print(
    f"{extracted_features.shape[0]} embeddings of dim {extracted_features.shape[1]} were extracted."
)

10 embeddings of dim 192 were extracted.
