In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import pydicom
import numpy as np
import os
import glob
from tqdm import tqdm
import warnings
import pandas as pd
DATA_PATH = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/'
WORKING_DIR = '/kaggle/input/rsna-2024-train/'
DATASET_SELECTION = 'sagittal'
SET_TYPE='test'
train_candidates = pd.read_csv(WORKING_DIR+'train_candidates.csv')
train_candidates.head(5)
train_candidates[train_candidates['study_id']==1085426528].groupby(['series_id','series_description']).size()
from collections import namedtuple
import functools
import datetime
from torch.utils.data import Dataset
import random
import pickle
from torch import tensor

SEVERITY_MAPPING = {"Normal/Mild":0,"Moderate":1,"Severe":2}
SEVERITY_WEIGHTING = {0:1.0,1:2.0,2:4.0}
severity_candidate_dict = {'44036939_2828203845_left_neural_foraminal_narrowing_l4_l5': 
                        {'instance_no': [22,23],
                          'box_coord': [tensor([186.6035, 251.5714,  40.5323,  36.3092]),
                           tensor([185.6549, 252.8904,  41.1806,  36.0747])],
                          'box_conf': [tensor(0.5057), tensor(0.3758)],
                          'roi_xyxy': (253, 316, 358, 415),
                          'conf_xyxy': (255, 316, 358, 413),
                          'conf_instance_no': 22,
                         # TODO add these attributes in
                         'img_path':f"{DATA_PATH}/{SET_TYPE}_images/44036939/2828203845/22.dcm",
                         'study_id':"44036939",
                         'series_id':"2828203845",
                         "set_type":"Sagittal T1"
                        }
                       }
severity_candidate_list = [{**value,**{"row_id":key}} for key, value in severity_candidate_dict.items()]
severity_candidate_list

## Dataset

Classification Component Set:
- candidate_list: list of instances with bounding box coordinates *(need to adjust for image boundaries)* 
- sorted_study_ids: the list of study_ids sorted with respect to condition severity, used for creating robust validation sets

In [None]:
# Initialise candidate_info_tuple
candidate_info_tuple = namedtuple(
    'candidate_info_tuple',
    'row_id, study_id, series_id, instance_number, has_centres, centre_xy, severity, img_path, width_bbox, height_bbox'
)

# Read in relevant Classification Component Set
with open(f'/kaggle/input/rsna-2024-train-candidate-list/candidate_list_{DATASET_SELECTION}.pkl', 'rb') as f:
    out = pickle.load(f)
    
candidate_list, sorted_study_ids = out['candidate_list'],out['sorted_study_ids'] 
@functools.lru_cache(1, typed=True)
def get_image(path):
    return pydicom.dcmread(path)
