**Sources:**

The code used for TTA and inference calculations:

[1] https://github.com/taheeraahmed/master-thesis

**Import Libraries**

In [1]:
pip install mambavision transformers mamba_ssm timm

Note: you may need to restart the kernel to use updated packages.


In [2]:
# Standard Libraries
import os
import copy
import random
from glob import glob
from PIL import Image
import zipfile
import time
from statistics import mean, stdev
from pathlib import Path
import subprocess

# Data Manipulation Libraries
import pandas as pd
import numpy as np

# Visualization Libraries
import matplotlib.pyplot as plt

# Progress Bar
from tqdm import tqdm

# Machine Learning Libraries
import torch
import torch.nn as nn
from torch.utils.data import Dataset, Subset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split

# Hugging Face transformers to load the MambaVision model
from transformers import AutoModel

  from .autonotebook import tqdm as notebook_tqdm


**Define Parameters**

In [3]:
# Paths
ZIP_PATH = '/cluster/home/bjorneme/projects/Data/chestX-ray14.zip'
EXTRACTED_PATH = '/cluster/home/bjorneme/projects/Data/chestX-ray14-extracted'
BACKBONE_PATH = '../../ChestX-ray14 Single Models/MambaVision_Large/mambavision_L_tta_backbone.pt'

# Model
MODEL_NAME = "nvidia/MambaVision-L-21K"

# Disease Labels
disease_labels = [
    'Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema',
    'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening',
    'Cardiomegaly', 'Nodule', 'Mass', 'Hernia'
]

# Other parameters
SEED = 42
NUM_WORKERS = 32

# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


**Set Seed for Reproducibility**

