In [1]:
from functools import lru_cache
import glob
import logging
import math
import os
import random

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F

from nodule_classifier.dsetsG import CTScan, create_df_candidates_info
from util.utilG import getCacheHandle, unzipped_path, xyz2irc
import SimpleITK as sitk

# log = logging.getLogger('ggggg')
# # log.setLevel(logging.WARN)
# # log.setLevel(logging.INFO)
# log.setLevel(logging.DEBUG)

disk_cache = getCacheHandle('segmentation')


def get_candidate_info(series_uid):
    df_candidates = create_df_candidates_info()
    return df_candidates.loc[series_uid].values.tolist()

class CTScan_seg(CTScan):
    def __init__(self, series_uid) -> None:
        self.seriesuid = series_uid
        path_mhdfile = glob.glob(unzipped_path + 'subset*/subset*/{}.mhd'.format(series_uid))[0]
        ct_img = sitk.ReadImage(path_mhdfile) #contain metadata getters
        ct_img_arr = sitk.GetArrayFromImage(ct_img).astype(np.float32)
        self.ct_img_arr = ct_img_arr
         # no longer clip hu [-1000, 1000] here 
         # because we want to keep the original values of the CT scan
        
        self.origin_xyz = np.array(ct_img.GetOrigin())
        self.vxSize_xyz = np.array(ct_img.GetSpacing())
        self.direction_matrix = np.array(ct_img.GetDirection()).reshape(3, 3)
        
        candidateInfo_list = get_candidate_info(series_uid) # for 1 ct uid only
        
        """idea: build a mask of the positive candidates, 
        then get the indexes of the mask by summing over the axis 1 and 2, 
        since the mask is 1 where the pixel is positive, and 0 where the pixel is negative"""
        self.positiveInfo_list = [candidate_tup for candidate_tup 
                                  in candidateInfo_list if candidate_tup.isNodule_bool]
        self.positive_mask = self.buildAnnotationMask(self.positiveInfo_list)
        self.positive_indexes = (self.positive_mask.sum(axis=(1,2))
                                 .nonzero()[0].tolist()) # get the Is in IRC of the positive masks

    def buildAnnotationMask(self, positiveInfo_list, threshold_hu = -700): 
        #need to fix the bug of wrapping around at 0 index of the tensor
        mask_arr = np.zeros_like(self.ct_img_arr, dtype=bool)
        """So, a HU value of -700 is less dense than water, air, and lung tissue. It's much less dense than bone or other soft tissues. In the context of a lung CT scan, a HU value of -700 would likely correspond to very low-density tissue or possibly an area of disease or damage."""
        for candidateInfo_tup in positiveInfo_list: 
            #loop over all the positive candidates of a single chosen ct scan, 
            # the positiveInfo_list is filtered to only 
            # contain candidates of a single ct series uid 
            center_irc = xyz2irc(
                candidateInfo_tup.center_xyz,
                self.origin_xyz,
                self.vxSize_xyz,
                self.direction_matrix,
            )
            ci = int(center_irc.index)
            cr = int(center_irc.row)
            cc = int(center_irc.col)
            # from the center of the candidate, 
            # we expand the bounding box until we reach the threshold_hu
            # we do this for all 3 axes, and we stop expanding an axis once one direction 
            # of a given axis reaches the threshold_hu, and continue to the next axis
            index_step = 2
            try: # we loop until we touch the threshold_hu
                while self.ct_img_arr[ci + index_step, cr, cc] > threshold_hu and \
                        self.ct_img_arr[ci - index_step, cr, cc] > threshold_hu:
                    index_step += 1
            except IndexError: #indexError is raised when we reach the end of the axis
                index_step -= 1

            row_step = 2
            try:
                while self.ct_img_arr[ci, cr + row_step, cc] > threshold_hu and \
                        self.ct_img_arr[ci, cr - row_step, cc] > threshold_hu:
                    row_step += 1
            except IndexError:
                row_step -= 1

            col_step = 2
            try:
                while self.ct_img_arr[ci, cr, cc + col_step] > threshold_hu and \
                        self.ct_img_arr[ci, cr, cc - col_step] > threshold_hu:
                    col_step += 1
            except IndexError:
                col_step -= 1

            """mask_arr is a 3d tensor of the same size as the ct, and that is False everywhere except where the candidate is located,"""
            mask_arr[ 
                ci - index_step: ci + index_step + 1,
                cr - row_step: cr + row_step + 1,
                cc - col_step: cc + col_step + 1] = True
        # need to do clean up because we stop expanding the bounding box when we reach the # threshold_hu without decreasing the increased step
        # filter out the low density boxes that are bordering the high density boxes
        mask_arr = mask_arr & (self.ct_img_arr > threshold_hu) # clean up bordering low density boxes
        
        """we return the full mask, and we do the cropping on both the CT scan and the mask when we actually get the candidate chunk. (in ct.get_ct_cropped method)"""
        return mask_arr       

    @disk_cache.memoize(typed=True) 
    #cache this for fast retrieval of Index axis size and pos idxs
    # we need the Index axis size because each ct scan has different number of slices
    def get_Ct_I_axis_info(series_uid):
        ct = CTScan_seg(series_uid)
        return int(ct.ct_img_arr.shape[0]), ct.positive_indexes


