In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import albumentations as A
import pydicom
import os
from PIL import Image
from preprocessing.segmantation_inference import SegmentaionInference
from preprocessing.detection_inference import DetectionInference, transforms
import torch.nn.functional as F
from pathlib import Path
import re
import matplotlib.pyplot as plt
from preprocessing.Cross_Reference_Axial import CrossReferenceAxial

class CFG():
    AUG_PROB = 0.75
    NOT_DEBUG = True
    AUG = True
    Axial_shape = (152, 152)
    Sagittal_shape = (152, 152)
    channel_size_sagittal_t1 = 12
    channel_size_sagittal_t2 = 9
    channel_size_sagittal = channel_size_sagittal_t1 + channel_size_sagittal_t2
    channel_size_axial = 9
    train_path = "train_images"
    segmentation = SegmentaionInference(model_path=r"weights\simple_unet.onnx")
    DetectionInference = DetectionInference(model_path=r"weights\axial_detection_resnet18.pth", transforms=transforms)
    cross_reference = CrossReferenceAxial(image_dir =r"test_images\\")
    label2id = {'Normal/Mild': 0, 'Moderate':1, 'Severe':2, np.nan: -100}
    category2id = {"L1": 1, "L2": 2, "L3": 3, "L4": 4, "L5": 5, "L5-S1": 11, "L4-L5": 12, "L3-L4": 13, "L2-L3": 14, "L1-L2": 15}
    skip_study_id = [2492114990, 2780132468, 3008676218]
    two_classes_category = {11: 'L5-S1', 12: 'L4-L5', 13: 'L3-L4', 14: 'L2-L3', 15: 'L1-L2'}

cfg = CFG()

transforms_val = A.Compose([
    A.Normalize(mean=[0.485], std=[0.229])
])

