In [16]:
import torch
import torch.nn as nn
from src.models.vit import ViT

## Define Model

In [2]:
# Define model parameters as a dictionary
model_params = {
    "img_size": 96,
    "patch_size": 12,
    "hidden_size": 768,
    "mlp_dim": 3072,
    "num_layers": 12,
    "num_heads": 12,
    "in_chans": 3,
    "dropout_rate": 0.0,
    "spatial_dims": 3,
    "patch_embed": 'conv',
    "pos_embed": "sincos",
    "classification": False,
    "num_classes": 2,
    "qkv_bias": False,
    "norm_layer": nn.LayerNorm,
    "post_activation": "Tanh",
}

# Determine the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the model
model = ViT(**model_params).to(device)

model.eval()

ViT(
  (patch_embedding): PatchEmbeddingBlock(
    (patch_embeddings): Conv3d(3, 768, kernel_size=(12, 12, 12), stride=(12, 12, 12))
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (blocks): ModuleList(
    (0-11): 12 x AttentionBlock(
      (mlp): MLPBlock(
        (linear1): Linear(in_features=768, out_features=3072, bias=True)
        (linear2): Linear(in_features=3072, out_features=768, bias=True)
        (fn): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (drop2): Dropout(p=0.0, inplace=False)
      )
      (att_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (ffn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): SelfAttention(
        (qkv): Linear(in_features=768, out_features=2304, bias=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
)

In [3]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    return total_params
    
num_params = count_parameters(model)
print(f"Total parameters: {num_params}")

Total parameters: 89404416


## Create Dataloader

In [4]:
from monai import data
from monai import transforms

from src.utils.misc import create_dataset
from src.data.transforms import MultipleWindowScaleStack

In [5]:
roi = [96, 96, 96]
window_sizes = [(40, 80), (80, 200), (600, 2800)]
windowing_tran = MultipleWindowScaleStack(
    keys=["image"],
    window_sizes=window_sizes,
)

trans = transforms.Compose(
    [
        transforms.LoadImaged(
            keys=["image"],
            image_only=True,
            allow_missing_keys=True,
        ),
        transforms.EnsureChannelFirstd(
            keys=["image"],
            allow_missing_keys=True,
        ),
        transforms.Orientationd(
            keys=["image"],
            axcodes="RAS",
            allow_missing_keys=True,
        ),
        transforms.Spacingd(
            keys=["image"],
            pixdim=(1.0, 1.0, 1.0),
            mode=3,
            allow_missing_keys=True
        ),
        transforms.CropForegroundd(
            keys=["image"],
            source_key="image",
            allow_smaller=False,
            allow_missing_keys=True,
        ),
        transforms.Resized(
            keys=["image"],
            spatial_size=(roi[0], roi[1], roi[2]),
            allow_missing_keys=True,
        ),
        windowing_tran,
        transforms.ToTensord(
            keys=["image"],
            allow_missing_keys=True,
        ),
    ]
)

In [7]:
# We present four nifti file samples from public RSNA dataset as example
sample_images = [
    {'image': '../sample_nifti/ID_c7cc66d672.nii.gz'},
    {'image': '../sample_nifti/ID_7a2902d550.nii.gz'},
    {'image': '../sample_nifti/ID_a186fdc315.nii.gz'},
    {'image': '../sample_nifti/ID_53d3458f56.nii.gz'}
]

In [8]:
# Define dataloader and get scan array
batch_size = 4

test_ds = data.Dataset(
    data=sample_images, 
    transform=trans,
)

test_loader = data.DataLoader(
    dataset=test_ds,
    batch_size=batch_size,
    num_workers=1,
    pin_memory=True,
    shuffle=False,
)

x_img = next(iter(test_loader))['image']
x_img = x_img.to(device)

In [9]:
print(f"Shape for current scan: {x_img.shape}")

Shape for current scan: torch.Size([4, 3, 96, 96, 96])


## Extract Feature

In [10]:
# Model output last layer and all layers representations for all tokens
# of shape [batch_size, num_tokens, num_features]
last_layer_out, all_layers_out = model(x_img)

In [11]:
print(f"Shape of last layer output feature: {last_layer_out.shape}")

Shape of last layer output feature: torch.Size([4, 513, 768])


## Classification

In [12]:
from src.models.classifier import LinearClassifier

In [13]:
# Define model parameters as a dictionary
model_params = {
    "dim": 768,
    "num_classes": 2,
}

# Define linear classifier
classifier = LinearClassifier(**model_params).to(device)

classifier.eval()

LinearClassifier(
  (bn): BatchNorm1d(768, eps=1e-06, momentum=0.1, affine=False, track_running_stats=True)
  (linear): Linear(in_features=768, out_features=2, bias=True)
)

In [14]:
# Get [CLS] token representation and get logits with linear classifier
# of shape [batch_size, num_classes]
cls_token_out = last_layer_out[:, 0, :]
logits = classifier(cls_token_out)

In [15]:
print(f"Shape of logits: {logits.shape}")

Shape of logits: torch.Size([4, 2])
