In [None]:
!pip install diskcache

In [None]:
import os, random
import numpy as np
import pandas as pd
from skimage.segmentation import watershed
import torch
import SimpleITK as sitk
from glob import glob
from typing import Tuple, List, NamedTuple
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import diskcache
import functools

from luna16_util import Config, augment_candidates_3d, augment_candidates_2d, show_nodule, create_3d_tomograph

## UTILITY CLASSES

In [None]:
class XyzTuple(NamedTuple):
    x: float
    y: float
    z: float

class IrcTuple(NamedTuple):
    index: int
    row: int
    col: int

class Ct:
    def __init__(self, series_uid, window, normalize, subset=0):
        """
        Initialize CT scan object with series UID
        
        Args:
            series_uid (str): Unique identifier for CT series
        """
        mhd_path = glob(f'/kaggle/input/luna16/subset{subset}/subset*/{series_uid}.mhd')[0]
        self.ct_mhd = sitk.ReadImage(mhd_path)
        
        # Process CT array
        ct_a = np.array(sitk.GetArrayFromImage(self.ct_mhd), dtype=np.float32)
        ct_a = self.apply_window(ct_a, window, normalize)
        
        self.series_uid = series_uid
        self.hu_a = ct_a
        self.origin_xyz = np.array(self.ct_mhd.GetOrigin())
        self.vxSize_xyz = np.array(self.ct_mhd.GetSpacing())
        self.direction_xyz = np.array(self.ct_mhd.GetDirection()).reshape(3, 3)

    def irc2xyz(self, coord_irc: Tuple[int, int, int]) -> XyzTuple:
        """Convert image-row-column coordinates to physical xyz coordinates"""
        cri_a = np.array(coord_irc)[::-1]
        coords_xyz = (self.direction_xyz @ (cri_a * self.vxSize_xyz)) + self.origin_xyz
        return XyzTuple(*coords_xyz)

    def xyz2irc(self, coord_xyz: Tuple[float, float, float]) -> IrcTuple:
        """Convert physical xyz coordinates to image-row-column coordinates"""
        cri_a = ((coord_xyz - self.origin_xyz) @ np.linalg.inv(self.direction_xyz)) / self.vxSize_xyz
        cri_a = np.round(cri_a)
        return IrcTuple(int(cri_a[2]), int(cri_a[1]), int(cri_a[0]))

    def apply_window(self, ct_scan, window_name, normalize=True):

        window_settings = {
            'full_range': {'center': 0, 'width': 2000},
            'brain': {'center': 40, 'width': 80},
            'soft_tissue': {'center': 50, 'width': 400},
            'lung': {'center': -600, 'width': 1500},
            'bone': {'center': 400, 'width': 1500},
            'mediastinum': {'center': 50, 'width': 350}
        }
        
        settings = window_settings[window_name]
        center = settings['center']
        width = settings['width']
        
        min_bound = center - width // 2
        max_bound = center + width // 2
        
        ct_scan = np.clip(ct_scan, min_bound, max_bound)
        if normalize:
            ct_scan = (ct_scan - min_bound) / (width + 1e-8)
            
        return ct_scan

## ClassifierDataset

