# Notebook Setup and Execution 

## Environment: 
- NVIDIA GH200 96GB GPU. If you are using H100, consider reducing batch size to 4.
- The training of the model is very slow on the full data. We recommend using a sample data.
- Moreover, the CPU inference of this will be equally slow slow, so GPU is recommended for both training and evaluation. 
- we provide the Jupyter notebook gemma3_dynamic_quant.ipynb. You may try to run the notebook as it is because we added the installation of the necessary libraries on top.

## Below are instructions and requirements for setting up and running the notebook: 
###  Environment & Dependencies: 
- Python 3.9 or above, with PyTorch and TorchVision installed 
- Hugging Face Transformers library (version 4.x) to load the Gemma 3 model or you can get the model from the correct transformers from google source code added on top.
- huggingface_hub library for model download is optional.
- OpenCV for image processing used to compute Sobel edges for edge density. 
- SciPy for some image filtering and entropy calculation. 
- scikit-image for computing Local Binary Patterns. 
- scikit-learn for metrics ROC AUC and an optional stratified splitter 
- Pandas for data handling and reading csv. 
- Tqdm for progress bars. 
- Matplotlib for plotting results. 

### Data Preparation: 
- Download MIMIC-CXR-JPG images and labels. You need PhysioNet credentials to access the dataset (https://physionet.org/content/mimic-cxr-jpg/2.1.0/). 
- Update the notebook configuration paths: 
- image_dir should point to the root folder containing the JPG files. 
- filenames_path can point to a text file or CSV listing the image filenames you intend to use 
- metadata_path should point to the CSV file with labels. 
- Make sure you have sufficient disk space for the data and the model.
 
### Running the Notebook: 
- Start the Jupyter environment and open gemma3_dynamic_quant.ipynb. 
- Step through the notebook cells in order. The notebook is organized into sections.

## Set up

This was trained on a H200 GPU on Lambda cloud

### Install libraries

In [1]:
!pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3

%cd /home/ubuntu/data

!pip install scikit-image
!pip install opencv-python-headless

!pip install opencv-python
!pip install scikit-image
!pip install --upgrade pandas
!pip install "numpy<2"
!pip install iterative-stratification

Defaulting to user installation because normal site-packages is not writeable
Collecting git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3
  Cloning https://github.com/huggingface/transformers (to revision v4.49.0-Gemma-3) to /tmp/pip-req-build-is3c7j2q
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /tmp/pip-req-build-is3c7j2q
  Running command git checkout -q 367bab469b0ef32017e2a0a0a5dbac5d36002f03
  Resolved https://github.com/huggingface/transformers to commit 367bab469b0ef32017e2a0a0a5dbac5d36002f03
  Installing build dependencies ... [?2done
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
/home/ubuntu/data
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaultin

### AUTHENTICATE TO HUGGINGFACE

In [None]:
from huggingface_hub import login
HF_TOKEN = ""
login(token=HF_TOKEN)

  from .autonotebook import tqdm as notebook_tqdm


### Imports

In [3]:
import time
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True 
import cv2
from scipy import ndimage, stats
from skimage import feature

from torch.utils.data import Dataset, DataLoader, random_split, WeightedRandomSampler
import torchvision.transforms as transforms
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
from torchvision.models import vit_b_16

from sklearn.metrics import roc_auc_score
from collections import Counter
import os, copy, gc, torch, random
torch.backends.quantized.engine = 'qnnpack'
from torch.ao.quantization import quantize_dynamic
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit


2025-05-08 23:44:29.094815: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746747869.104161   28805 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746747869.109059   28805 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


### BASIC CONFIGRATION SET UP

In [4]:
class Config:

    filenames_path = "/home/ubuntu/data/chex/IMAGE_FILENAMES_UPDATE"
    image_dir = "/home/ubuntu/data/chex/mimic-cxr-jpg/2.1.0/"
    metadata_path = "/home/ubuntu/data/chex/mimic-cxr-2.0.0-metadata.csv"
    chexpert_path = "/home/ubuntu/data/chex/mimic-cxr-2.0.0-chexpert.csv"
    negbio_path = "/home/ubuntu/data/chex/mimic-cxr-2.0.0-negbio.csv"

    batch_size = 8
    img_size = 224
    num_classes = 14
    max_seq_length = 256

    # I added the 3 configuration below for the uniform quantization
    num_epochs = 1
    lr = 1e-5
    uniform_bits = 4

### MULTIMODAL MODEL DEFINITION WITH QUANTIZATION

In [5]:
class MultimodalCheXpertModel(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        #  ViT from torchvision Image encoder 
        self.image_encoder = vit_b_16(pretrained=True)
        self.image_encoder.heads = torch.nn.Identity()  # Remove classification head
        self.image_proj = torch.nn.Linear(768, 512)

        # using Gemma3ForConditionalGeneration fo Text encoding
        self.text_encoder = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
        hidden_size_text = self.text_encoder.config.text_config.hidden_size
        self.text_proj = torch.nn.Linear(hidden_size_text, 512)

        # Classifier head for text features
        self.classifier_head = torch.nn.Linear(512, config.num_classes)

        # Multimodal fusion classifier
        self.classifier = torch.nn.Linear(1024, config.num_classes)

    def forward(self, inputs, quant_level=None):
        img_features = self.image_encoder(inputs["image"])
        img_features = self.image_proj(img_features)

        text_output = self.text_encoder(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            output_hidden_states=True,
            return_dict=True
        )
        text_features = text_output.hidden_states[-1][:, 0]
        projected_text_features = self.text_proj(text_features)

        # Fuse modalities by concatenating image and text features 
        combined = torch.cat((img_features, projected_text_features), dim=1)
        
        # Quantized the features we combined above.
        if quant_level is not None:
            level_to_bits = {0: 1, 1: 2, 2: 3, 3: 4}
            bitwidth = level_to_bits.get(quant_level, 4)
            combined = quantize_tensor(combined, bitwidth)
        
        return self.classifier(combined)

### DATASET DEFINITION AND PREPROCESSING

In [6]:
class CXRDataset(Dataset):
    def __init__(self, config):
        self.config = config

        # Loading filenames & metadata
        self.filenames = pd.read_csv(config.filenames_path, header=None, names=["filename"])
        self.filenames["dicom_id"] = (
            self.filenames["filename"]
            .str.split("/")
            .str[-1]
            .str.split(".")
            .str[0]
        )
        self.metadata = pd.read_csv(config.metadata_path)
        self.metadata["dicom_id"] = self.metadata["dicom_id"].astype(str)

        # Merging filenames and metadata
        merged = self.filenames.merge(self.metadata, on="dicom_id", how="inner")
        merged["subject_id"] = pd.to_numeric(merged["subject_id"])
        merged["study_id"]   = pd.to_numeric(merged["study_id"])

        # Merging CheXpert labels
        self.chexpert = pd.read_csv(config.chexpert_path)
        self.chexpert[["subject_id","study_id"]] = self.chexpert[["subject_id","study_id"]].apply(pd.to_numeric)
        merged = merged.merge(self.chexpert, on=["subject_id","study_id"], how="inner")

        # Merging NegBio labels
        self.negbio = pd.read_csv(config.negbio_path)
        self.negbio[["subject_id","study_id"]] = self.negbio[["subject_id","study_id"]].apply(pd.to_numeric)
        data = merged.merge(self.negbio, on=["subject_id","study_id"], suffixes=("", "_negbio"), how="inner")

        if data.empty:
            raise ValueError("Data merging failed: check DICOM IDs and label files.")

        # Identifying & clean multi‑label columns once
        self.label_cols = [
            col for col in self.chexpert.columns
            if col not in ("subject_id", "study_id")
        ]
        data[self.label_cols] = (
            data[self.label_cols]
            .fillna(0)
            .replace(-1, 0)
        )

        # Filtering out image files that are missing
        data = data[data["filename"].apply(
            lambda fn: os.path.exists(os.path.join(config.image_dir, fn))
        )]

        self.data = data.reset_index(drop=True)

        self.transform = transforms.Compose([
            transforms.Resize((config.img_size, config.img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
        ])
        self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-it")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # Coverting Images to RGB
        img = Image.open(os.path.join(self.config.image_dir, row["filename"])).convert("RGB")
        image = self.transform(img)

        # Texting prompt and encoding
        text_prompt = (
            f"Findings: {row['PerformedProcedureStepDescription']}  "
            f"View: {row['ViewPosition']}  "
            f"Orientation: {row.get('PatientOrientationCodeSequence_CodeMeaning','Unknown')}"
        )
        enc = self.tokenizer(
            text_prompt,
            max_length=self.config.max_seq_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        input_ids      = enc["input_ids"].squeeze(0)
        attention_mask = enc["attention_mask"].squeeze(0)

        labels = torch.FloatTensor(
            row[self.label_cols].values.astype(float)
        )

        # text complexity
        text_complexity = compute_text_entropy(text_prompt)

        return {
            "image": image,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "text_prompt": text_prompt,
            "text_complexity": text_complexity
        }

## COMPLEXITY METRICS

### COMPLEXITY METRIC FUNCTIONS FOR IMAGES

In [7]:
def compute_shannon_entropy(pil_image):
    gray = pil_image.convert("L")
    np_gray = np.array(gray)
    histogram, _ = np.histogram(np_gray, bins=256, range=(0, 255))
    prob = histogram / (histogram.sum() + 1e-7)
    entropy_value = -np.sum(prob * np.log2(prob + 1e-7))
    return float(entropy_value)

def compute_edge_density(pil_image, threshold=20):
    gray = np.array(pil_image.convert("L"), dtype=np.float32)
    dx = ndimage.sobel(gray, axis=0)
    dy = ndimage.sobel(gray, axis=1)
    mag = np.hypot(dx, dy)
    edge_pixels = np.sum(mag > threshold)
    density = edge_pixels / gray.size
    return float(density)

def compute_intensity_variation(pil_image):
    gray = np.array(pil_image.convert("L"), dtype=np.float32)
    return float(np.std(gray))

def compute_fractal_dimension(pil_image, threshold=128):
    gray = np.array(pil_image.convert("L"))
    binary = gray < threshold

    def boxcount(Z, k):
        S = np.add.reduceat(
            np.add.reduceat(Z, np.arange(0, Z.shape[0], k), axis=0),
            np.arange(0, Z.shape[1], k), axis=1)
        return np.sum((S > 0) & (S < k*k))

    p = min(binary.shape)
    n = 2**np.floor(np.log2(p))
    n = int(n)
    Z = binary[:n, :n]
    sizes = 2**np.arange(int(np.log2(n)), 1, -1)
    counts = [boxcount(Z, size) for size in sizes]
    coeffs = np.polyfit(np.log(sizes), np.log(counts), 1)
    return -coeffs[0]

def compute_lbp_complexity(pil_image, radius=1, n_points=8):
    gray = np.array(pil_image.convert("L"))
    lbp = feature.local_binary_pattern(gray, n_points, radius, method="uniform")
    hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, n_points + 3), range=(0, n_points + 2))
    hist = hist.astype("float")
    s = hist.sum() + 1e-7
    hist /= s
    lbp_entropy = -np.sum(hist * np.log2(hist + 1e-7))
    if not np.isfinite(lbp_entropy):
        lbp_entropy = 0.0
    return float(lbp_entropy)

### COMPLEXITY METRIC FUNCTION FOR TEXT - entropy based on character frequencies

In [8]:
def compute_text_entropy(text):
    text = text.lower()
    freq = {}
    for char in text:
        if char.isalnum():
            freq[char] = freq.get(char, 0) + 1
    total = sum(freq.values())
    if total == 0:
        return 0.0
    entropy = -sum((count/total) * np.log2(count/total + 1e-7) for count in freq.values())
    return float(entropy)


### DATA DESCRIPTION & STATISTICS

In [None]:
config = Config()
full_ds = CXRDataset(config)

sample_size = len(full_ds)
indices = random.sample(range(len(full_ds)), sample_size)

records = []
for idx in indices:
    row = full_ds.data.iloc[idx]
    img_path = os.path.join(config.image_dir, row["filename"])
    image = Image.open(img_path).convert("RGB")
    
    sh = compute_shannon_entropy(image)
    ed = compute_edge_density(image)
    iv = compute_intensity_variation(image)
    fd = compute_fractal_dimension(image)
    lbp = compute_lbp_complexity(image)
    
    text_prompt = (
        f"Findings: {row['PerformedProcedureStepDescription']} "
        f"View: {row['ViewPosition']} "
        f"Orientation: {row.get('PatientOrientationCodeSequence_CodeMeaning','Unknown')}"
    )
    te = compute_text_entropy(text_prompt)
    
    records.append({
        "shannon_entropy": sh,
        "edge_density": ed,
        "intensity_variation": iv,
        "fractal_dimension": fd,
        "lbp_entropy": lbp,
        "text_entropy": te
    })

df = pd.DataFrame(records)

  coeffs = np.polyfit(np.log(sizes), np.log(counts), 1)


#### Plot distributions

In [None]:
for col in df.columns:
    plt.figure()
    plt.hist(df[col], bins=30)
    plt.title(f"Distribution of {col.replace('_', ' ').title()}")
    plt.xlabel(col.replace('_', ' ').title())
    plt.ylabel("Frequency")
    plt.tight_layout()

#### Scatter: image vs text entropy

In [None]:
plt.figure()
plt.scatter(df["shannon_entropy"], df["text_entropy"])
plt.title("Image Shannon Entropy vs Text Entropy")
plt.xlabel("Image Shannon Entropy")
plt.ylabel("Text Entropy")
plt.tight_layout()

#### Label prevalence bar chart

In [None]:
label_counts = full_ds.data[full_ds.label_cols].sum()
plt.figure(figsize=(12,4))
plt.bar(label_counts.index, label_counts.values)
plt.xticks(rotation=90)
plt.title("CheXpert Label Prevalence")
plt.ylabel("Count")
plt.tight_layout()
plt.show()

## QUANTIZATION

### QUANTIZATION FUNCTION (for scalar complexity values)

In [None]:
def quantize_complexity(value, min_val, max_val, levels=4):
    if max_val == min_val:
        return 0
    step = (max_val - min_val) / levels
    level = int((value - min_val) / step)
    if level >= levels:
        level = levels - 1
    return level

### UNIFORM TENSOR QUANTIZATION FUNCTION

In [None]:
def quantize_tensor(tensor, num_bits):
    min_val = tensor.min()
    max_val = tensor.max()
    qmin = 0
    qmax = 2**num_bits - 1
    if max_val == min_val:
        return tensor
    scale = (max_val - min_val) / (qmax - qmin)
    tensor_q = torch.round((tensor - min_val) / scale)
    tensor_dq = tensor_q * scale + min_val
    return tensor_dq

## MAIN SCRIPT: TRAINING, COMPLEXITY ANALYSIS, INFERENCE, AND COMPARISON GRAPH

In [None]:
if __name__ == "__main__":
    config  = Config()
    full_ds = CXRDataset(config)
    N = len(full_ds)

    # print(full_ds.data.columns.tolist())
  
    label_cols = [
        c for c in full_ds.chexpert.columns 
        if c not in ("subject_id", "study_id")
    ]

    # extracting multi‑label array for stratification
    Y = full_ds.data[label_cols].values
    
    splitter = MultilabelStratifiedShuffleSplit(
        n_splits=1, test_size=0.10, random_state=42
    )
    train_idx, test_idx = next(splitter.split(X=np.zeros(N), y=Y))
    
    train_ds = torch.utils.data.Subset(full_ds, train_idx)
    test_ds  = torch.utils.data.Subset(full_ds, test_idx)
    
    class_counts  = Y.sum(axis=0)
    class_weights = 1.0 / (class_counts + 1e-6)
    sample_weights = (Y * class_weights).sum(axis=1)
    train_weights  = sample_weights[train_idx]

    # We are making sure to use a stratified data
    sampler = WeightedRandomSampler(
        weights=train_weights,
        num_samples=len(train_weights),
        replacement=True
    )
    
    train_loader    = DataLoader(
        train_ds,
        batch_size=config.batch_size,
        sampler=sampler,
        num_workers=4
    )
    val_loader      = DataLoader(
        test_ds,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=4
    )
    analysis_loader = DataLoader(
        test_ds,
        batch_size=1,
        shuffle=False,
        num_workers=2
    )
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # initializing model, optimizer, loss 
    model     = MultimodalCheXpertModel(config).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    criterion = torch.nn.BCEWithLogitsLoss()
    
    # Training and Validation loops 
    num_epochs = 5
    train_losses, train_accs = [], []
    val_losses,   val_accs   = [], []
    
    for epoch in range(1, num_epochs+1):
        model.train()
        tl, ta = [], []
        for batch in train_loader:
            batch = {k:(v.to(device) if isinstance(v, torch.Tensor) else v)
                     for k,v in batch.items()}
            out   = model(batch)
            loss  = criterion(out, batch["labels"])
            preds = (torch.sigmoid(out) > 0.5).float()
            acc   = (preds == batch["labels"]).float().mean().item()
    
            tl.append(loss.item()); ta.append(acc)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step(); optimizer.zero_grad()
    
        train_losses.append(np.mean(tl))
        train_accs.append(np.mean(ta))
    
        # Validation
        model.eval()
        vl, va = [], []
        with torch.no_grad():
            for batch in val_loader:
                batch = {k:(v.to(device) if isinstance(v, torch.Tensor) else v)
                         for k,v in batch.items()}
                out   = model(batch)
                loss  = criterion(out, batch["labels"])
                preds = (torch.sigmoid(out) > 0.5).float()
                acc   = (preds == batch["labels"]).float().mean().item()
    
                vl.append(loss.item()); va.append(acc)
    
        val_losses.append(np.mean(vl))
        val_accs.append(np.mean(va))
    
        print(
            f"Epoch {epoch} ▶ "
            f"train_loss={train_losses[-1]:.4f}, train_acc={train_accs[-1]:.4f}  |  "
            f"val_loss={val_losses[-1]:.4f},   val_acc={val_accs[-1]:.4f}"
        )
    
    # Plots
    epochs = range(1, num_epochs+1)
    plt.figure(figsize=(12,5))
    
    plt.subplot(1,2,1)
    plt.plot(epochs, train_losses, label="Train")
    plt.plot(epochs, val_losses,   label="Test")
    plt.title("Loss vs Epoch"); plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend()
    
    plt.subplot(1,2,2)
    plt.plot(epochs, train_accs, label="Train")
    plt.plot(epochs, val_accs,   label="Test")
    plt.title("Hamming Accuracy vs Epoch"); plt.xlabel("Epoch"); plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Complexity analysis on test split
    unnorm = transforms.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
        std =[1/0.229,    1/0.224,    1/0.225]
    )
    
    img_ent, txt_ent, combined = [], [], []
    for sample in tqdm(analysis_loader, desc="Computing Complexity"):
        img_t = sample["image"].squeeze(0)
        img_u = unnorm(img_t).clamp(0,1)
        img_p = transforms.ToPILImage()(img_u)
        ie    = compute_shannon_entropy(img_p)
        img_ent.append(ie)
    
        # Text entropy
        te = sample["text_complexity"]
        te = te.item() if torch.is_tensor(te) else te
        txt_ent.append(te)
    
        combined.append((ie + te) / 2)
    
    mn, mx = min(combined), max(combined)
    quant_levels = [
        quantize_complexity(x, mn, mx, levels=4) for x in combined
    ]
    
    # Inference metrics on test split
    model.eval()
    y_true, p_full, p_q = [], [], []
    reg_times, q_times = [], []
    reg_accs, q_accs = [], []
    errors_by_level = {i: [] for i in range(4)}

    with torch.no_grad():
        for sample, lvl in zip(analysis_loader, quant_levels):
            batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v)
                     for k, v in sample.items()}
            y = batch["labels"]

            # 32 bits evaluation
            torch.cuda.synchronize(); t0 = time.time()
            o0 = model(batch, quant_level=None)
            torch.cuda.synchronize(); reg_times.append(time.time() - t0)

            # Quantized model evaluation
            torch.cuda.synchronize(); t1 = time.time()
            o1 = model(batch, quant_level=lvl)
            torch.cuda.synchronize(); q_times.append(time.time() - t1)

            p0 = torch.sigmoid(o0).cpu().numpy()
            p1 = torch.sigmoid(o1).cpu().numpy()
            y_true.append(y.cpu().numpy()); p_full.append(p0); p_q.append(p1)

            pr0 = (p0 > 0.5).astype(int); pr1 = (p1 > 0.5).astype(int)
            reg_accs.append((pr0 == y.cpu().numpy()).mean())
            q_accs.append((pr1 == y.cpu().numpy()).mean())

            err = np.abs(o0.cpu().numpy() - o1.cpu().numpy()).mean()
            errors_by_level[lvl].append(err)

    #  per class ROC‑AUC
    Y  = np.vstack(y_true)
    PF = np.vstack(p_full)
    PQ = np.vstack(p_q)

    auc_full  = np.full(len(label_cols), np.nan, dtype=float)
    auc_quant = np.full(len(label_cols), np.nan, dtype=float)

    for i in range(len(label_cols)):
        col_true = Y[:, i]
        if len(np.unique(col_true)) == 2:
            auc_full[i]  = roc_auc_score(col_true, PF[:, i])
            auc_quant[i] = roc_auc_score(col_true, PQ[:, i])

    # we calculate the mean‑5 and macro‑14 and ignoring NaNs so we can compare the result with other papers, although those papers also included question/answering featres
    chex5 = ["Atelectasis","Cardiomegaly","Consolidation","Edema","Pleural Effusion"]
    idx5  = [label_cols.index(l) for l in chex5]
    mean5_full  = np.nanmean(auc_full[idx5])
    mean5_quant = np.nanmean(auc_quant[idx5])
    macro_full  = np.nanmean(auc_full)
    macro_quant = np.nanmean(auc_quant)

    print(f"Mean‑5 AUROC:  full={mean5_full:.3f}, quant={mean5_quant:.3f}")
    print(f"Macro‑14 AUROC: full={macro_full:.3f}, quant={macro_quant:.3f}")


## Model‑size & weight‑dtype comparison: FP32/BF16 vs INT8‑weights

In [None]:
def model_stats(model, name="model"):
    n_params     = sum(p.numel() for p in model.parameters())
    n_trainable  = sum(p.numel() for p in model.parameters() if p.requires_grad)
    weight_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
    return {
        "name"       : name,
        "params (M)" : f"{n_params/1e6:7.2f}",
        "trainable M": f"{n_trainable/1e6:7.2f}",
        "mem (MB)"   : f"{weight_bytes/1024**2:8.1f}",
    }

def print_table(*rows):
    cols = list(rows[0].keys())
    widths = {c: max(len(c), *(len(str(r[c])) for r in rows)) for c in cols}
    fmt = "  ".join(f"{{:{w}}}" for w in widths.values())
    print(fmt.format(*cols))
    print(fmt.format(*("-"*w for w in widths.values())))
    for r in rows:
        print(fmt.format(*(r[c] for c in cols)))
    print()

fp_stats = model_stats(model, "Full‑precision")

print("\nCreating int8 dynamic‑quantised copy")
quantised_model = quantize_dynamic(
    copy.deepcopy(model).cpu(), 
    {torch.nn.Linear},
    dtype=torch.qint8
)
int8_stats = model_stats(quantised_model, "INT8‑weights")

ckpt_fp = "checkpoint_full.pt"
ckpt_i8 = "checkpoint_int8.pt"
torch.save(model.state_dict(),      ckpt_fp)
torch.save(quantised_model.state_dict(), ckpt_i8)

for p in (ckpt_fp, ckpt_i8):
    size_mb = os.path.getsize(p)/(1024**2)
    print(f"{p:<20} {size_mb:8.1f} MB")

print("\n===  Model‑size comparison  ===")
print_table(fp_stats, int8_stats)

# freeing CPU RAM used by the quantised copy for the next training
del quantised_model
gc.collect()


# 2 bits quant

### REDEFINING CONFIG FOR UNIFORM QUANTIZATION

In [None]:
class Config:

    filenames_path = "/home/ubuntu/data/chex/IMAGE_FILENAMES_UPDATE"
    image_dir = "/home/ubuntu/data/chex/mimic-cxr-jpg/2.1.0/"
    metadata_path = "/home/ubuntu/data/chex/mimic-cxr-2.0.0-metadata.csv"
    chexpert_path = "/home/ubuntu/data/chex/mimic-cxr-2.0.0-chexpert.csv"
    negbio_path = "/home/ubuntu/data/chex/mimic-cxr-2.0.0-negbio.csv"

    batch_size = 4
    img_size = 224
    num_classes = 14
    max_seq_length = 256

    # I added the 3 configuration below for the uniform quantization
    num_epochs = 1
    lr = 1e-5
    uniform_bits = 4

## Uniform Quantization
Uniformly quantize a float tensor to num_bits and dequantize in-place

In [None]:
def quantize_tensor(tensor: torch.Tensor, num_bits: int) -> torch.Tensor:
    mn, mx = tensor.min(), tensor.max()
    if mx == mn:
        return tensor
    qmin, qmax = 0, 2**num_bits - 1
    scale = (mx - mn) / (qmax - qmin)
    q = ( (tensor - mn) / scale ).round().clamp(qmin, qmax)
    return q * scale + mn

## CHEXPERT MML SET UP

In [None]:
class MultimodalCheXpertModel(torch.nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        # encoding images, removed head, encde text and add classifier of our configuration
        self.img_enc  = vit_b_16(pretrained=True)
        self.img_enc.heads = torch.nn.Identity()
        self.img_proj = torch.nn.Linear(768,512)
        self.txt_enc  = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
        hs = self.txt_enc.config.text_config.hidden_size
        self.txt_proj  = torch.nn.Linear(hs,512)
        self.classifier = torch.nn.Linear(1024, cfg.num_classes)

    def forward(self, batch):
        xi = self.img_enc(batch["image"])
        xi = self.img_proj(xi)
        to = self.txt_enc(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            output_hidden_states=True, return_dict=True
        )
        xt = to.hidden_states[-1][:,0]
        xt = self.txt_proj(xt)
        fused = torch.cat([xi, xt], dim=1)
        return self.classifier(fused)

## Metrics and helpers

In [None]:
def model_stats(model, name:str):
    ps = list(model.parameters())
    total = sum(p.numel() for p in ps)
    train = sum(p.numel() for p in ps if p.requires_grad)
    mem   = sum(p.numel()*p.element_size() for p in ps)/1024**2
    return {"name":name,
            "params (M)":f"{total/1e6:7.2f}",
            "trainable M":f"{train/1e6:7.2f}",
            "mem (MB)":f"{mem:8.1f}"}

def print_table(*rows):
    keys   = list(rows[0].keys())
    widths = {k:max(len(k),*(len(str(r[k])) for r in rows)) for k in keys}
    fmt    = "  ".join(f"{{:{widths[k]}}}" for k in keys)
    print(fmt.format(*keys))
    print(fmt.format(*("-"*widths[k] for k in keys)))
    for r in rows: print(fmt.format(*(r[k] for k in keys)))
    print()


## Main class for uniform quant

In [None]:
def main():
    cfg    = Config()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ds = CXRDataset(cfg)
    dl = DataLoader(ds, batch_size=cfg.batch_size, shuffle=True, num_workers=4)

    # defining model, opt, loss
    model     = MultimodalCheXpertModel(cfg).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
    criterion = torch.nn.BCEWithLogitsLoss()

    model.train()
    for batch in tqdm(dl, desc="Training Epoch"):
        batch = {k:(v.to(device) if isinstance(v,torch.Tensor) else v)
                 for k,v in batch.items()}
        out  = model(batch)
        loss = criterion(out, batch["labels"])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step(); optimizer.zero_grad()

    # We can move the model to the cpu
    model_cpu = copy.deepcopy(model).cpu()
    for _, p in model_cpu.named_parameters():
        with torch.no_grad():
            p.data = torch.nan_to_num(p.data,
                                      nan=0.0,
                                      posinf=torch.finfo(p.dtype).max,
                                      neginf=torch.finfo(p.dtype).min)

    #  and display statistics of full precision 
    fp_stats = model_stats(model_cpu, "Full-precision")
    torch.save(model_cpu.state_dict(), "checkpoint_fp32.pt")

    # as well as the static uniform quant.
    model_uq = copy.deepcopy(model_cpu)
    B = cfg.uniform_bits
    for m in model_uq.modules():
        if isinstance(m, torch.nn.Linear):
            with torch.no_grad():
                m.weight.data = quantize_tensor(m.weight.data, B)
                if m.bias is not None:
                    m.bias.data   = quantize_tensor(m.bias.data, B)

    uq_stats = model_stats(model_uq, f"Uniform-{B}bit")
    torch.save(model_uq.state_dict(), f"checkpoint_uq{B}.pt")

    # then we can do a size comparison
    print("\n=== Model-size comparison ===")
    print_table(fp_stats, uq_stats)

    #  by displaying the CPU evaluation metrics. Note that this take a longtime
    model_cpu.eval(); model_uq.eval()
    y_true, p_fp, p_uq = [], [], []
    h_fp, h_uq, mae_e = [], [], []

    for batch in tqdm(dl, desc="Final eval"):
        batch = {k:(v.cpu() if isinstance(v,torch.Tensor) else v)
                 for k,v in batch.items()}
        Y    = batch["labels"].numpy()
        o_fp = model_cpu(batch)
        o_u  = model_uq(batch)
    
        prob_fp = torch.sigmoid(o_fp).detach().numpy()
        prob_uq = torch.sigmoid(o_u).detach().numpy()
    
        y_true.append(Y)
        p_fp.append(prob_fp)
        p_uq.append(prob_uq)
    
        pred_fp = (prob_fp > 0.5).astype(int)
        pred_uq = (prob_uq > 0.5).astype(int)
        h_fp.append((pred_fp == Y).mean())
        h_uq.append((pred_uq == Y).mean())
    
        mae_e.append(np.abs(o_fp.detach().numpy() - o_u.detach().numpy()).mean())

    # Stack and compute AUC
    Y_all = np.vstack(y_true)
    PF    = np.vstack(p_fp)
    PU    = np.vstack(p_uq)
    auc_fp = roc_auc_score(Y_all, PF, average=None)
    auc_u  = roc_auc_score(Y_all, PU, average=None)

    core5    = [ ds.label_cols.index(l) for l in
                 ["Atelectasis","Cardiomegaly","Consolidation","Edema","Pleural Effusion"] ]
    mean5fp  = auc_fp[core5].mean()
    mean5uq  = auc_u[ core5].mean()
    macro_fp = auc_fp.mean()
    macro_u  = auc_u.mean()
    h_fp_all = np.mean(h_fp)
    h_uq_all = np.mean(h_uq)
    mae_all  = np.mean(mae_e)

    print("\n=== Final metrics ===")
    print(f"Hamming-Acc    | FP32 = {h_fp_all:.4f},  UQ = {h_uq_all:.4f}")
    print(f"Mean-5  AUROC  | FP32 = {mean5fp:.4f},     UQ = {mean5uq:.4f}")
    print(f"Macro-14 AUROC | FP32 = {macro_fp:.4f},     UQ = {macro_u:.4f}")
    print(f"Logit-MAE      |                    {mae_all:.4f}")

    # cleaning the model from the CPU
    del model_cpu, model_uq
    gc.collect()

if __name__ == "__main__":
    main()
