# Contents
<div style="font-size: 16px;">
    This notebook evaluates how trained disc detection models perform on test sets, with respect to each spinal diagnosed condition. We take the model trained from the 'best epoch' that resulted in highest validation precision/recall.
</div>


# Install libaries

In [1]:
!pip install pydicom -q
!pip install torch==2.1 torchvision==0.16 -q
!pip install -qU pycocotools
!pip install -qU wandb

# Configs

In [2]:
for name in list(globals()):
    if not name.startswith("_"):  # Avoid deleting built-in and special variables
        del globals()[name]

In [3]:
import os
import time
import random
from datetime import datetime
import numpy as np
import collections

from matplotlib import animation, rc
import pandas as pd

import matplotlib.patches as patches
import matplotlib.pyplot as plt

import tqdm
import sys
import torch

In [4]:
# Device configuration
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda', index=0)

## Install and load libraries

In [5]:
import os
import time
import random
from datetime import datetime
import numpy as np
import collections

from matplotlib import animation, rc
import pandas as pd

import matplotlib.patches as patches
import matplotlib.pyplot as plt

import tqdm
import sys
import torch

## Directories

In [6]:
# Directories
PROJECT_DIR = '/home/jupyter'
DATA_DIR = os.path.join(PROJECT_DIR, 'data')
SRC_DIR = os.path.join(PROJECT_DIR, 'src')

## Functions

In [7]:
with open(os.path.join(SRC_DIR, 'pipeline_disc_detection.py')) as file:
    exec(file.read())

# Evaluation on test sets

In [8]:
def evaluate_on_test_set(condition, epoch, config):
    
    MODEL_DIR = os.path.join(PROJECT_DIR, 'models', '02_train_disc_detection', condition)

    # Read in metadata
    test_df = pd.read_csv(os.path.join(DATA_DIR, 'processed_metadata', condition, 'test.csv'))
    
    # Create data loder
    dataset_test = RSNAMultipleBBoxesDataset(test_df, w = config['box_w'], h_l1_l4 = config['box_h_l1_l4'], h_l5 = config['box_h_l5'])
    test_loader = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=config['batch_size'],
        shuffle=True,
        collate_fn=utils.collate_fn,
        num_workers=os.cpu_count()
    )
    
    # Load model
    trained_model = load_model_disc_detection(state_dict=torch.load(os.path.join(f"{MODEL_DIR}/epoch_{epoch}/model_dict.pt"))).to(device)
    
    # Evaluate on test set (this evaluate function is from https://raw.githubusercontent.com/pytorch/vision/main/references/detection/engine.py)
    evaluate(trained_model, test_loader, device=device)

## Spinal Canal Stenosis (captured by Sagittal T2/STIR images)

In [9]:
CONDITION = 'SpinalCanalStenosis'
BEST_EPOCH = 1

CONFIG = dict(
    num_epochs=3,
    batch_size=10,
    lr=0.0001,
    lr_step_size=3,
    lr_gamma=0.1,
    box_w = 70, # width of the bounding boxes
    box_h_l1_l4 = 30, # height of the boxes for levels from L1/L2 to L4/L5
    box_h_l5 = 40 # width of the boxes for level L5/S1
)

evaluate_on_test_set(condition = CONDITION, epoch = BEST_EPOCH, config = CONFIG)