In [None]:
class ClassifierDataset(Dataset):
    def __init__(self, seriesuid: str = None, width: Tuple[int, int, int] = (32, 48, 48), data=None,
                 mode: str = 'train', cache_path: str = 'memoize_cache', config: Config = None, subset: int = 0):
        """
        Initialize LUNA16 Dataset
        
        Args:
            seriesuid (str, optional): Specific series UID to process
            width (Tuple[int, int, int]): Volume extraction dimensions
            show (bool): Whether to visualize nodules
            mode (str): Dataset mode ('train', 'val', 'cache')
        """
        self.data = data
        self.width = width
        self.mode = mode
        self.subset = subset
        self.seriesuid = seriesuid
        self.window = config.window
        self.normalize = config.normalize
        self.meta_data, self.label_name = self._get_metadata()
        if config.balanced and self.mode == 'train':
            self.meta_data = self._balance_metadata(self.meta_data, self.label_name, config.balanced)
        self.cache = diskcache.Cache(cache_path, size_limit=3e11)

    def __len__(self) -> int:
        return len(self.meta_data)

    def __getitem__(self, idx: int):
        row = self.meta_data.loc[idx]
        seriesuid = row.seriesuid
        xyz_center = row[['coordX', 'coordY', 'coordZ']].tolist()
        @self.cache.memoize()
        def _cached_getitem(seriesuid, xyz_center, width):
            ct = self._get_ct(seriesuid, self.window, self.normalize, self.subset)
            candidate, irc_center = self.get_candidate(ct ,xyz_center, width)
            return candidate, row[self.label_name], row.seriesuid, (irc_center.index, irc_center.row, irc_center.col)

        candidate, label, seriesuid, irc_center = _cached_getitem(seriesuid, xyz_center, self.width)
        
        return (
            torch.from_numpy(candidate)[None],
            torch.tensor(label, dtype=torch.long),
            seriesuid, 
            torch.tensor(irc_center)
        )

    @functools.lru_cache(maxsize=1)
    def _get_ct(self, seriesuid, window, normalize, subset):
        return Ct(seriesuid, window, normalize)

    def get_candidate(self, ct_scan: Ct, xyz_center: List[float], width: Tuple[int, int, int]) -> Tuple[np.ndarray, IrcTuple]:
        """
        Extract candidate region around a center point
        
        Args:
            xyz_center (List[float]): Center coordinates
            width (Tuple[int, int, int]): Extraction dimensions
        
        Returns:
            Tuple of extracted volume and center coordinates
        """
        irc_center = ct_scan.xyz2irc(xyz_center)
        candidate_wrap = np.zeros(width, dtype='float32')
        candidate = ct_scan.hu_a[
                irc_center.index - width[0]//2:irc_center.index + width[0]//2,
                irc_center.row - width[1]//2:irc_center.row + width[1]//2,
                irc_center.col - width[2]//2:irc_center.col + width[2]//2
            ]
        d,h,w = candidate.shape
        candidate_wrap[:d,:h,:w] = candidate[:d,:h,:w]
        return (
            candidate_wrap,
            irc_center
        )
        
    def sampler(self, class_weight=[1,1]):
        class_weights = class_probaility * 1 / self.meta_data[self.label_name].value_counts()
        sample_weights = class_weights[self.meta_data[self.label_name]].reset_index(drop=True).astype('float32')
        sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(self.meta_data), replacement=True)
        return sampler
        
    def _balance_metadata(self, df, target_column, ratio=1):

        class_counts = df[target_column].value_counts()
    
        majority_class = class_counts.index[0]
        minority_class = class_counts.index[-1]
    
        majority_df = df[df[target_column] == majority_class]
        minority_df = df[df[target_column] == minority_class]
    
        samples_to_add = int(ratio * len(majority_df))
        oversampled_minority = minority_df.sample(n=samples_to_add, replace=True)
        balanced_df = pd.concat([majority_df, oversampled_minority], ignore_index=True)
        
        return balanced_df

    def balanced_decay(self, epoch):
        self.meta_data = self._balance_metadata(self.meta_data, self.label_name, 1/epoch**.3)
        
    def _get_metadata(self) -> pd.DataFrame:
        """
        Process and merge candidate and annotation data
        
        Returns:
            Processed DataFrame with nodule information
        """
        if self.data is not None:
            return self.data, 'class'
            
        candidates = pd.read_csv('/kaggle/input/luna16/candidates.csv')
        annotations = pd.read_csv('/kaggle/input/luna16/annotations.csv')
        
        # Filter for available CT scans
        mhd_list = glob(f'/kaggle/input/luna16/subset{self.subset}/subset*/*.mhd')
        presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
        
        candidates = candidates[candidates['seriesuid'].isin(presentOnDisk_set)]
        annotations = annotations[annotations['seriesuid'].isin(presentOnDisk_set)]
        
        # Merge and process data
        result = pd.merge(candidates, annotations, on=['seriesuid'], how='left')
        result['diameter_mm'] = result['diameter_mm'].fillna(0)
        
        # Calculate distances
        nodule_coords = result[['coordX_x','coordY_x','coordZ_x']].values
        center = result[['coordX_y','coordY_y','coordZ_y']].values
        distances = np.linalg.norm(nodule_coords - center, axis=1)
        
        result['distance'] = distances / (result['diameter_mm'] + 1e-15)
        result['distance'] = result['distance'].fillna(100)
        
        result = (result
                  .sort_values('distance')
                  .groupby(['seriesuid', 'coordX_x', 'coordY_x', 'coordZ_x'])
                  .first()
                  .reset_index())
        
        result = result.rename(columns={
            'coordX_x': 'coordX', 
            'coordY_x': 'coordY', 
            'coordZ_x': 'coordZ'
        })
        
        result.loc[result['distance'] > 0.3, 'diameter_mm'] = 0
        result = (result[list(candidates.columns) + ['diameter_mm']]
                 .sort_values('diameter_mm', ascending=False)
                 .reset_index(drop=True))
        
        if self.seriesuid:
            result = result[result.seriesuid == self.seriesuid]
        
        # Split data into train/validation
        if self.mode == 'val':
            result = result[result.index % 10 == 0].reset_index(drop=True)
        elif self.mode == 'train': 
            result = result[result.index % 10 != 0].reset_index(drop=True)
        elif self.mode == 'cache': 
            result = result.sort_values('seriesuid').reset_index(drop=True)
            
        return result, 'class'    