In [4]:
def seed_everything(seed=SEED):
    """
    Sets the seed to ensure reproducibility.
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Apply the seed
seed_everything()

**Print Hardware Info**

In [5]:
gpu_info = subprocess.check_output([
    "nvidia-smi",
    "--query-gpu=name,memory.total","--format=csv,noheader"]
).decode()
print("PyTorch:", torch.__version__)
print("CUDA", torch.version.cuda)
print(f"GPU: {gpu_info}")

PyTorch: 2.4.0+cu121
CUDA 12.1
GPU: NVIDIA A100-SXM4-80GB, 81920 MiB



# **Step 1: Load Data**

In [6]:
def extract_data(zip_path, extracted_path):
    """
    Extracts the ZIP file of the dataset.
    """
    os.makedirs(extracted_path, exist_ok=True)
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extracted_path)

# Extract ChestX-ray14 dataset
# TODO: Uncomment to extract data from zip
# extract_data(ZIP_PATH, EXTRACTED_PATH)

# **Step 2: Data Preprocessing**

In [7]:
def load_labels(csv_path, extracted_path):
    """
    Read labels from CSV, maps images to paths, and create binary disease labels.
    """

    # Read the CSV containing labels
    labels_df = pd.read_csv(csv_path)

    # Create binary columns for each disease label
    for disease in disease_labels:
        labels_df[disease] = labels_df['Finding Labels'].str.contains(disease).astype(int)

    # Create binary column for 'No Finding'
    labels_df['No Finding'] = labels_df['Finding Labels'].str.contains('No Finding').astype(int)

    # Map images to their full path
    labels_df['Path'] = labels_df['Image Index'].map(
        {os.path.basename(path): path for path in glob(os.path.join(extracted_path, '**', 'images', '*.png'))}
    )
    
    return labels_df

# Path to the labels CSV file
labels_csv_path = os.path.join(EXTRACTED_PATH, 'Data_Entry_2017.csv')

# Load and preprocess the labels
df = load_labels(labels_csv_path, EXTRACTED_PATH)

**Split Dataset**

In [8]:
# Split based on patients
unique_patients = df['Patient ID'].unique()

# Split patients into training, validation and test sets
train_val_patients, test_patients = train_test_split(
    unique_patients, test_size=0.2, random_state=SEED
)
train_patients, val_patients = train_test_split(
    train_val_patients, test_size=0.125, random_state=SEED
)

# Create dataframes for training, validation, and test sets
train_df = df[df['Patient ID'].isin(train_patients)].reset_index(drop=True)
val_df = df[df['Patient ID'].isin(val_patients)].reset_index(drop=True)
test_df = df[df['Patient ID'].isin(test_patients)].reset_index(drop=True)

# Verify Split Sizes
print(f"Train dataset size: {len(train_df)}. Number of unique patients: {len(train_patients)}")
print(f"Validation size: {len(val_df)}. Number of unique patients: {len(val_patients)}")
print(f"Test size: {len(test_df)}. Number of unique patients: {len(test_patients)}")

Train dataset size: 78614. Number of unique patients: 21563
Validation size: 11212. Number of unique patients: 3081
Test size: 22294. Number of unique patients: 6161


**Define Dataset for Chest X-ray images**

In [9]:
class ChestXrayDataset(Dataset):
    """
    Create dataset for Chest X-ray images.
    """
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        
        # Get image and labels
        img_path = self.df.iloc[idx]['Path']
        image = plt.imread(img_path)
        label = self.df.iloc[idx][disease_labels].values.astype(np.float32)
        
        # Apply transformation on image
        image = self.transform(image)
        
        return image, label

**Define Transformations**

In [10]:
mean_transform=[0.485, 0.456, 0.406]
std_transform=[0.229, 0.224, 0.225]

# Define transformations for test data
test_transforms = transforms.Compose([

    # Convert image to PIL format
    transforms.ToPILImage(),

    # Convert to 3 channels
    transforms.Grayscale(num_output_channels=3),

    # Resize the image to 256x256
    transforms.Resize((256,256)),

    # Create 10 crops
    transforms.TenCrop(224),
    transforms.Lambda(lambda crops: torch.stack([
        transforms.ToTensor()(crop) for crop in crops
    ])),

    # Normalize using ImageNet mean and std
    transforms.Lambda(lambda crops: torch.stack(
        [transforms.Normalize(mean_transform, std_transform)(crop) for crop in crops]
    ))
])

**Create Test Dataset**

In [11]:
test_dataset = ChestXrayDataset(test_df, transform=test_transforms)
reduced_dataset = Subset(test_dataset, list(range(1000)))

**Create Test DataLoader**

In [12]:
test_loader = DataLoader(reduced_dataset, batch_size=1, shuffle=False, num_workers=NUM_WORKERS)

# **Step 3: Build the Model**

In [13]:
# Define the model
class MultiLabelClassifier(nn.Module):
    """
    Multi-Label Classification Model using MambaVision as the base model.
    """
    def __init__(self, device, model_name="nvidia/MambaVision-T2-1K", num_classes=len(disease_labels)):
        super(MultiLabelClassifier, self).__init__()

        # Load pre-trained MambaVision model
        self.base_model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)

        # Replace the classification head to match the number of disease labels
        self.base_model.model.head = nn.Linear(self.base_model.model.head.in_features, num_classes)

    def forward(self, x):
        avg_pool, _ = self.base_model(x)
        return self.base_model.model.head(avg_pool)

# Initialize the Model
mambavision_L_model = MultiLabelClassifier(device, MODEL_NAME)
mambavision_L_model = nn.DataParallel(mambavision_L_model).to(device)

# Load the model
mambavision_L_model.load_state_dict(torch.load(
    BACKBONE_PATH,
    weights_only=True
))

2025-05-16 22:31:39.545155: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747427499.557012 3999355 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747427499.561149 3999355 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-16 22:31:39.581978: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


<All keys matched successfully>

# **Step 4: Calculate Metrics Inference**

**Warm-up**

In [14]:
dummy = torch.randn(1,3,224,224).to(device)
for _ in range(100):
    _ = mambavision_L_model(dummy)

**Inference GPU Runtime and Peak Memory Usage**

In [15]:
# Current wall-clock time
torch.cuda.reset_peak_memory_stats(device)

# List to store batch time (ms)
batch_times = []

# Used for start and end time record of GPU
start = torch.cuda.Event(enable_timing=True)
end   = torch.cuda.Event(enable_timing=True)

# Progress bar
progress_bar = tqdm(test_loader, desc="Evaluating on Test Set")

# Set to evaluation mode
mambavision_L_model.eval()

# Disable gradients for evaluation
with torch.no_grad():
    for inputs, labels in progress_bar:

        # Retrieve input sizes
        batch_size, ncrops, C, H, W = inputs.size()

        # Move to device
        inputs, labels = inputs.to(device), labels.to(device)

        # Change to [batch_size * ncrops, C, H, W]
        inputs = inputs.view(-1, C, H, W) 

        # Start GPU timer
        start.record()
        
        # Forward pass
        outputs = mambavision_L_model(inputs)

        # Stop GPU timer
        end.record()

        # Wait until the kernel is finished
        torch.cuda.synchronize()

        # Save time in ms
        batch_times.append(start.elapsed_time(end))

mem_allocated  = torch.cuda.max_memory_allocated()
mem_reserved = torch.cuda.max_memory_reserved()

# Print Inference runtime
print(f"Inference runtime: {mean(batch_times):.2f} ± {stdev(batch_times):.2f} ms")
print(f"Throughput: {(len(test_loader)/sum(batch_times))*1000:.0f} img/s")


print()

# Print GPU Memory
print(f"Peak GPU memory allocated: {mem_allocated / 2**30:.2f} GB")
print(f"Peak GPU memory reserved: {mem_reserved / 2**30:.2f} GB")

Evaluating on Test Set: 100%|██████████| 1000/1000 [00:43<00:00, 22.82it/s]

Inference runtime: 39.25 ± 1.40 ms
Throughput: 25 img/s

Peak GPU memory allocated: 1.20 GB
Peak GPU memory reserved: 1.94 GB





**Model Size**

In [16]:
size = Path(BACKBONE_PATH).stat().st_size / (1000000)
print(f"Model size (MB): {size:.2f}")

Model size (MB): 905.85
