In [1]:
from functools import lru_cache
import glob
import logging
import math
import os
import random

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F

from nodule_classifier.dsetsG import CTScan, create_df_candidates_info
from util.utilG import getCacheHandle, unzipped_path, xyz2irc
import SimpleITK as sitk

# log = logging.getLogger('ggggg')
# # log.setLevel(logging.WARN)
# # log.setLevel(logging.INFO)
# log.setLevel(logging.DEBUG)

disk_cache = getCacheHandle('segmentation')


def get_candidate_info(series_uid):
    df_candidates = create_df_candidates_info()
    return df_candidates.loc[series_uid].values.tolist()

def get_series_uids_list():
    df_candidates = create_df_candidates_info()
    return df_candidates.index.unique().tolist()

class CTScan_seg(CTScan):
    def __init__(self, series_uid) -> None:
        self.seriesuid = series_uid
        path_mhdfile = glob.glob(unzipped_path + 'subset*/subset*/{}.mhd'.format(series_uid))[0]
        ct_img = sitk.ReadImage(path_mhdfile) #contain metadata getters
        ct_img_arr = sitk.GetArrayFromImage(ct_img).astype(np.float32)
        self.ct_img_arr = ct_img_arr
         # no longer clip hu [-1000, 1000] here 
         # because we want to keep the original values of the CT scan
        
        self.origin_xyz = np.array(ct_img.GetOrigin())
        self.vxSize_xyz = np.array(ct_img.GetSpacing())
        self.direction_matrix = np.array(ct_img.GetDirection()).reshape(3, 3)
        
        candidateInfo_list = get_candidate_info(series_uid) # for 1 ct uid only
        
        """idea: build a mask of the positive candidates, 
        then get the indexes of the mask by summing over the axis 1 and 2, 
        since the mask is 1 where the pixel is positive, and 0 where the pixel is negative.
        We are only aiming to segment the positive candidates, 
        so we only need the indexes of the positive candidates.
        """
        self.positiveInfo_list = [candidate_tup for candidate_tup 
                                  in candidateInfo_list if candidate_tup[1]] # idx 1 is isNodule
        self.positive_mask = self.buildAnnotationMask(self.positiveInfo_list)
        self.positive_indexes = (self.positive_mask.sum(axis=(1,2))
                                 .nonzero()[0].tolist()) # get the Is in IRC of the positive masks

    def buildAnnotationMask(self, positiveInfo_list, threshold_hu = -700): 
        #need to fix the bug of wrapping around at 0 index of the tensor
        mask_arr = np.zeros_like(self.ct_img_arr, dtype=bool)
        """So, a HU value of -700 is less dense than water, air, and lung tissue. It's much less dense than bone or other soft tissues. In the context of a lung CT scan, a HU value of -700 would likely correspond to very low-density tissue or possibly an area of disease or damage."""
        for candidateInfo_tup in positiveInfo_list: 
            #loop over all the positive candidates of a single chosen ct scan, 
            # the positiveInfo_list is filtered to only 
            # contain candidates of a single ct series uid 
            center_irc = xyz2irc(
                candidateInfo_tup[0], # idx 0 is xyzCoord
                self.origin_xyz,
                self.vxSize_xyz,
                self.direction_matrix,
            )
            ci = int(center_irc[0])
            cr = int(center_irc[1])
            cc = int(center_irc[2])
            # from the center of the candidate, 
            # we expand the bounding box until we reach the threshold_hu
            # we do this for all 3 axes, and we stop expanding an axis once one direction 
            # of a given axis reaches the threshold_hu, and continue to the next axis
            index_step = 2
            try: # we loop until we touch the threshold_hu
                while self.ct_img_arr[ci + index_step, cr, cc] > threshold_hu and \
                        self.ct_img_arr[ci - index_step, cr, cc] > threshold_hu:
                    index_step += 1
            except IndexError: #indexError is raised when we reach the end of the axis
                index_step -= 1

            row_step = 2
            try:
                while self.ct_img_arr[ci, cr + row_step, cc] > threshold_hu and \
                        self.ct_img_arr[ci, cr - row_step, cc] > threshold_hu:
                    row_step += 1
            except IndexError:
                row_step -= 1

            col_step = 2
            try:
                while self.ct_img_arr[ci, cr, cc + col_step] > threshold_hu and \
                        self.ct_img_arr[ci, cr, cc - col_step] > threshold_hu:
                    col_step += 1
            except IndexError:
                col_step -= 1

            """mask_arr is a 3d tensor of the same size as the ct, and that is False everywhere except where the candidate is located,"""
            mask_arr[ 
                ci - index_step: ci + index_step + 1,
                cr - row_step: cr + row_step + 1,
                cc - col_step: cc + col_step + 1] = True
        # need to do clean up because we stop expanding the bounding box when we reach the # threshold_hu without decreasing the increased step
        # filter out the low density boxes that are bordering the high density boxes
        mask_arr = mask_arr & (self.ct_img_arr > threshold_hu) # clean up bordering low density boxes
        
        """we return the full mask, and we do the cropping on both the CT scan and the mask when we actually get the candidate chunk. (in ct.get_ct_cropped method)"""
        return mask_arr       

    @staticmethod
    @disk_cache.memoize(typed=True) 
    #cache this for fast retrieval of Index axis size and pos idxs
    # we need the Index axis size because each ct scan has different number of slices
    def get_Ct_I_axis_info(series_uid):
        ct = CTScan_seg(series_uid)
        return int(ct.ct_img_arr.shape[0]), ct.positive_indexes


