In [1]:
import numpy as np
import torch
from torchvision.transforms import v2
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, confusion_matrix, roc_curve

from preprocess import read_dicom_series, n4_bias_correction, get_bbox, crop
from model import OVFNet

  from .autonotebook import tqdm as notebook_tqdm


# Sample Data

In [8]:
patient_id = ["083", "084"]

dicom_path = [
    "./sample_dicom/083_T1",
    "./sample_dicom/084_T1"
]

# Keypoint data need to be pre-defined.
keypoint_data = [
    {'keyframe': 6, 'keypoints': [[397.64, 189.88], [368.43, 176.72], [385.3, 243.77], [348.27, 238.84], [363.49, 209.63], [389.41, 219.5]]},
    {'keyframe': 7, 'keypoints': [[143.39, 198.16], [145.86, 252.46], [180.82, 247.11], [169.31, 196.1], [155.32, 225.31], [170.13, 227.78]]}
]

metadata = [
    {"Age": 57, "Gender": 1, "BMD": -3.4, "PreFractureDrugIntake": 0, "PostFractureDrugIntake": 1},
    {"Age": 69, "Gender": 0, "BMD": -4.3, "PreFractureDrugIntake": 0, "PostFractureDrugIntake": 1}
]

Y = [0, 1]

# Model Checkpoint

In [6]:
use_ensemble = 1 # Set to 1 if probability ensemble is desired. Else 0. Just leave it as 1 unless you have less than 3 consecutive DICOM frames (which is unlikely) prepared for a sample.
use_clinical = 0 # Set to 1 if using metadata is desired. Else 0. Unlike ensemble, using clinical is not strictly better in performance. Your choice. I've prepared checkpoints for both 0 and 1.
ckpt_path = f"./best_checkpoint/use_clinical_{use_clinical}.pth"
optimal_threshold = [0.5644413232803345, 0.6101016998291016][use_clinical] # Determined in training stage.

In [7]:
network = OVFNet(use_clinical=use_clinical)
device = torch.device('cuda')
network = network.to(device)

checkpoint = torch.load(ckpt_path, map_location=device)
network.load_state_dict(checkpoint)

network.eval()