class CustomDataset(Dataset):
    def __init__(self, study_ids, labels_path, test_path, transform):
        self.study_ids = study_ids
        self.df_description = pd.read_csv(labels_path)
        self.transform = transform
        self.test_path = test_path
    def __len__(self):
        return len(self.study_ids)
    
    @staticmethod
    def plot(stack, x = 5, y = 6):
        fig, axes = plt.subplots(x, y, figsize=(15, 9))
        for i, ax in enumerate(axes.flat):
            ax.imshow(stack[..., i], cmap='gray')
        plt.tight_layout()
        plt.show()
    
    def load_dicom(self, path):
        original_dicom = pydicom.dcmread(path).pixel_array
        original_dicom = original_dicom.clip(np.percentile(original_dicom, 1), np.percentile(original_dicom, 99))
        original_dicom = np.array(self.resize_image(original_dicom, (512, 512)))
        original_dicom = (original_dicom - original_dicom.min()) / (original_dicom.max() - original_dicom.min() + 1e-6) * 255
        return original_dicom.astype(np.uint8)

    @staticmethod
    def pad_images_list(images_list, max_len): # need to check this function
        if len(images_list) < 0:
            raise ValueError("images_list is empty")
        if len(images_list) == max_len:
            return images_list
        
        n = len(images_list)
        output_list = []
        
        # How many times should we duplicate each element minimally?
        min_repeats = max_len // n
        
        # How many extra duplicates are needed beyond minimal repeats?
        extra = max_len % n
        
        # Determine the central region to duplicate more
        mid_point = n // 2
        start_extra = mid_point - (extra // 2)
        end_extra = start_extra + extra
        
        # Duplicate elements, adding extra repeats to central elements
        for i in range(n):
            repeats = min_repeats + 1 if start_extra <= i < end_extra else min_repeats
            output_list.extend([images_list[i]] * repeats)

        return output_list
    
    @staticmethod
    def center_crop(pixel_array, bboxes): # need to check this function
        min_x, max_x = 999999, -1
        min_y, max_y = 999999, -1
        list_of_bboxes = [boxes[0] for boxes in bboxes.values()]
        for x, y, w, h in list_of_bboxes:
            if x == -1 and y == -1 and h == -1 and w == -1:
                continue
            min_x = min(min_x, x)
            max_x = max(max_x, x+w)
            min_y = min(min_y, y)
            max_y = max(max_y, y+h)
        # Convert DICOM pixel array to PIL image
        image = Image.fromarray(pixel_array)
        # Crop the image using the calculated bounding box
        if min_x == 999999:
            min_x = 0
        if min_y == 999999:
            min_y = 0
        if max_x == -1:
            max_x = pixel_array.shape[1]
        if max_y == -1:
            max_y = pixel_array.shape[1]
        cropped_image_array = np.array(image.crop((min_x + 10, min_y, max_x + 30, max_y)))
        return cropped_image_array

    @staticmethod
    def center_crop_by_categorys(original_dicom, bboxes, category, second_category): # need to check this function
        bbox1 = bboxes[category]
        bbox2 = bboxes[second_category]
        x, y, h, w = bbox1[0]
        x2, y2, h2, w2 = bbox2[0]
        # Function to get the minimum value ignoring -1
        def min_ignore_neg_one(a, b):
            if a == -1:
                return b
            if b == -1:
                return a
            return min(a, b)

        # Function to get the maximum value ignoring -1
        min_x = min_ignore_neg_one(x, x2)
        min_y = min_ignore_neg_one(y, y2)
        max_x = max(x + w, x2 + w2)
        max_y = max(y + h, y2 + h2)
        image = Image.fromarray(original_dicom)
        if (min_x == -1 and min_y == -1) or ((max_x - min_x < 35) or (max_y - min_y < 35)):
            # shape = np.array(original_dicom).shape
            # need to add plt.imshow() of the segmentation mask
            return original_dicom
        cropped_image = image.crop((min_x-10, min_y-10, max_x + 30, max_y))

        return np.array(cropped_image)
    
    def center_crop_by_category(self, pixel_array, bboxes, category):
        bbox = bboxes[category]
        x, y, h, w = bbox[0]
        image = Image.fromarray(pixel_array)
        margin = cfg.Sagittal_shape[0] // 2 
        cropped_image = image.crop((x - 20, y - margin, x + cfg.Sagittal_shape[0] - 20, y + margin))
        return np.array(cropped_image)

    @staticmethod
    def unpad_images_list(images_list, max_len): # need to check this function
        i = 0
        while len(images_list) > max_len:
            if i % 2 == 0:
                images_list.pop(-1)
            else:
                images_list.pop(0)
            i += 1
        return images_list

    @staticmethod
    def resize_image(pixel_array, new_size):
        if pixel_array.shape == new_size:
            return pixel_array
        else:
            image = Image.fromarray(pixel_array)
            return image.resize((new_size[1], new_size[0]))

    @staticmethod
    def extract_number(filename):
            match = re.search(r'\d+', filename)
            return int(match.group()) if match else 0

    def divide_Axiel(self, sub_set):
        df_classes = pd.DataFrame(columns=['path', 'class_id'])
        study_id = sub_set['study_id'].iloc[0]
        series_id_axial = sub_set['series_id'].loc[sub_set['series_description'] == "Axial T2"].iloc[0]
        list_ = os.listdir(os.path.join(self.train_path, str(study_id), str(series_id_axial)))
        list_ = sorted(list_, key=self.extract_number)
        divide_by_5 = len(list_) // 5
        remainder = len(list_) % 5

        class_ids = ["L1", "L2", "L3", "L4", "L5"]
        start_idx = 0

        for i, class_id in enumerate(class_ids):
            end_idx = start_idx + divide_by_5 + (1 if i < remainder else 0)  # Add 1 to the first 'remainder' groups
            for file in list_[start_idx:end_idx]:
                df_classes.loc[len(df_classes)] = [os.path.join(self.train_path, str(study_id), str(series_id_axial), file), class_id]
            start_idx = end_idx
        
        return df_classes
    def crop_axial_center(self, image, bbox):
        min_x, max_x = 999999, -1
        min_y, max_y = 999999, -1
        for x, y, h, w in bbox:
            min_x = min(min_x, x)
            max_x = max(max_x, x+w)
            min_y = min(min_y, y)
            max_y = max(max_y, y+h)
        if type(image) != Image.Image:
            image = Image.fromarray(image)
        
        if min_x == 999999 or min_y == 999999:
            
            width, height = image.size
            # Define the size of the crop
            crop_size = 304

            # Calculate coordinates for the middle crop
            left = (width - crop_size) // 2
            top = (height - crop_size) // 2
            right = left + crop_size
            bottom = top + crop_size
            cropped_image = image.crop((left, top, right, bottom))
            # print("left: ", left, "top: ", top, "right: ", right, "bottom: ", bottom)
            return (
            cropped_image.crop((0, 0, 152, crop_size)),   # Adjusted coordinates
            cropped_image.crop((76, 0, 228, crop_size)),  # Adjusted coordinates
            cropped_image.crop((152, 0, 304, crop_size))  # Adjusted coordinates
        )
        
        # Crop the center of the image
        margin = 304 // 2 
        
        return (image.crop((min_x - margin, min_y - 50, min_x, min_y + 102)),
                image.crop((min_x - 76 , min_y - 50, min_x + 76, min_y + 102)),
                image.crop((min_x, min_y - 50, min_x + margin, min_y + 102)))
    
    def crop_sagittal_center(self, file, Sagittal_bboxes_scaled, Sagittal_path, category):

            original_dicom = self.load_dicom(os.path.join(Sagittal_path, file))
            if "S1" in category:
                category1 = category.split("-")[0]
                category2 = category
            else:
                category1 = category.split("-")[0]
                category2 = category.split("-")[1]
            new_pixel_array = self.center_crop_by_category(original_dicom, Sagittal_bboxes_scaled,
                                                            cfg.category2id[category])
            # resized_pixel_array = self.resize_image(new_pixel_array, (cfg.Sagittal_shape[0], cfg.Sagittal_shape[1]))
            new_shape = new_pixel_array.shape
            # Ensure the padded array is large enough to hold new_pixel_array
            padded_array = np.zeros((max(cfg.Sagittal_shape[0], new_shape[0]), max(cfg.Sagittal_shape[1], new_shape[1])))
            
            # Compute the starting indices for centering the new_pixel_array
            start_x = (padded_array.shape[0] - new_shape[0]) // 2
            start_y = (padded_array.shape[1] - new_shape[1]) // 2

            # Place the new_pixel_array in the center of the padded_array
            padded_array[start_x:start_x + new_shape[0], start_y:start_y + new_shape[1]] = new_pixel_array
            return padded_array[:cfg.Sagittal_shape[0], :cfg.Sagittal_shape[1]].astype(np.uint8)

    def create_stack(self, sagittal_stack, axial_stack,
                     Sagittal_T2_files, Sagittal_T2_bboxes, Sagittal_T2_path,
                        Sagittal_T1_files, Sagittal_T1_bboxes, Sagittal_T1_path,
                     category, two_classes_category, df_classes):
        k = 0
        # the order of the images are mirrored
        RIGHT_T1_files = Sagittal_T1_files[:len(Sagittal_T1_files)//2]
        LEFT_T1_files = Sagittal_T1_files[len(Sagittal_T1_files)//2:]
            

        if len(RIGHT_T1_files) < cfg.channel_size_sagittal_t1 // 2:
            RIGHT_T1_files = self.pad_images_list(RIGHT_T1_files, cfg.channel_size_sagittal_t1//2)
        elif len(RIGHT_T1_files) > cfg.channel_size_sagittal_t1 // 2:
            RIGHT_T1_files = RIGHT_T1_files[-cfg.channel_size_sagittal_t1 // 2:]
        # elif len(RIGHT_T1_files) > cfg.channel_size_sagittal_t1 // 2:
        #     RIGHT_T1_files = self.unpad_images_list(RIGHT_T1_files, cfg.channel_size_sagittal_t1//2)

        if len(LEFT_T1_files) < cfg.channel_size_sagittal_t1 // 2:
            LEFT_T1_files = self.pad_images_list(LEFT_T1_files, cfg.channel_size_sagittal_t1//2)
        elif len(LEFT_T1_files) > cfg.channel_size_sagittal_t1 // 2:
            LEFT_T1_files = LEFT_T1_files[:cfg.channel_size_sagittal_t1//2]
        # elif len(LEFT_T1_files) > cfg.channel_size_sagittal_t1 // 2:
        #     LEFT_T1_files = self.unpad_images_list(LEFT_T1_files, cfg.channel_size_sagittal_t1//2)

        original_dicom = pydicom.dcmread(os.path.join(Sagittal_T1_path, Sagittal_T1_files[len(Sagittal_T1_files)//2])).pixel_array
        Sagittal_T1_bboxes = cfg.segmentation.scale_bboxes(Sagittal_T1_bboxes, (512, 512), original_dicom.shape)
        LEFT_T1_files = LEFT_T1_files[::-1] # reverse the order of the images
        for file in LEFT_T1_files:
            sagittal_stack[..., k] = self.crop_sagittal_center(file, Sagittal_T1_bboxes, Sagittal_T1_path, two_classes_category)
            k += 1
        
        RIGHT_T1_files = RIGHT_T1_files[::-1] # reverse the order of the images
        for file in RIGHT_T1_files:
            sagittal_stack[..., k] = self.crop_sagittal_center(file, Sagittal_T1_bboxes, Sagittal_T1_path, two_classes_category)
            k += 1

        


        if len(Sagittal_T2_files) < cfg.channel_size_sagittal_t2:
            Sagittal_T2_files = self.pad_images_list(Sagittal_T2_files, cfg.channel_size_sagittal_t2)
        elif len(Sagittal_T2_files) > cfg.channel_size_sagittal_t2:
            Sagittal_T2_files = self.unpad_images_list(Sagittal_T2_files, cfg.channel_size_sagittal_t2)
        original_dicom = pydicom.dcmread(os.path.join(Sagittal_T2_path, Sagittal_T2_files[len(Sagittal_T2_files)//2])).pixel_array
        Sagittal_T2_bboxes = cfg.segmentation.scale_bboxes(Sagittal_T2_bboxes, (512, 512), original_dicom.shape)
        for file in Sagittal_T2_files:
            sagittal_stack[..., k] = self.crop_sagittal_center(file, Sagittal_T2_bboxes, Sagittal_T2_path, two_classes_category)
            k += 1
        

        l = df_classes['path'].loc[df_classes['class_id'] == two_classes_category].unique() # df_classes['class_id'] == category)

        l = l.tolist()
        if len(l) == 0:
            l = df_classes['path'].loc[(df_classes['class_id'] == category)].unique()
            l = l.tolist()
        l = sorted(l, key=self.extract_number)

        if len(l) == 0:
            return sagittal_stack, axial_stack
        
        if len(l) < 3:
            l = self.pad_images_list(l, 3)

        elif len(l) > 3:
            l = self.unpad_images_list(l, 3)
        
        def crop_axial_image(pixel_array):
            # Define the size of the crop
            crop_size = 384
            width, height = np.array(pixel_array).shape[0], np.array(pixel_array).shape[1]
            # Calculate coordinates for the middle crop
            left = (width - crop_size) // 2
            top = (height - crop_size) // 2
            right = left + crop_size
            bottom = top + crop_size
            cropped_image = pixel_array.crop((left, top, right, bottom))
            return cropped_image
        
        p = 0
        j = 3
        o = 6
        for file in l:
            original_dicom = pydicom.dcmread(file).pixel_array
            original_dicom = original_dicom.clip(np.percentile(original_dicom, 1), np.percentile(original_dicom, 99))
            bbox = cfg.DetectionInference.inference(original_dicom, 512, 512)
            original_dicom = (original_dicom - original_dicom.min()) / (original_dicom.max() - original_dicom.min() + 1e-6) * 255
            original_dicom = original_dicom.astype(np.uint8)
            resized_pixel_array = self.resize_image(original_dicom, (512,512))
            if type(resized_pixel_array) != Image.Image:
                resized_pixel_array = Image.fromarray(resized_pixel_array)
            resized_pixel_array = resized_pixel_array.transpose(Image.FLIP_LEFT_RIGHT)
            # cropped_image = crop_axial_image(resized_pixel_array)
            # axial_stack[..., p] = cropped_image
            # p += 1
            # Crop the center of the DICOM image
            cropped_left, cropped_middle, cropped_right = self.crop_axial_center(resized_pixel_array, bbox)
            cropped_left = self.resize_image(np.array(cropped_left), (cfg.Axial_shape[0], cfg.Axial_shape[1]))
            cropped_middle = self.resize_image(np.array(cropped_middle), (cfg.Axial_shape[0], cfg.Axial_shape[1]))
            cropped_right = self.resize_image(np.array(cropped_right), (cfg.Axial_shape[0], cfg.Axial_shape[1]))
            axial_stack[..., p] = np.array(cropped_left).astype(np.uint8)
            sagittal_stack[..., k + p] = np.array(cropped_left).astype(np.uint8)
            p += 1
            axial_stack[..., j] = np.array(cropped_middle).astype(np.uint8)
            sagittal_stack[..., k + j] = np.array(cropped_middle).astype(np.uint8)
            j += 1
            axial_stack[..., o] = np.array(cropped_right).astype(np.uint8)
            sagittal_stack[..., k + o] = np.array(cropped_right).astype(np.uint8)
            o += 1
        
        return sagittal_stack, axial_stack


    
    @staticmethod
    def _is_dict_structure_correct(d):
        required_keys = {1, 2, 3, 4, 5, 11, 12, 13, 14, 15}
        if set(d.keys()) != required_keys:
            return False
        
        for key in required_keys:
            if not (isinstance(d[key], list) and len(d[key]) == 1 and d[key][0] == (-1, -1, -1, -1)):
                return False
    
        return True
    
    @staticmethod
    def _is_all_black(image_array):
        return np.all(image_array == 0)
    

    @staticmethod
    def _count_neg_ones(bboxes):
        count = 0
        for vals in bboxes.values():
            count += vals.count((-1, -1, -1, -1))
        return count
    
  
    def __getitem__(self, index):
        sagittal_l1_l2 = np.zeros((cfg.Sagittal_shape[0], cfg.Sagittal_shape[1], cfg.channel_size_sagittal + cfg.channel_size_axial), dtype = np.uint8)
        axial_l1_l2 = np.zeros((cfg.Axial_shape[0], cfg.Axial_shape[1], cfg.channel_size_axial), dtype = np.uint8)
        sagittal_l2_l3 = np.zeros((cfg.Sagittal_shape[0], cfg.Sagittal_shape[1], cfg.channel_size_sagittal + cfg.channel_size_axial), dtype = np.uint8)
        axial_l2_l3 = np.zeros((cfg.Axial_shape[0], cfg.Axial_shape[1], cfg.channel_size_axial), dtype = np.uint8)
        sagittal_l3_l4 = np.zeros((cfg.Sagittal_shape[0], cfg.Sagittal_shape[1], cfg.channel_size_sagittal + cfg.channel_size_axial), dtype = np.uint8)
        axial_l3_l4 = np.zeros((cfg.Axial_shape[0], cfg.Axial_shape[1], cfg.channel_size_axial), dtype = np.uint8)
        sagittal_l4_l5 = np.zeros((cfg.Sagittal_shape[0], cfg.Sagittal_shape[1], cfg.channel_size_sagittal + cfg.channel_size_axial), dtype = np.uint8)
        axial_l4_l5 = np.zeros((cfg.Axial_shape[0], cfg.Axial_shape[1], cfg.channel_size_axial), dtype = np.uint8)
        sagittal_l5_s1 = np.zeros((cfg.Sagittal_shape[0], cfg.Sagittal_shape[1], cfg.channel_size_sagittal + cfg.channel_size_axial), dtype = np.uint8)
        axial_l5_s1 = np.zeros((cfg.Axial_shape[0], cfg.Axial_shape[1], cfg.channel_size_axial), dtype = np.uint8)
        study_id = self.study_ids[index]

        
        sub_set = self.df_description.loc[self.df_description.study_id == study_id]
        Sagittal_T1_path = os.path.join(self.test_path, str(study_id), str(sub_set["series_id"].loc[sub_set["series_description"] == "Sagittal T1"].iloc[0]))
        Sagittal_T2_STIR_path = os.path.join(self.test_path, str(study_id), str(sub_set["series_id"].loc[sub_set["series_description"] == "Sagittal T2/STIR"].iloc[0]))
        Axial_path = os.path.join(self.test_path, str(study_id), str(sub_set["series_id"].loc[sub_set["series_description"] == "Axial T2"].iloc[0]))

        Sagittal_T1_files = os.listdir(Sagittal_T1_path)
        Sagittal_T1_files = sorted(Sagittal_T1_files, key=self.extract_number)
        middle_index = len(Sagittal_T1_files) // 2
        Sagittal_T1_bboxes = cfg.segmentation.inference(os.path.join(Sagittal_T1_path, Sagittal_T1_files[middle_index]))
        if self._count_neg_ones(Sagittal_T1_bboxes) != 0:
            pmax = len(Sagittal_T1_files)// 2 + 5
            pmin = len(Sagittal_T1_files)// 2 - 5
            for p in range(pmin, pmax):
                temp = cfg.segmentation.inference(os.path.join(Sagittal_T1_path, Sagittal_T1_files[p]))
                for key, value in temp.items():
                    if value != [(-1, -1, -1, -1)]:
                        if key not in Sagittal_T1_bboxes or Sagittal_T1_bboxes[key] == [(-1, -1, -1, -1)]:
                            Sagittal_T1_bboxes[key] = value
                if self._count_neg_ones(Sagittal_T1_bboxes) == 0:
                    break
        
            
        Sagittal_T2_files = os.listdir(Sagittal_T2_STIR_path)
        Sagittal_T2_files = sorted(Sagittal_T2_files, key=self.extract_number)
        middle_index = len(Sagittal_T2_files) // 2
        Sagittal_T2_bboxes = cfg.segmentation.inference(os.path.join(Sagittal_T2_STIR_path, Sagittal_T2_files[middle_index]))
        if self._count_neg_ones(Sagittal_T2_bboxes) != 0:
            pmax = len(Sagittal_T2_files)// 2 + 5
            pmin = len(Sagittal_T2_files)// 2 - 5
            for p in range(pmin, pmax):
                temp = cfg.segmentation.inference(os.path.join(Sagittal_T2_STIR_path, Sagittal_T2_files[p]))
                for key, value in temp.items():
                    if value != [(-1, -1, -1, -1)]:
                        if key not in Sagittal_T2_bboxes or Sagittal_T2_bboxes[key] == [(-1, -1, -1, -1)]:
                            Sagittal_T2_bboxes[key] = value
                if self._count_neg_ones(Sagittal_T2_bboxes) == 0:
                    break
        
        
        # Axial_files = os.listdir(Axial_path)
        # Axial_files = sorted(Axial_files, key=self.extract_number)
        decription_df = self.df_description[(self.df_description['study_id'] == study_id)]
        # df_classes = self.divide_Axiel(decription_df)
        
        df_classes = cfg.cross_reference.get_cross_reference_for_Axial(decription_df, "test")

        
        
        sagittal_l1_l2, axial_l1_l2 = self.create_stack(sagittal_l1_l2, axial_l1_l2,
                                                        Sagittal_T2_files, Sagittal_T2_bboxes, Sagittal_T2_STIR_path,
                                                        Sagittal_T1_files, Sagittal_T1_bboxes, Sagittal_T1_path,
                                                        "L1", "L1-L2", df_classes)
        
        sagittal_l2_l3, axial_l2_l3 = self.create_stack(sagittal_l2_l3, axial_l2_l3,
                                                        Sagittal_T2_files, Sagittal_T2_bboxes, Sagittal_T2_STIR_path,
                                                        Sagittal_T1_files, Sagittal_T1_bboxes, Sagittal_T1_path,
                                        "L2", "L2-L3", df_classes)
        
        sagittal_l3_l4, axial_l3_l4 = self.create_stack(sagittal_l3_l4, axial_l3_l4,
                                                        Sagittal_T2_files, Sagittal_T2_bboxes, Sagittal_T2_STIR_path,
                                                        Sagittal_T1_files, Sagittal_T1_bboxes, Sagittal_T1_path,
                                        "L3", "L3-L4", df_classes)
                                        
        sagittal_l4_l5, axial_l4_l5 = self.create_stack(sagittal_l4_l5, axial_l4_l5,
                                                        Sagittal_T2_files, Sagittal_T2_bboxes, Sagittal_T2_STIR_path,
                                                        Sagittal_T1_files, Sagittal_T1_bboxes, Sagittal_T1_path,
                                        "L4", "L4-L5", df_classes)
        
        sagittal_l5_s1, axial_l5_s1 = self.create_stack(sagittal_l5_s1, axial_l5_s1,
                                                        Sagittal_T2_files, Sagittal_T2_bboxes, Sagittal_T2_STIR_path,
                                                        Sagittal_T1_files, Sagittal_T1_bboxes, Sagittal_T1_path,
                                        "L5", "L5-S1", df_classes)

        # self.plot(sagittal_l1_l2[...,], x = 4, y = 7)
        # self.plot(sagittal_l2_l3[...,], x = 4, y = 7)
        # self.plot(sagittal_l3_l4[...,], x = 4, y = 7)
        # self.plot(sagittal_l4_l5[...,], x = 4, y = 7)
        # self.plot(sagittal_l5_s1[...,], x = 4, y = 7)

        flag_l1_l2 = False
        flag_l2_l3 = False
        flag_l3_l4 = False
        flag_l4_l5 = False
        flag_l5_s1 = False

        flag_l1_l2 = np.all(np.isclose(sagittal_l1_l2[:, :, -9:], 0.0))
        flag_l2_l3 = np.all(np.isclose(sagittal_l2_l3[:, :, -9:], 0.0))
        flag_l3_l4 = np.all(np.isclose(sagittal_l3_l4[:, :, -9:], 0.0))
        flag_l4_l5 = np.all(np.isclose(sagittal_l4_l5[:, :, -9:], 0.0))
        flag_l5_s1 = np.all(np.isclose(sagittal_l5_s1[:, :, -9:], 0.0))

        if self.transform:
            sagittal_l1_l2 = self.transform(image=sagittal_l1_l2)['image']
            sagittal_l2_l3 = self.transform(image=sagittal_l2_l3)['image']
            sagittal_l3_l4 = self.transform(image=sagittal_l3_l4)['image']
            sagittal_l4_l5 = self.transform(image=sagittal_l4_l5)['image']
            sagittal_l5_s1 = self.transform(image=sagittal_l5_s1)['image']
            axial_l1_l2 = self.transform(image=axial_l1_l2)['image']
            axial_l2_l3 = self.transform(image=axial_l2_l3)['image']
            axial_l3_l4 = self.transform(image=axial_l3_l4)['image']
            axial_l4_l5 = self.transform(image=axial_l4_l5)['image']
            axial_l5_s1 = self.transform(image=axial_l5_s1)['image']

        if flag_l1_l2:
            sagittal_l1_l2[:, :, -9:] = 0
        if flag_l2_l3:
            sagittal_l2_l3[:, :, -9:] = 0
        if flag_l3_l4:
            sagittal_l3_l4[:, :, -9:] = 0
        if flag_l4_l5:
            sagittal_l4_l5[:, :, -9:] = 0
        if flag_l5_s1:
            sagittal_l5_s1[:, :, -9:] = 0


        sagittal_l1_l2 = torch.tensor(sagittal_l1_l2).permute(2, 0, 1)
        sagittal_l2_l3 = torch.tensor(sagittal_l2_l3).permute(2, 0, 1)
        sagittal_l3_l4 = torch.tensor(sagittal_l3_l4).permute(2, 0, 1)
        sagittal_l4_l5 = torch.tensor(sagittal_l4_l5).permute(2, 0, 1)
        sagittal_l5_s1 = torch.tensor(sagittal_l5_s1).permute(2, 0, 1)
        axial_l1_l2 = torch.tensor(axial_l1_l2).permute(2, 0, 1)
        axial_l2_l3 = torch.tensor(axial_l2_l3).permute(2, 0, 1)
        axial_l3_l4 = torch.tensor(axial_l3_l4).permute(2, 0, 1)
        axial_l4_l5 = torch.tensor(axial_l4_l5).permute(2, 0, 1)
        axial_l5_s1 = torch.tensor(axial_l5_s1).permute(2, 0, 1)
        

        return (study_id, sagittal_l1_l2, axial_l1_l2, sagittal_l2_l3,
                 axial_l2_l3, sagittal_l3_l4, axial_l3_l4, sagittal_l4_l5, axial_l4_l5, sagittal_l5_s1,
                   axial_l5_s1)
            


def TestLoader(study_ids: list, labels_path: Path, test_path:Path) -> tuple[DataLoader, DataLoader]:
    return CustomDataset(study_ids, labels_path, test_path, transforms_val)



In [2]:
CONDITIONS = [
    'spinal_canal_stenosis', 
    'left_neural_foraminal_narrowing', 
    'right_neural_foraminal_narrowing',
    'left_subarticular_stenosis',
    'right_subarticular_stenosis'
]

LEVELS = [
    'l1_l2',
    'l2_l3',
    'l3_l4',
    'l4_l5',
    'l5_s1',
]

In [3]:
from model_code.custom_model import CustomRain
import torch
import pandas as pd
from torch.utils.data import DataLoader
from tqdm import tqdm

def reorder_labels(labels):
    
    # Create an empty tensor to hold the reversed labels
    original_labels = torch.empty_like(labels)

    # Number of total sets
    num_sets = len(labels) // 3

    # Reversing the order based on % 5 position
    for set_index in range(num_sets):
        source_start = set_index * 3
        # Calculate destination start based on the modulus operation
        dest_start = ((set_index % 5) * 15) + (set_index // 5 * 3)
        original_labels[dest_start:dest_start + 3] = labels[source_start:source_start + 3]
    return original_labels


class Settings:
    number_of_classes = 15
    batch_size = 1
    pretrain = False
    path = "test_images"
    description_path = "test_series_descriptions.csv"
    test_path = "test_images"
    model_path = "RainDrop_0.4910611528158188_fold_1.pt"
    N_LABELS = 25
    LABELS = ['normal_mild','moderate','severe']

settings = Settings()
model = CustomRain(settings.number_of_classes, settings.pretrain)
model.load_state_dict(torch.load(settings.model_path))

test_df = pd.read_csv(settings.description_path)
study_ids = list(test_df['study_id'].unique())
data_loader = TestLoader(study_ids, settings.description_path, settings.test_path)
test_loader = DataLoader(data_loader, batch_size=settings.batch_size, shuffle=False)
model = model.eval()
model = model.cuda()

submissions = pd.DataFrame()
device = "cuda" if torch.cuda.is_available() else "cpu"
row_names = []
y_preds = []
for (study_id, sagittal_l1_l2, axial_l1_l2, sagittal_l2_l3,
                 axial_l2_l3, sagittal_l3_l4, axial_l3_l4, sagittal_l4_l5, axial_l4_l5, sagittal_l5_s1,
                   axial_l5_s1,) in tqdm(test_loader):
    with torch.no_grad():
        sagittal_l1_l2 = sagittal_l1_l2.cuda()
        sagittal_l2_l3 = sagittal_l2_l3.cuda()
        sagittal_l3_l4 = sagittal_l3_l4.cuda()
        sagittal_l4_l5 = sagittal_l4_l5.cuda()
        sagittal_l5_s1 = sagittal_l5_s1.cuda()
        axial_l1_l2 = axial_l1_l2.cuda()
        axial_l2_l3 = axial_l2_l3.cuda()
        axial_l3_l4 = axial_l3_l4.cuda()
        axial_l4_l5 = axial_l4_l5.cuda()
        axial_l5_s1 = axial_l5_s1.cuda()
        
        output1 = model(sagittal_l1_l2, torch.tensor(0, device=device)).squeeze()
        output2 = model(sagittal_l2_l3, torch.tensor(1, device=device)).squeeze()
        output3 = model(sagittal_l3_l4, torch.tensor(2, device=device)).squeeze()
        output4 = model(sagittal_l4_l5, torch.tensor(3, device=device)).squeeze()
        output5 = model(sagittal_l5_s1, torch.tensor(4, device=device)).squeeze()
        output = torch.cat([output1, output2, output3, output4, output5], dim=0)
        output = reorder_labels(output)
        pred_per_study = np.zeros((25, 3))

        for cond in CONDITIONS:
            for level in LEVELS:
                row_names.append(str(study_id.tolist()[0]) + '_' + cond + '_' + level)

        for col in range(settings.N_LABELS):
            pred = output[col*3:col*3+3]
            y_pred = pred.float().softmax(0).cpu().numpy()
            pred_per_study[col] += y_pred
        y_preds.append(pred_per_study)
y_preds = np.concatenate(y_preds, axis=0)

submissions['row_id'] = row_names
submissions[settings.LABELS] = y_preds
submissions

100%|██████████| 1/1 [00:09<00:00,  9.75s/it]


Unnamed: 0,row_id,normal_mild,moderate,severe
0,44036939_spinal_canal_stenosis_l1_l2,0.289629,0.473262,0.237108
1,44036939_spinal_canal_stenosis_l2_l3,0.158324,0.461493,0.380182
2,44036939_spinal_canal_stenosis_l3_l4,0.160917,0.473151,0.365932
3,44036939_spinal_canal_stenosis_l4_l5,0.309416,0.471928,0.218656
4,44036939_spinal_canal_stenosis_l5_s1,0.866586,0.074696,0.058718
5,44036939_left_neural_foraminal_narrowing_l1_l2,0.33206,0.575108,0.092832
6,44036939_left_neural_foraminal_narrowing_l2_l3,0.339537,0.592243,0.06822
7,44036939_left_neural_foraminal_narrowing_l3_l4,0.089024,0.640141,0.270835
8,44036939_left_neural_foraminal_narrowing_l4_l5,0.1069,0.610414,0.282685
9,44036939_left_neural_foraminal_narrowing_l5_s1,0.072774,0.40835,0.518875


In [4]:
np.array(sagittal_l1_l2.cpu()).sum(), np.array(sagittal_l2_l3.cpu()).sum(), np.array(sagittal_l3_l4.cpu()).sum(), np.array(sagittal_l4_l5.cpu()).sum(), np.array(sagittal_l5_s1.cpu()).sum()

(-378140.03, -328560.66, -293460.0, -311064.72, -128999.19)