class LunaSegDataset(Dataset):
    #custome implementation of a dataset that loads the CT scans and candidate info
    
    series_uids_list = get_series_uids_list()
    df_candidates = create_df_candidates_info()
    split_idx = math.ceil(len(series_uids_list) * 0.7)
    # dataloader probably does shallow copy of the object when numworkers > 0
    # so df_candidates must stay outside of __init__ for it to be copied to each worker
    def __init__(self) -> None:
        self.samples = []
        
    def __len__(self):
        return len(self.samples)
    
    # def __getitem__(self, idx):
    #     pass

        


# training and validation datasets classes, which share the same parent self.df_candidates, 
# but the shared df_candidates is splitted
class LunaSegDataset_Train(LunaSegDataset):
    def __init__(self):
        super().__init__()
        self.series_uids_list = self.series_uids_list[:self.split_idx]
        df_candidates_reset_index = self.df_candidates.reset_index()
        self.samples = df_candidates_reset_index[(df_candidates_reset_index['seriesuid']
                                            .isin(self.series_uids_list)) 
                                           & (df_candidates_reset_index['isNodule'] == True)]
    
    def __getitem__(self, idx):
        candidate_info = self.samples.iloc[idx]
        ct_arr, pos_arr = CTScan_seg.get_ct_cropped_disk_cache(
            candidate_info['seriesuid'],
            candidate_info['xyzCoord'],
            (7, 96, 96)
            )
        ct_tensor = torch.tensor(ct_arr, dtype=torch.float32)
        # we will use conv2d, so no need to add the channel dimension
        # because we treat the I axis as the channel dimension already
        # and there is no depth D dimension
        pos_tensor = torch.tensor(pos_arr, dtype=torch.long)
        return ct_tensor, pos_tensor
        
class LunaSegDataset_Val(LunaSegDataset):
    def __init__(self):
        self.series_uids_list = self.series_uids_list[self.split_idx:]
        self.samples = []
        for series_uid in self.series_uids_list:
            I_count, pos_idxs = CTScan_seg.get_Ct_I_axis_info(series_uid)
            self.samples.extend([(series_uid, pos_idx) for pos_idx in pos_idxs])
    
    def __getitem__(self, idx):
        series_uid, slice_idx = self.samples[idx]
        ct = CTScan_seg._get_single_ct_lru_cache(series_uid)
        ct_t = torch.zeros((3*2 + 1, 512, 512))
        
        #for every 2d input slice, we accompany it with 3 slices before and after it
        start_idx = slice_idx - 3
        end_idx = slice_idx + 3 + 1
        for i, context_idx in enumerate(range(start_idx, end_idx)):
            #When we reach beyond the bounds of the ct_a, we duplicate the first or last slice.
            context_idx = max(context_idx, 0)
            context_idx = min(context_idx, ct.ct_img_arr.shape[0] - 1)
            ct_t[i] = torch.from_numpy(ct.ct_img_arr[context_idx].astype(np.float32))

        ct_t.clamp_(-1000, 1000) # have not done the clipping

        assert len(ct.positive_mask) > 0, f"positive mask is empty for {series_uid}"
        pos_t = torch.from_numpy(ct.positive_mask[slice_idx]).unsqueeze(0)

        return ct_t, pos_t
        


In [4]:
from torch.utils.data import DataLoader

from util.utilG import enumerateWithEstimate
train_set = LunaSegDataset_Train()
val_set = LunaSegDataset_Val()
train_loader = DataLoader(train_set, batch_size=64)
val_loader = DataLoader(val_set, batch_size=64)
numworkers = 0

# batch_iter = enumerateWithEstimate(
#     train_loader,
#     "cache train",
#     start_ndx=numworkers,
# )
# for batch_ndx, batch_tup in batch_iter:
#     pass


batch_iter = enumerateWithEstimate(
    val_loader,
    "cache val",
    start_ndx=numworkers,
)
for batch_ndx, batch_tup in batch_iter:
    pass