holder = create_df_candidates_info() #anti pattern, should fix this
class LunaSegDataset(Dataset):
    #custome implementation of a dataset that loads the CT scans and candidate info
    
    
    # df_candidates = create_df_candidates_info().sample(frac=1) 
    """we have to perform stratified split on the dataset, because the dataset is extremely imbalanced"""
    df_candidates = holder
    # num_samples = int(0.7 * len(df_candidates)) # lambda will automatically look for instance attributes
    grouped = df_candidates.groupby('isNodule')
    df_train = grouped.apply(lambda x: x.sample(int(int(0.7 * len(holder)) * len(x) / len(holder))), include_groups=False) \
        .reset_index(drop=False, inplace=False)
    df_val = df_candidates.drop(df_train.index).reset_index(drop=False, inplace=False)
    
    
    # dataloader probably does shallow copy of the object when numworkers > 0
    # so df_candidates must stay outside of __init__ for it to be copied to each worker
    def __init__(self, *, frac=.7, balance=True) -> None:
        # if not hasattr(LunaDataset, 'df_candidates') or LunaDataset.df_candidates is None:
        #     LunaDataset.df_candidates = create_df_candidates_info() # no copy, beware
        #     LunaDataset.df_candidates = self.df_candidates.sample(frac=1) #shuffle
        # self.frac_split_idx = int(frac * len(self.df_candidates))
        self.balance = balance
        self.positives, self.negatives = self.split_neg_pos(self.df_candidates)
        
    def __len__(self):
        return len(self.df_candidates)
    
    def __getitem__(self, idx):
        if self.balance:
            pos_idx = idx // (self.balance + 1) 
            # every balance + 1 samples, we will have a positive sample

            if idx % (self.balance + 1):
                neg_idx = idx - 1 - pos_idx # adjust the idx 
                neg_idx %= len(self.negatives)
                candidateInfo = self.negatives.iloc[neg_idx]
            else:
                pos_idx %= len(self.positives) #pos_list is small, so need to wraps around, otherwise will overflow
                candidateInfo = self.positives.iloc[pos_idx]
        else: # if balance is fasle, then we don't need to balance the dataset
            # this is for validation set
            candidateInfo = self.df_candidates.iloc[idx]
            
        ct_cropped = get_ct_cropped_disk_cache(candidateInfo['seriesuid'], candidateInfo['xyzCoord'])
        ct_cropped = torch.tensor(ct_cropped).unsqueeze(0) # add channel input dimension

        
        isNodule_label = candidateInfo['isNodule']
        # one_hot_encoding_tensor = F.one_hot(labels).to(torch.float32)
        
        # one_hot_encoding_tensor = torch.tensor([0, 1]) if isNodule_label else torch.tensor([1, 0])
        # one_hot_encoding_tensor = one_hot_encoding_tensor.to(torch.long)
        
        one_hot_encoding_tensor = torch.tensor(isNodule_label).to(torch.long)
        return ct_cropped, one_hot_encoding_tensor
    
    def split_neg_pos(self, df_candidates) -> tuple[pd.DataFrame, pd.DataFrame]:
        return (df_candidates[self.df_candidates['isNodule']], 
                df_candidates[~self.df_candidates['isNodule']])
        


# training and validation datasets classes, which share the same parent self.df_candidates, 
# but the shared df_candidates is splitted
class LunaSegDataset_Train(LunaSegDataset):
    def __init__(self):
        super().__init__(balance=True)
        # self.df_candidates = self.df_candidates[:self.frac_split_idx]
        self.df_candidates = self.df_train
        self.positives, self.negatives = self.split_neg_pos(self.df_candidates)
        
        
class LunaSegDataset_Val(LunaSegDataset):
    def __init__(self):
        super().__init__(balance=False)
        # self.df_candidates = self.df_candidates[self.frac_split_idx:]
        self.df_candidates = self.df_val
        # self.positives, self.negatives = self.split_neg_pos(self.df_candidates)


KeyboardInterrupt: 