creating index...
index created!
Test:  [ 0/38]  eta: 0:02:02  model_time: 2.4348 (2.4348)  evaluator_time: 0.0258 (0.0258)  time: 3.2106  data: 0.7474  max mem: 5491
Test:  [37/38]  eta: 0:00:01  model_time: 1.1387 (1.1582)  evaluator_time: 0.0244 (0.0250)  time: 1.1959  data: 0.0271  max mem: 5499
Test: Total time: 0:00:46 (1.2321 s / it)
Averaged stats: model_time: 1.1387 (1.1582)  evaluator_time: 0.0244 (0.0250)
Accumulating evaluation results...
DONE (t=0.23s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.764
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.942
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.924
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.764
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ Io

## Left Neural Foraminal Narrowing (captured by Sagittal T1 images)

In [10]:
CONDITION = 'LeftNeuralForaminalNarrowing'
BEST_EPOCH = 1

CONFIG = dict(
    num_epochs=3,
    batch_size=10,
    lr=0.0001,
    lr_step_size=3,
    lr_gamma=0.1,
    box_w = 70, # width of the bounding boxes
    box_h_l1_l4 = 30, # height of the boxes for levels from L1/L2 to L4/L5
    box_h_l5 = 40 # width of the boxes for level L5/S1
)

evaluate_on_test_set(condition = CONDITION, epoch = BEST_EPOCH, config = CONFIG)

creating index...
index created!
Test:  [ 0/62]  eta: 0:01:57  model_time: 1.2168 (1.2168)  evaluator_time: 0.0272 (0.0272)  time: 1.9001  data: 0.6537  max mem: 5499
Test:  [61/62]  eta: 0:00:01  model_time: 1.3449 (1.2649)  evaluator_time: 0.0256 (0.0307)  time: 1.3978  data: 0.0266  max mem: 5499
Test: Total time: 0:01:23 (1.3389 s / it)
Averaged stats: model_time: 1.3449 (1.2649)  evaluator_time: 0.0256 (0.0307)
Accumulating evaluation results...
DONE (t=0.35s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.695
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.943
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.862
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.695
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ Io

## Right Neural Foraminal Narrowing (captured by Sagittal T1 images)

In [11]:
CONDITION = 'RightNeuralForaminalNarrowing'
BEST_EPOCH = 1

CONFIG = dict(
    num_epochs=3,
    batch_size=10,
    lr=0.0001,
    lr_step_size=3,
    lr_gamma=0.1,
    box_w = 70, # width of the bounding boxes
    box_h_l1_l4 = 30, # height of the boxes for levels from L1/L2 to L4/L5
    box_h_l5 = 40 # width of the boxes for level L5/S1
)

evaluate_on_test_set(condition = CONDITION, epoch = BEST_EPOCH, config = CONFIG)

creating index...
index created!
Test:  [ 0/62]  eta: 0:02:10  model_time: 1.3058 (1.3058)  evaluator_time: 0.0277 (0.0277)  time: 2.0987  data: 0.7626  max mem: 5499
Test:  [61/62]  eta: 0:00:01  model_time: 1.3005 (1.3117)  evaluator_time: 0.0255 (0.0302)  time: 1.3269  data: 0.0260  max mem: 5499
Test: Total time: 0:01:25 (1.3842 s / it)
Averaged stats: model_time: 1.3005 (1.3117)  evaluator_time: 0.0255 (0.0302)
Accumulating evaluation results...
DONE (t=0.35s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.721
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.963
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.894
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.721
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ Io

# Crop predicted boxes for testing severity classification

In [12]:
LABELS_DICT = {
    1: "L1_L2",
    2: "L2_L3",
    3: "L3_L4",
    4: "L4_L5",
    5: "L5_S1"
}

In [13]:
def get_best_boxes(pred):
    best_boxes = {}

    for box, label, score in zip(pred['boxes'], pred['labels'], pred['scores']):
        if label.item() not in best_boxes or score > best_boxes[label.item()]['score']:
            best_boxes[label.item()] = {'box': box.tolist(), 'score': score.item()}

    result = {
        'boxes': [entry['box'] for entry in best_boxes.values()],
        'labels': list(best_boxes.keys()),
        'scores': [entry['score'] for entry in best_boxes.values()]
    }

    return result

def crop_bbox(image, bbox):
    x0, y0, x1, y1 = bbox

    cropped_img = torchvision.transforms.functional.crop(
        image,
        top=round(int(y0)),
        left=round(int(x0)),
        height=round(int(y1 - y0)),
        width=round(int(x1 - x0))
    )
    return cropped_img


def plot_crop(image, bboxes):
    fig, ax = plt.subplots(nrows=5, ncols=1, figsize=(4,3))
    plt.subplots_adjust(top=2)

    for i in range(len(bboxes['boxes'])):
        label_i = bboxes['labels'][i] - 1
        label = LABELS_DICT[label_i + 1]
        score = bboxes['scores'][i]
        bbox = bboxes['boxes'][i]

        cropped_img = crop_bbox(image, bbox)
        cropped_img = cropped_img[0, :]

        ax[label_i].set_axis_off()
        ax[label_i].imshow(cropped_img, cmap="bone")
        ax[label_i].set_title(f"{label} ({'{:.2f}'.format(score)})")
        

def save_crop(image, bboxes, target, crop_dir):
    series_id = target['series_id']
    study_id = target['study_id']
    instance_number = target['instance_number']

    for i in range(len(bboxes['boxes'])):
        label = LABELS_DICT[bboxes['labels'][i]]

        dirname = f'{crop_dir}/{study_id}/{series_id}/{label}'
        os.makedirs(dirname, exist_ok=True)
        filepath = os.path.join(dirname, f'{instance_number}.pt')

        bbox = bboxes['boxes'][i]

        cropped_img = crop_bbox(image, bbox)
        torch.save(cropped_img, filepath)

    return

In [14]:
def crop_and_save_predicted_boxes(condition, epoch, config, limit = None):
    
    MODEL_DIR = os.path.join(PROJECT_DIR, 'models', '02_train_disc_detection', condition)
    CROP_DIR = os.path.join(DATA_DIR, 'test_crops', '03_test_disc_detection', CONDITION)
    os.makedirs(CROP_DIR, exist_ok=True)

    # Read in metadata
    test_df = pd.read_csv(os.path.join(DATA_DIR, 'processed_metadata', condition, 'test.csv'))
    
    # Create data loder
    dataset_test = RSNAMultipleBBoxesDataset(test_df, w = config['box_w'], h_l1_l4 = config['box_h_l1_l4'], h_l5 = config['box_h_l5'])
    test_loader = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=config['batch_size'],
        shuffle=True,
        collate_fn=utils.collate_fn,
        num_workers=os.cpu_count()
    )
    
    # Load model
    trained_model = load_model_disc_detection(state_dict=torch.load(os.path.join(f"{MODEL_DIR}/epoch_{epoch}/model_dict.pt"))).to(device)
    
    trained_model.eval()
    with torch.inference_mode():
        # Crop and save images
        for j, (images, targets) in enumerate(tqdm.tqdm(test_loader)):
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
            predictions = trained_model(images)

            for i in range(len(images)):
                bboxes = get_best_boxes(predictions[i])
                save_crop(images[i].cpu(), bboxes, targets[i], crop_dir=CROP_DIR)

## Spinal Canal Stenosis (captured by Sagittal T2/STIR images)

In [15]:
CONDITION = 'SpinalCanalStenosis'
BEST_EPOCH = 1

CONFIG = dict(
    num_epochs=3,
    batch_size=10,
    lr=0.0001,
    lr_step_size=3,
    lr_gamma=0.1,
    box_w = 70, # width of the bounding boxes
    box_h_l1_l4 = 30, # height of the boxes for levels from L1/L2 to L4/L5
    box_h_l5 = 40 # width of the boxes for level L5/S1
)

crop_and_save_predicted_boxes(condition = CONDITION, epoch = BEST_EPOCH, config = CONFIG)

100%|██████████| 38/38 [01:00<00:00,  1.59s/it]


## Left Neural Foraminal Narrowing (captured by Sagittal T1 images)

In [16]:
CONDITION = 'LeftNeuralForaminalNarrowing'
BEST_EPOCH = 1

CONFIG = dict(
    num_epochs=3,
    batch_size=10,
    lr=0.0001,
    lr_step_size=3,
    lr_gamma=0.1,
    box_w = 70, # width of the bounding boxes
    box_h_l1_l4 = 30, # height of the boxes for levels from L1/L2 to L4/L5
    box_h_l5 = 40 # width of the boxes for level L5/S1
)

crop_and_save_predicted_boxes(condition = CONDITION, epoch = BEST_EPOCH, config = CONFIG)

100%|██████████| 62/62 [01:36<00:00,  1.56s/it]


## Right Neural Foraminal Narrowing (captured by Sagittal T1 images)

In [17]:
CONDITION = 'RightNeuralForaminalNarrowing'
BEST_EPOCH = 1

CONFIG = dict(
    num_epochs=3,
    batch_size=10,
    lr=0.0001,
    lr_step_size=3,
    lr_gamma=0.1,
    box_w = 70, # width of the bounding boxes
    box_h_l1_l4 = 30, # height of the boxes for levels from L1/L2 to L4/L5
    box_h_l5 = 40 # width of the boxes for level L5/S1
)

crop_and_save_predicted_boxes(condition = CONDITION, epoch = BEST_EPOCH, config = CONFIG)

100%|██████████| 62/62 [01:34<00:00,  1.52s/it]
