# Contents
<div style="font-size: 16px;">
    This notebook evaluates how trained severity classification models perform on test sets, with respect to each spinal diagnosed condition.
</div>


# Install libaries

In [1]:
!pip install pydicom -q
!pip install torch==2.1 torchvision==0.16 -q
!pip install -qU pycocotools
!pip install -qU wandb

# Configs

In [2]:
for name in list(globals()):
    if not name.startswith("_"):  # Avoid deleting built-in and special variables
        del globals()[name]

In [3]:
import os
import time
import random
from datetime import datetime
import numpy as np
import collections

from matplotlib import animation, rc
import pandas as pd

import matplotlib.patches as patches
import matplotlib.pyplot as plt

import tqdm
import sys
import torch

In [4]:
# Device configuration
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda', index=0)

## Install and load libraries

In [8]:
import os
import time
import random
from datetime import datetime
import numpy as np
import collections

from matplotlib import animation, rc
import pandas as pd

import matplotlib.patches as patches
import matplotlib.pyplot as plt

import tqdm
import sys
import torch

## Directories

In [6]:
# Directories
PROJECT_DIR = '/home/jupyter'
DATA_DIR = os.path.join(PROJECT_DIR, 'data')
SRC_DIR = os.path.join(PROJECT_DIR, 'src')

## Functions

In [7]:
with open(os.path.join(SRC_DIR, 'pipeline_severity_classification.py')) as file:
    exec(file.read())

# Evaluation on test sets

In [89]:
def evaluate(model, loader, criterion, device):

    model.eval()
    all_labels = []
    all_preds = []
    total_loss = 0
    n_examples = 0
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            n_examples += len(images)
            # print(n_examples, loss)

            # Collect predictions and labels
            _, preds = torch.max(outputs, 1)
            all_labels.extend(labels.tolist())
            all_preds.extend(preds.tolist())


    # Calculate average test loss
    avg_loss = total_loss / len(loader)

    # Calculate accuracy, precision, and recall for the validation set
    accuracy, precision, recall, f1 = accuracy_metrics(all_labels, all_preds)

    # Print validation metrics
    print(f"Validation Metrics:")
    print(f"  Loss: {avg_loss:.4f}")
    print(f"  Accuracy: {accuracy * 100:.2f}%")
    for cls in range(num_classes):
        print(f"  Class {cls}: Precision: {precision[cls]:.4f}, Recall: {recall[cls]:.4f}, F1-score: {f1[cls]:.4f}")
    
    return n_examples, avg_loss, all_labels, all_preds

In [90]:
LABELS_DICT = {
    1: "L1_L2",
    2: "L2_L3",
    3: "L3_L4",
    4: "L4_L5",
    5: "L5_S1"
}

def evaluate_on_test_set(condition, epoch):
    
    MODEL_DIR = os.path.join(PROJECT_DIR, 'models', '04_train_severity_classification', condition)

    # Read in metadata
    test_df = pd.read_csv(os.path.join(DATA_DIR, 'processed_metadata', condition, 'test.csv'))
    
    if 'SubarticularStenosis' in condition:
        dataset_test = RSNAUncroppedImageDataset(test_df)
    else: 
        # Write cropped test images
        crop_dir = os.path.join(DATA_DIR, 'test_crops', '03_test_disc_detection', condition)
        filepaths = [os.path.join(crop_dir, f"{row['study_id']}/{row['series_id']}/{LABELS_DICT[row['level_code']]}/{row['instance_number']}.pt") for _, row in test_df.iterrows()]
        test_df['cropped_image_path'] = filepaths
        # Set dataset
        dataset_test = RSNACroppedImageDataset(test_df)
        
    # Create data loder
    test_loader = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=50,
        shuffle=False
    )
    
    # Load model
    trained_model = load_model_severity_classification(state_dict=torch.load(os.path.join(f"{MODEL_DIR}/epoch_{epoch}/model_dict.pt"))).to(device)
    
    # Load criterion
    class_weights = torch.tensor([1.0, 2.0, 4.0]).clone().detach()  # Adjust weights as needed
    criterion = nn.CrossEntropyLoss(weight=class_weights).to(device)

    # Evaluate on test set
    return evaluate(trained_model, test_loader, criterion, device=device)

## Spinal Canal Stenosis (captured by Sagittal T2/STIR images)

In [91]:
CONDITION = 'SpinalCanalStenosis'
BEST_EPOCH = 3
results_0 = evaluate_on_test_set(condition = CONDITION, epoch = BEST_EPOCH)