OVFNet(
  (image_encoder): _LoRA_ViT_timm(
    (lora_vit): VisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        (norm): Identity()
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (patch_drop): Identity()
      (norm_pre): Identity()
      (blocks): Sequential(
        (0): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): _LoRA_qkv_timm(
              (qkv): Linear(in_features=768, out_features=2304, bias=True)
              (linear_a_q): Linear(in_features=768, out_features=16, bias=False)
              (linear_b_q): Linear(in_features=16, out_features=768, bias=False)
              (linear_a_k): Linear(in_features=768, out_features=16, bias=False)
              (linear_b_k): Linear(in_features=16, out_features=768, bias=False)
              (linear_a_v): Linear(in_features=768, out_features=16, bias=False)
           

# Preprocess

In [9]:
transforms = v2.Compose([
    v2.ToImage(),
    v2.Resize(size=(224, 224), interpolation=v2.InterpolationMode.BICUBIC),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(
        mean=(0.48145466, 0.4578275, 0.40821073), # BiomedCLIP statistics: https://arxiv.org/abs/2303.00915
        std=(0.26862954, 0.26130258, 0.27577711))  # BiomedCLIP statistics: https://arxiv.org/abs/2303.00915
])

X = []

for i, _ in enumerate(patient_id):
    dp = dicom_path[i]
    img = read_dicom_series(dp)
    img = n4_bias_correction(img)

    keyframe = keypoint_data[i]["keyframe"]
    keypoints = keypoint_data[i]["keypoints"]
    keypoints = np.array(keypoints)
    bbox = get_bbox(keypoints, expand_ratio=1.5)

    slices = []

    for j in [keyframe-1, keyframe, keyframe+1]: # This is for prob ensemble below.
        slice = img[j, :, :]
        slice = crop(slice, bbox)
        slice = np.clip(slice, np.quantile(slice, 0.05), np.quantile(slice, 0.95))
        slice = (slice - slice.min()) / (slice.max() - slice.min())
        slice = (slice * 255).astype(np.uint8)
        slice = np.tile(slice[..., None], (1, 1, 3)) # The '3' here is for duplicating channels from grayscale (1) to color (3), NOT for ensemble's 3 frames. Don't be confused.
        slice = transforms(slice)
        slices.append(slice)

    x = torch.stack(slices)
    X.append(x)

In [10]:
age_mean = 72.77551020408163 # Pre-calculated value
age_std = 9.378789520714859 # Pre-calculated value
bmd_mean = -3.154285714285714 # Pre-calculated value
bmd_std = 1.0773905938902932 # Pre-calculated value

for md in metadata:
    md['Age'] = (md['Age'] - age_mean) / age_std
    md['BMD'] = (md['BMD'] - bmd_mean) / bmd_std

columns = ['Age', 'Gender', 'BMD', 'PreFractureDrugIntake', 'PostFractureDrugIntake']
M = [torch.tensor([md[c] for c in columns]) for md in metadata]

In [11]:
Y = torch.tensor(Y)

In [12]:
with torch.no_grad():
    outputs = []
    for x, m, y in zip(X, M, Y):
        m = m.unsqueeze(0)
        x, m, y = x.to(device), m.to(device), y.to(device)

        if use_ensemble:
            positive_probs = []

            for i in range(3):
                logit = network(x[i, None], m).squeeze()
                positive_prob = torch.sigmoid(logit)
                positive_probs.append(positive_prob)
                
            positive_prob = torch.stack(positive_probs, dim=0).mean(dim=0)
        else:
            logit = network(x[1, None], m).squeeze()
            positive_prob = torch.sigmoid(logit)
            
        outputs.append({'positive_prob': positive_prob, 'y_true': y})

In [13]:
for output in outputs:
    print(output)

{'positive_prob': tensor(0.1072, device='cuda:0'), 'y_true': tensor(0, device='cuda:0')}
{'positive_prob': tensor(0.7573, device='cuda:0'), 'y_true': tensor(1, device='cuda:0')}


In [14]:
metrics_all = []
pr_all = []

with torch.no_grad():
    positive_probs = torch.cat([output['positive_prob'].unsqueeze(0) for output in outputs]).cpu().numpy().squeeze()
    all_y_trues = torch.cat([output['y_true'].unsqueeze(0) for output in outputs]).cpu().numpy().squeeze()

    auc = roc_auc_score(all_y_trues, positive_probs)
    fpr, tpr, thresholds = roc_curve(all_y_trues, positive_probs)
    # ix = np.argmax(tpr - fpr)
    # optimal_threshold = thresholds[ix]

    preds_binary = (positive_probs  > optimal_threshold).astype(int)
    f1 = f1_score(all_y_trues, preds_binary)
    accuracy = accuracy_score(all_y_trues, preds_binary)

    conf_matrix = confusion_matrix(all_y_trues, preds_binary)
    TN, FP, FN, TP = conf_matrix.ravel()
    specificity = TN / (TN + FP) if (TN + FP) > 0 else 0
    sensitivity = TP / (TP + FN) if (TP + FN) > 0 else 0

    metrics = { 'auc': auc, 'accuracy': accuracy, 'specificity': specificity, 'sensitivity': sensitivity, 'f1': f1 }
    pr = { 'tpr': tpr, 'fpr': fpr }

    metrics_all.append(metrics)
    pr_all.append(pr)

In [15]:
for metric in metrics_all:
    print(metrics)

{'auc': np.float64(1.0), 'accuracy': 1.0, 'specificity': np.float64(1.0), 'sensitivity': np.float64(1.0), 'f1': 1.0}


In [16]:
for pr in pr_all:
    print(pr)

{'tpr': array([0., 1., 1.]), 'fpr': array([0., 0., 1.])}