# Pass img_path which is mutable and can be easily cached, get_image() is also cached
# and so this should not result in additional compute
@functools.lru_cache(1, typed=True)
def get_roi(img_path, x_centre, y_centre, width = 50, height = 40):
    image = get_image(img_path)
    y_max, x_max = image.pixel_array.shape
    x_left, x_right, y_bot, y_top = int(max(0,x_centre - width//2)), int(min(x_max, x_centre + width//2)), int(max(0, y_centre - height//2)), int(min(y_max, y_centre + height//2))
    roi = image.pixel_array[y_bot:y_top,x_left:x_right]
    return roi.astype(np.uint8)
# Pass img_path which is mutable and can be easily cached, get_image() is also cached
# and so this should not result in additional compute
@functools.lru_cache(1, typed=True)
def get_roi_xyxy(img_path, x_left, x_right, y_bot, y_top):
    image = get_image(img_path)
    roi = image.pixel_array[y_bot:y_top,x_left:x_right]
    return roi.astype(np.uint8)

### Plot images

In [None]:
def dicom_transforms(img: np.array):
    IMG_normalized = cv2.normalize(img, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
    # Convert the image from BGR to RGB for correct color display in matplotlib
    IMG_normalized = cv2.cvtColor(IMG_normalized, cv2.COLOR_BGR2RGB)
    return IMG_normalized
import matplotlib.pyplot as plt
import cv2
import os
import glob
import math
import re

# Function to extract the numeric values from the filename
def extract_numbers(file_path):
    filename = file_path.split('/')[-1]
    numbers = re.findall(r'\d+', filename)
    return list(map(int, numbers))


def plot_images(image_dir,pattern="*"):
    """
    Plot YOLO predictions and ground truth labels on images.

    Parameters:
    - image_dir (str): Directory containing the images.
    - label_dir (str): Directory containing the label files.
    - model: YOLO model for making predictions.
    - pattern (str): Pattern to match image files in the directory.
    """
    # Get list of images
    image_paths = sorted(glob.glob(os.path.join(image_dir, pattern)),key=extract_numbers)
    
    # Plot image with ground truth and predicted boxes
    num_rows = math.ceil(len(image_paths)/4)
    fig, ax = plt.subplots(num_rows, 4, figsize=(12, 12))      
    
    if num_rows > 1:
        ax = ax.flatten()
    else:
        ax = [ax]  # Make it iterable for consistency    
    
    imgs = []
    for i, image_path in enumerate(image_paths):
        # Load image
        img = get_image(image_path)
        imgs.append(img)
        IMG = img.pixel_array
        IMG_normalized = cv2.normalize(IMG, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
        # Convert the image from BGR to RGB for correct color display in matplotlib
        IMG_normalized = cv2.cvtColor(IMG_normalized, cv2.COLOR_BGR2RGB)
        
        ax[i].imshow(IMG_normalized)
        
    # Show the result
    plt.show()
        
    return imgs
        
# Example usage:
out = plot_images(image_dir="/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train_images/",
           pattern="1085426528/1518511736/*")


### Dataset

In [None]:
class InferenceLumbarSeverityDataset(Dataset):
    
    def __init__(self,
                val_stride=0,
                val_set=False,
                study_id=None,
                candidate_list=None,
                transform=None,
                sample: int | None = None,
                rand_bbox: bool = False,
                sorted_candidate_list=None, 
                save_to_drive: bool = False,
                debug_mode: bool = False):
        
        self.sample = sample
        self.rand_bbox = rand_bbox
        self.save_to_drive = save_to_drive
        self.debug_mode = debug_mode
        self.val_set = val_set
        
        if candidate_list:
            self.candidate_list = candidate_list
        else:
            raise Exception("Must provide candidate_list.")
        
        # Default to a stratified sampled of candidates by study_id (option to provide custom sorted list on other
        # stratifications).
        if sorted_candidate_list:
            self.sorted_candidate_list = sorted_candidate_list.copy()
        else:
            study_ids = np.unique([x['study_id'] for x in self.candidate_list])
            self.sorted_candidate_list = np.random.shuffle(study_ids)        

        if study_id:
            self.candidate_list = [x for x in self.candidate_list if x['study_id'] == study_id]
        
        if self.val_set:
            assert val_stride > 0, val_stride
            val_study_ids = self.sorted_candidate_list[::val_stride]
            self.candidate_list = [x for x in self.candidate_list if str(x['study_id']) in val_study_ids]
            assert self.candidate_list
        elif val_stride > 0:
            del self.sorted_candidate_list[::val_stride]                        
            self.candidate_list = [x for x in self.candidate_list if str(x['study_id']) in self.sorted_candidate_list]            
            assert self.candidate_list
    
        self.transform = transform
    
    def __len__(self):
        if self.sample:
            return min(self.sample, len(self.candidate_list))
        else:
            return len(self.candidate_list)
    
    def __getitem__(self,ndx):
        """
        return 
        """
        
        candidate = self.candidate_list[ndx]
        
        if self.rand_bbox:    
            width = randint(40,60)
            height = randint(50,70)
            candidate = self.candidate_list[ndx].udpate({'width_bbox':width, 'height_bbox':height})
            self.candidate_list[ndx] = candidate

        roi = get_roi_xyxy(candidate['img_path'], *candidate['conf_xyxy']) # TODO unpack a tuple in a function
    
        if self.transform:
            try:
                roi = self.transform(roi)        
            except Exception as e:
                print(f"Exception raised for {candidate.row_id} \n\n {e}")                
        
        return roi, candidate['row_id'], candidate
        
    def get_items_by_study(self, study_id):
        """
        Get items for a specific study.
        """
        study_candidates = [i for i, x in enumerate(self.candidate_list) if x.study_id == study_id]
        result = [self.__getitem__(i) for i in study_candidates]

        return result        

    def get_items_by_row_id(self, row_id, debug=False):
        """
        Get items for a specific study.
        """
        study_candidates = [i for i, x in enumerate(self.candidate_list) if x.row_id == row_id]
        if debug:
            try:
                result = [self.__getitem__(i) for i in study_candidates]
            except Exception:
                return get_roi_xyxy(candidate.img_path, *candidate.conf_xyxy), candidate
        else:
            result = [self.__getitem__(i) for i in study_candidates]

        return result     
dataset = InferenceLumbarSeverityDataset(candidate_list=severity_candidate_list)
# for i in range(len(dataset)):
#     roi, sev, row_id = dataset[i]
#     x, y = roi.shape
    
#     if x <= 0 or y <= 0:
#         print(row_id)
dataset[0]

### Model and Config

In [None]:
import torch
from torch import nn
from PIL import Image
from torchvision import models
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import v2
import torch.nn.functional as F
import timm

class Config:
    def __init__(self):
        # Training parameters
        self.batch_size = 32
        self.learning_rate = 5e-4
        self.num_epochs = 20
        self.trained_epochs = None
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        # Dataset
        self.roi_width_range = (40,60)
        self.roi_height_range = (50,70)
        
        # Model parameters
        self.model_type = 'resnet50.a1_in1k'
        self.pretrained = True
        self.num_classes = 3
        self.fc = 'fc'
        self.trainable_layers = ['layer4']
        
        # Optimization parameters
        self.optimizer = 'adam'
        self.weight_decay = 1e-4
        self.momentum = 0.9
        
        # Scheduler parameters
        self.use_scheduler = True
        self.scheduler = 'cosine_annealing'
        self.T_max = 50
        
        # Early stopping
        self.early_stop = True
        self.early_stop_patience = 5
        
        # Mixed precision training
        self.use_amp = True
        self.clip_grad = True
        self.clip_value = 1.0
        
        # Augmentation
        self.horizontal_flip = 0.5
        self.random_affine = True
        
        # Logging parameters
        self.log_interval = 10
        self.tensorboard_log_dir = './logs'
        
    def update(self, **kwargs):
        """ Update configuration parameters. """
        for key, value in kwargs.items():
            if hasattr(self, key):
                setattr(self, key, value)
            else:
                raise AttributeError(f"{key} is not a valid attribute of Config")
        return self  # Return the entire Config object
                
    def get(self,value: str, default: any = Ellipsis):
        if hasattr(self,value):
            return getattr(self,value)
        else:
            if default is Ellipsis:
                raise ValueError(f"Attribute '{value}' does not exist and no default value was provided.")
            return default


# Example usage
config = Config()

# # Update parameters if needed
# config.update(batch_size=64, learning_rate=0.0005, num_epochs=100)

# Example model definition (same as before)
class LumbarNet(nn.Module):
    def __init__(self, config: dict):
        super(LumbarNet, self).__init__()
        
        self.config = config
        
        pretrained = config.get("pretrained")
        model_type = config.get("model_type")
        self.model = timm.create_model(model_type,pretrained) # keep linear head for in_features

        # Set trainable layers
        trainable_layers = config.get('trainable_layers',[])        
        if trainable_layers:
            self.set_layer_requires_grad(trainable_layers)
        else:
            self.freeze_all(self.model)

        # Set fully connected
        fc = config.get('fc',None)
        if fc:
            out = getattr(self.model, fc)
            in_f = out.in_features
            new_fc = nn.Linear(in_f, config.get('num_classes'))
            setattr(self.model, fc, new_fc)
        
    def forward(self, x):
        return self.model(x)
    
    @staticmethod
    def freeze_all(model):
        for param in model.parameters():
            param.requires_grad = False
            
    def set_layer_requires_grad(self, layers_to_unfreeze):
        # Freeze all layers first
        self.freeze_all(self.model)

        # Unfreeze the specified layers
        for layer_name in layers_to_unfreeze:
            layer = getattr(self.model, layer_name, None)
            if layer and not isinstance(layer, nn.BatchNorm2d):
                for param in layer.parameters():
                    param.requires_grad = True
class ToRGB(object):
    def __call__(self, img):
        if isinstance(img, np.ndarray):
            # Convert NumPy array to PIL Image
            img = Image.fromarray(img)
        # Convert the image to RGB mode
        img_rgb = img.convert('RGB')
        return img_rgb
def get_standard_transforms(model: nn.Module):
    """
    Only applied if pretrained=True.
    
    Removes CentreCrop which is likely to hinder performance when ROI are small enough anyway.
    """
    data_config = timm.data.resolve_model_data_config(model)
    transform_list = timm.data.create_transform(**data_config, is_training=False)

    filtered_transforms = [t for t in transform_list.transforms if not isinstance(t, transforms.CenterCrop)]

    return filtered_transforms
def get_transforms(config: dict, model: nn.Module):
    
    transforms_list = []
    
    inference_transforms =  [ToRGB(),
             transforms.Resize(size=(256,256), interpolation=transforms.InterpolationMode.BICUBIC, max_size=None),
             transforms.ToTensor(),
             transforms.Normalize(mean=torch.tensor([0.4850, 0.4560, 0.4060]), std=torch.tensor([0.2290, 0.2240, 0.2250]))]
        
    if config.get("pretrained",False):

        transforms_list.extend(inference_transforms)
        
    else:
        # TODO standard minmax transforms to dataset
        transforms_list.append(ToRGB())


    if config.get("horizontal_flip", False):
        transforms_list.append(v2.RandomHorizontalFlip(p=config.get("horizontal_flip")))
    
    if config.get("random_affine", False):
        transforms_list.append(v2.RandomApply([
                v2.RandomAffine(
                    degrees=(-20, 20), translate=(0.0,0.25), scale=(0.75, 1)
                )],p=0.25))
            
            
    composed_transforms = transforms.Compose(transforms_list)
    
    return composed_transforms, transforms.Compose(inference_transforms)

## Evaluation

In [111]:
def inference(model, loader, device):
    
    labels = {}
    results = []
    probs_list = {}
    with torch.no_grad():
        for batch in loader:
            inputs, row_ids, candidates = batch
            inputs = inputs.to(device)
            outputs = model(inputs)
            
            probs = F.softmax(outputs, dim=1)
            
            results.append(zip(row_ids,probs.cpu()))
            
            labels = labels | {row_id: np.argmax(x.numpy()) for row_id, x in zip(row_ids, outputs.cpu())}
            probs_list = probs_list | {row_id: prob.numpy() for row_id, prob in zip(row_ids, probs.cpu())}
            
    return results, labels, probs_list     

#### Inference script

In [None]:
import pickle
import timm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
best_model_state = torch.load("/kaggle/input/severity-classification/pytorch/resnet-sagittal-t1/1/resnet50.a1_in1k_2024-089_13-24-01_earlystop.pth",
                             map_location = device)
config = Config()
model = LumbarNet(config)
model.load_state_dict(best_model_state['model'])
model = model.to(device)

_, transform = get_transforms(config,model)

# Create dataset
dataset = InferenceLumbarSeverityDataset(candidate_list=severity_candidate_list,
                                         transform=transform)

data_loader = DataLoader(dataset, batch_size=16)

results, labels, probs_list = inference(model, data_loader, device)
probs_list
for x in severity_candidate_list:
    probs = probs_list[x['row_id']]
    severity = labels[x['row_id']]
    
    x.update({"probs_x":probs})
    x.update({"severity":severity})    
results = pd.DataFrame(severity_candidate_list)
results.groupby('severity').size()