Validation Metrics:
  Loss: 0.3682
  Accuracy: 90.64%
  Class 0: Precision: 0.9695, Recall: 0.9605, F1-score: 0.9650
  Class 1: Precision: 0.4123, Recall: 0.4352, F1-score: 0.4234
  Class 2: Precision: 0.5634, Recall: 0.6154, F1-score: 0.5882


## Left Neural Foraminal Narrowing (captured by Sagittal T1 images)

In [92]:
CONDITION = 'LeftNeuralForaminalNarrowing'
BEST_EPOCH = 4
results_1 = evaluate_on_test_set(condition = CONDITION, epoch = BEST_EPOCH)

Validation Metrics:
  Loss: 0.6593
  Accuracy: 82.03%
  Class 0: Precision: 0.8926, Recall: 0.9159, F1-score: 0.9041
  Class 1: Precision: 0.5315, Recall: 0.4945, F1-score: 0.5123
  Class 2: Precision: 0.5349, Recall: 0.4259, F1-score: 0.4742


## Right Neural Foraminal Narrowing (captured by Sagittal T1 images)

In [93]:
CONDITION = 'RightNeuralForaminalNarrowing'
BEST_EPOCH = 5
results_2 = evaluate_on_test_set(condition = CONDITION, epoch = BEST_EPOCH)

Validation Metrics:
  Loss: 0.5793
  Accuracy: 79.40%
  Class 0: Precision: 0.9313, Recall: 0.8465, F1-score: 0.8869
  Class 1: Precision: 0.4592, Recall: 0.6377, F1-score: 0.5340
  Class 2: Precision: 0.4182, Recall: 0.4340, F1-score: 0.4259


## Left Subarticular Stenosis (captured by Axial T1 images)

In [94]:
CONDITION = 'LeftSubarticularStenosis'
BEST_EPOCH = 1
results_3 = evaluate_on_test_set(condition = CONDITION, epoch = BEST_EPOCH)

Validation Metrics:
  Loss: 0.6665
  Accuracy: 75.36%
  Class 0: Precision: 0.8762, Recall: 0.8745, F1-score: 0.8754
  Class 1: Precision: 0.4310, Recall: 0.4325, F1-score: 0.4318
  Class 2: Precision: 0.4960, Recall: 0.5000, F1-score: 0.4980


## Right SubarticularStenosis (captured by Axial T1 images)

In [95]:
CONDITION = 'RightSubarticularStenosis'
BEST_EPOCH = 1
results_4 = evaluate_on_test_set(condition = CONDITION, epoch = BEST_EPOCH)

Validation Metrics:
  Loss: 0.7375
  Accuracy: 71.88%
  Class 0: Precision: 0.8942, Recall: 0.8304, F1-score: 0.8611
  Class 1: Precision: 0.3488, Recall: 0.3750, F1-score: 0.3614
  Class 2: Precision: 0.4318, Recall: 0.5846, F1-score: 0.4967


# Results of the entire test set

In [102]:
def accuracy_metrics_for_complete_set(results):
    
    n_examples = [r[0] for r in results]
    losses     = [r[1] for r in results]
    all_labels = sum([r[2] for r in results], [])
    all_preds  = sum([r[3] for r in results], [])
        
    # Calculate accuracy, precision, and recall for the validation set
    accuracy, precision, recall, f1 = accuracy_metrics(all_labels, all_preds)
    avg_loss = sum([l * n for l, n in zip(n_examples, losses)]) / sum(n_examples)
    
    # Print validation metrics
    print(f"Validation Metrics:")
    print(f"  Loss: {avg_loss:.4f}")
    print(f"  Accuracy: {accuracy * 100:.2f}%")
    for cls in range(num_classes):
        print(f"  Class {cls}: Precision: {precision[cls]:.4f}, Recall: {recall[cls]:.4f}, F1-score: {f1[cls]:.4f}")

In [103]:
results = [results_0, results_1, results_2, results_3, results_4]
accuracy_metrics_for_complete_set(results)

Validation Metrics:
  Loss: 0.5847
  Accuracy: 80.90%
  Class 0: Precision: 0.9174, Recall: 0.8955, F1-score: 0.9063
  Class 1: Precision: 0.4511, Recall: 0.4938, F1-score: 0.4715
  Class 2: Precision: 0.4869, Recall: 0.5152, F1-score: 0.5007