In [None]:
# hyper_parameters = {
#         'model_type': 'classification',  # or 'segmentation'
#         'window': 'full_range',
#         'normalize': True,
#         'batch_norm': False,
#         'batch_size': 64,
#         'num_workers': 4,
#         'cache_in': False,
#         'visualize': True,
#         'epochs': 10,
#         'n_metrics': 3,
#         'balanced': 1,
#         'enable_balanced_decay': True,
#         'augment': {
#             'flip': True,
#             'offset': 0.1,
#             'scale': 0.2,
#             'rotate': True,
#             'noise': 0.1,
#             'mixup': 0.4
#     }
# }
# config = Config(hyper_parameters)
# luna = ClassifierDataset(mode='val', config=config)
# for x, y, _, irc_center in DataLoader(luna, batch_size=1, shuffle=True):
#     x, y = augment_candidates_3d(x, y, config.augment)
#     show_nodule(x.squeeze(1), y, irc_center)
#     create_3d_tomograph(x, slice_spacing=1, colormap='plasma', transparency=1.0, threshold=0)
#     break

## SegmenterDataset

In [None]:
cache = diskcache.Cache('memoize_cache', size_limit=3e11)

class SegmenterDataset(Dataset):
    def __init__(self, seriesuid: str = None, img_size: Tuple[int, int] = (96, 96), crop_size: Tuple[int, int] = (64, 64), 
                 mode: str = 'train', config: Config = None, subset=0, threshold=.2):
        """
        Initialize LUNA16 Dataset
        
        Args:
            seriesuid (str, optional): Specific series UID to process
            width (Tuple[int, int, int]): Volume extraction dimensions
            show (bool): Whether to visualize nodules
            mode (str): Dataset mode ('train', 'val')
        """
        self.threshold = threshold
        self.subset = subset
        self.crop_size = crop_size
        self.img_size = img_size
        self.mode = mode
        self.seriesuid = seriesuid
        self.window = config.window
        self.normalize = config.normalize
        self.meta_data = self._get_metadata()
        self.meta_data = self.populate_annotations(self.meta_data)

    def __len__(self) -> int:
        return len(self.meta_data)

    def __getitem__(self, idx: int):
        row = self.meta_data.loc[idx]
        seriesuid = row.seriesuid
        x, y, z = row[['coordX', 'coordY', 'coordZ']].tolist()
        slice_idx = row.slice_idx
        if self.mode in ['full_val','inference']:
            nodule, full_mask = self.get_full_nodule_mask(seriesuid, self.meta_data, self.window, self.normalize)
            irc_center = [0,0,0]
        else:
            nodule, full_mask, irc_center = self.get_nodule_mask(seriesuid, x, y, z, self.mode, self.img_size, self.window, self.normalize, self.subset)
        if self.mode == 'train':
            h_img, w_img = self.img_size
            h_crop, w_crop = self.crop_size
            h_start = h_img//2 - h_crop//2
            w_start = w_img//2 - w_crop//2
            h_start = random.randint(0, h_start)
            w_start = random.randint(0, w_start)
            nodule_slice = nodule[slice_idx-1:slice_idx+2, h_start:h_start+h_crop, w_start:w_start+w_crop]
            mask = full_mask[slice_idx, h_start:h_start+h_crop, w_start:w_start+w_crop]
        elif self.mode in ['val','full_val','inference']:
            nodule_slice = nodule[slice_idx-1:slice_idx+2]
            mask = np.zeros((nodule_slice.shape[1],nodule_slice.shape[2]), dtype='float32') if full_mask is None else full_mask[slice_idx]
        
        return (
            torch.from_numpy(nodule_slice),
            torch.tensor(mask, dtype=torch.float)[None],
            seriesuid, 
            torch.tensor(irc_center)
        )

    def _calculate_mask(self, x, y, z, ct):
        irc_center = ct.xyz2irc((x, y, z))
        center_value = ct.hu_a[irc_center.index, irc_center.row, irc_center.col]
        for i in range(16):
            next_value_u = ct.hu_a[irc_center.index+i, irc_center.row, irc_center.col]
            next_value_d = ct.hu_a[irc_center.index-i, irc_center.row, irc_center.col]
            if next_value_u < self.threshold or next_value_d < self.threshold:
                break

        for j in range(16):
            next_value_u = ct.hu_a[irc_center.index, irc_center.row+j, irc_center.col]
            next_value_d = ct.hu_a[irc_center.index, irc_center.row-j, irc_center.col]
            if next_value_u < self.threshold or next_value_d < self.threshold:
                break
                
        for k in range(16):
            next_value_u = ct.hu_a[irc_center.index, irc_center.row, irc_center.col+k]
            next_value_d = ct.hu_a[irc_center.index, irc_center.row, irc_center.col-k]
            if next_value_u < self.threshold or next_value_d < self.threshold:
                break
                
        nodule = ct.hu_a[irc_center.index - i - 1:irc_center.index + i + 2]
        mask = np.zeros_like(nodule, dtype=int)
        mask[:,irc_center.row - j - 1:irc_center.row + j + 1,irc_center.col - k - 1:irc_center.col + k + 1] = 1
        mask = mask * (nodule >= self.threshold)
        return mask, irc_center.index - i - 1, irc_center.index + i + 2

    @cache.memoize(ignore=(0,2,3,4))
    def get_full_nodule_mask(self, seriesuid, meta_data, window, normalize):
        ct = SegmenterDataset._get_ct(seriesuid, window, normalize, self.subset)
        if self.mode == 'inference':
            return ct.hu_a, None
        XYZ = meta_data[meta_data.seriesuid == seriesuid][['coordX','coordY','coordZ']].values
        full_mask = np.zeros_like(ct.hu_a, dtype=int)
        for xyz in XYZ:
            x, y, z = xyz
            mask, start, end = self._calculate_mask(x, y, z, ct)
            full_mask[start:end] = mask
        
        return ct.hu_a, full_mask

    @cache.memoize(ignore=(0,5,6,7,8,9))
    def get_nodule_mask(self, seriesuid, x, y, z, mode, img_size, window, normalize, subset):
        ct = SegmenterDataset._get_ct(seriesuid, window, normalize, subset)
        irc_center = ct.xyz2irc((x, y, z))
        h, w = img_size
        mask, start, end = self._calculate_mask(x, y, z, ct)
        nodule = ct.hu_a[start:end]

        if mode == 'train':
            nodule = nodule[:,irc_center.row - h//2:irc_center.row + h//2, irc_center.col - w//2:irc_center.col + w//2]
            mask = mask[:,irc_center.row - h//2:irc_center.row + h//2, irc_center.col - w//2:irc_center.col + w//2]

        return nodule, mask, irc_center

    @staticmethod
    @functools.lru_cache(maxsize=1)
    def _get_ct(seriesuid, window, normalize, subset):
        return Ct(seriesuid, window, normalize, subset)

    def populate_annotations(self, metadata):
        slice_idx = []
        if self.mode == 'inference':
            for index, row in metadata.iterrows():
                ct_scan, _ = self.get_full_nodule_mask(row.seriesuid, metadata, self.window, self.normalize)
                slice_idx.append(list(range(1,ct_scan.shape[0] - 1)))

            metadata['slice_idx'] = slice_idx
            metadata = metadata.explode('slice_idx').reset_index(drop=True)
            return metadata
        if self.mode == 'full_val':
            df = pd.DataFrame(metadata.groupby('seriesuid').first()).reset_index()
            for index, row in df.iterrows():
                _, full_mask = self.get_full_nodule_mask(row.seriesuid, metadata, self.window, self.normalize)
                slice_idx.append(list(range(1,full_mask.shape[0] - 1)))

            df['slice_idx'] = slice_idx
            df = df.explode('slice_idx').reset_index(drop=True)
            return df
        else:
            for index, row in metadata.iterrows():
                x, y, z = row[['coordX', 'coordY', 'coordZ']].tolist()
                _, full_mask, _  = self.get_nodule_mask(row.seriesuid, x, y, z, self.mode, self.img_size, self.window, self.normalize, self.subset)
                slice_idx.append(list(range(1,full_mask.shape[0] - 1)))
        
            metadata['slice_idx'] = slice_idx
            metadata = metadata.explode('slice_idx').reset_index(drop=True)
            return metadata
    
    def _get_metadata(self) -> pd.DataFrame:
        """
        Process and merge candidate and annotation data
        
        Returns:
            Processed DataFrame with nodule information
        """
        # Filter for available CT scans
        mhd_list = glob(f'/kaggle/input/luna16/subset{self.subset}/subset*/*.mhd')
        presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
        if self.mode == 'inference':
            result = pd.DataFrame({'seriesuid': list(presentOnDisk_set)})
            result[['coordX','coordY','coordZ','class','diameter_mm']] = None
            if self.seriesuid:
                result = result[result.seriesuid == self.seriesuid]
            return result
        candidates = pd.read_csv('/kaggle/input/luna16/candidates.csv')
        annotations = pd.read_csv('/kaggle/input/luna16/annotations.csv')
        
        candidates = candidates[candidates['seriesuid'].isin(presentOnDisk_set)]
        annotations = annotations[annotations['seriesuid'].isin(presentOnDisk_set)]
        
        # Merge and process data
        result = pd.merge(candidates, annotations, on=['seriesuid'], how='left')
        result['diameter_mm'] = result['diameter_mm'].fillna(0)
        
        # Calculate distances
        nodule_coords = result[['coordX_x','coordY_x','coordZ_x']].values
        center = result[['coordX_y','coordY_y','coordZ_y']].values
        distances = np.linalg.norm(nodule_coords - center, axis=1)
        
        result['distance'] = distances / (result['diameter_mm'] + 1e-15)
        result['distance'] = result['distance'].fillna(100)
        
        result = (result
                  .sort_values('distance')
                  .groupby(['seriesuid', 'coordX_x', 'coordY_x', 'coordZ_x'])
                  .first()
                  .reset_index())
        
        result = result.rename(columns={
            'coordX_x': 'coordX', 
            'coordY_x': 'coordY', 
            'coordZ_x': 'coordZ'
        })
        
        result.loc[result['distance'] > 0.3, 'diameter_mm'] = 0
        result = (result[list(candidates.columns) + ['diameter_mm']]
                 .sort_values('diameter_mm', ascending=False)
                 .reset_index(drop=True))

        result = result[result['class'] == 1]
        seriesuid_pos = result.seriesuid.unique()
        
        # Split data into train/validation
        if self.mode in ['val']:
            indeces = np.array(range(len(seriesuid_pos))) % 10 == 0
            seriesuid_pos = seriesuid_pos[indeces]
            result = result.loc[result.seriesuid.isin(seriesuid_pos)]
            result = result.sort_values('seriesuid')

        elif self.mode in ['full_val']:
            result = result.sort_values('seriesuid')
            
        elif self.mode == 'train': 
            indeces = np.array(range(len(seriesuid_pos))) % 10 != 0
            seriesuid_pos = seriesuid_pos[indeces]
            result = result.loc[result.seriesuid.isin(seriesuid_pos)]
            result = result.sort_values('seriesuid')

        if self.seriesuid:
            result = result[result.seriesuid == self.seriesuid]
        
        return result.reset_index(drop=True)

In [None]:
# hyper_parameters = {
#         'model_type': 'segmentation',  # or 'segmentation'
#         'window': 'full_range',
#         'normalize': True,
#         'batch_norm': False,
#         'batch_size': 64,
#         'num_workers': 4,
#         'cache_in': False,
#         'visualize': True,
#         'epochs': 10,
#         'n_metrics': 3,
#         'balanced': 1,
#         'enable_balanced_decay': True,
#         'augment': {
#             'flip': True,
#             'offset': 0.1,
#             'scale': 0.2,
#             'rotate': True,
#             'noise': 0.1,
#             'mixup': 0.4
#     }
# }
# config = Config(hyper_parameters)

# # validation seriesuid
# seriesuid_val =  '1.3.6.1.4.1.14519.5.2.1.6279.6001.213140617640021803112060161074'
# # luna = SegmenterDataset(seriesuid=seriesuid_val, mode='val', config=config)

# # inference seriesuid
# # seriesuid_inf =  '1.3.6.1.4.1.14519.5.2.1.6279.6001.100684836163890911914061745866'
# # luna = SegmenterDataset(seriesuid=seriesuid_inf, mode='inference', config=config, subset=1)

# # train seriesuid
# # seriesuid_train = '1.3.6.1.4.1.14519.5.2.1.6279.6001.128023902651233986592378348912'
# # seriesuid_train = '1.3.6.1.4.1.14519.5.2.1.6279.6001.134996872583497382954024478441'
# seriesuid_train = '1.3.6.1.4.1.14519.5.2.1.6279.6001.137763212752154081977261297097'
# # seriesuid_train = '1.3.6.1.4.1.14519.5.2.1.6279.6001.137763212752154081977261297097'
# luna = SegmenterDataset(seriesuid=seriesuid_train, mode='train', config=config)
# # luna = SegmenterDataset(mode='train', config=config)

# count = 0
# for x, y, _, irc_center in DataLoader(luna, batch_size=1, shuffle=False):
#     if count%10 == 0:
#         # x, y = augment_candidates_2d(x,y,config.augment)
#         show_nodule(x, 1, irc_center)
#         show_nodule(y, 1, irc_center)
#         # create_3d_tomograph(x, slice_spacing=1, colormap='plasma', transparency=1.0, threshold=0)
#     count += 1