In [1]:
import collections
import glob
import SimpleITK as sitk
import numpy as np

unzipped_path = 'C:\\Users\\justm\\OneDrive\\Desktop\\New folder\\unzipped\\'

IrcTuple = collections.namedtuple('IrcTuple', ['index', 'row', 'col'])
XyzTuple = collections.namedtuple('XyzTuple', ['x', 'y', 'z'])

def irc2xyz(coord_irc, origin_xyz, vxSize_xyz, direction_matrix):
    cri_a = coord_irc[::-1]
    # cri_a = np.array(coord_irc)[::-1]
    # origin_a = np.array(origin_xyz)
    # vxSize_a = np.array(vxSize_xyz)
    coords_xyz = (direction_matrix @ (cri_a * vxSize_xyz)) + origin_xyz
    return coords_xyz

def xyz2irc(coord_xyz, origin_xyz, vxSize_xyz, direction_matrix):
    # origin_a = np.array(origin_xyz)
    # vxSize_a = np.array(vxSize_xyz)
    # coord_a = np.array(coord_xyz)
    cri_a = ((coord_xyz - origin_xyz) @ np.linalg.inv(direction_matrix)) / vxSize_xyz
    return np.round(cri_a).astype(int)[::-1]
    # return IrcTuple(int(cri_a[2]), int(cri_a[1]), int(cri_a[0]))

In [2]:

class Ct:
    def __init__(self, series_uid):
        mhd_path = glob.glob(
            unzipped_path + 'subset*/subset*/{}.mhd'.format(series_uid)
        )[0]

        ct_metadata = sitk.ReadImage(mhd_path)
        # ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
        ct_arr = sitk.GetArrayFromImage(ct_metadata).astype(np.float32)
        #crop HU values to [-1000, 1000]
        ct_arr.clip(-1000, 1000, ct_arr)

        self.series_uid = series_uid
        self.hu_a = ct_arr

        self.origin_xyz = np.array(ct_metadata.GetOrigin())
        self.vxSize_xyz = np.array(ct_metadata.GetSpacing())
        self.direction_a = np.array(ct_metadata.GetDirection()).reshape(3, 3)
        
    def get_candidate_croppedChunk_inVoxelCoord(self, center_xyz, width_irc = (32, 48, 48)):
        center_irc = xyz2irc(
            np.array(center_xyz),
            self.origin_xyz,
            self.vxSize_xyz,
            self.direction_a,
        )

        slice_list = []
        for axis, center_val in enumerate(center_irc):
            start_ndx = int(round(center_val - width_irc[axis]/2))
            end_ndx = int(start_ndx + width_irc[axis])

            assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
            # assert check irc center is within the CT array
            
            if start_ndx < 0:
                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
                start_ndx = 0
                end_ndx = int(width_irc[axis])
                # if start_ndx < 0, set start_ndx to 0 and end_ndx to width_irc[axis]

            if end_ndx > self.hu_a.shape[axis]:
                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
                end_ndx = self.hu_a.shape[axis]
                start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
                # if end_ndx > axis size, set end_ndx to axis size and start_ndx to axis size - width_irc[axis]

            slice_list.append(slice(start_ndx, end_ndx))

        ct_chunk = self.hu_a[tuple(slice_list)]

        return ct_chunk

In [3]:
# Ct('1.3.6.1.4.1.14519.5.2.1.6279.6001.100225287222365663678666836860')

ct scan instance class:

xyz <-> irc #irc needs to be flipped the order to cri

sitk(seriesuid) read image into array
clip HU values [-1000, 1000]
lru cache single ct instance
    # that's why need lru cache of size 1, 
    # but we shuffled the candidateinfolist, size 1 cache is not enough
    # so ct object cache is only useful during prepcache, the dset always rely on diskcached ct chunks
disk cache cropped ct chunks = function get_candidate_croppedChunk_inVoxelCoord
crop the ct to the size of a width_irc tuple, each axis is to be (axisCoord +- width_irc[axis]/2)
# assert check irc center is within the CT array
# if start_ndx < 0, set start_ndx to 0 and end_ndx to width_irc[axis]
# if end_ndx > axis size, set end_ndx to axis size and start_ndx to axis size - width_irc[axis]

In [4]:
from functools import lru_cache

from utilG import getCacheHandle, unzipped_path


disk_cache = getCacheHandle('test1')

class CTScan:
    def __init__(self, seriesuid) -> None:
        self.seriesuid = seriesuid
        path_mhdfile = glob.glob(unzipped_path + 'subset*/subset*/{}.mhd'.format(seriesuid))[0]
        ct_img = sitk.ReadImage(path_mhdfile) #contain metadata getters
        ct_img_arr = sitk.GetArrayFromImage(ct_img).astype(np.float32)
        self.ct_img_arr = ct_img_arr
        #crop HU values to [-1000, 1000]
        np.clip(ct_img_arr, -1000, 1000, out=ct_img_arr)
        
        self.origin_xyz = np.array(ct_img.GetOrigin())
        self.vxSize_xyz = np.array(ct_img.GetSpacing())
        self.direction_matrix = np.array(ct_img.GetDirection()).reshape(3, 3)
    
    def get_candidate_croppedChunk_inVoxelCoord(self, center_xyz, axis_sizes = (32, 48, 48)):
        """
        center_xyz: tuple of 3 floats, center of the chunk in xyz coord
        axis_size: tuple of 3 integers, size of the chunk in each axis. Default is (32, 48, 48)
        """
        center_irc = xyz2irc(np.array(center_xyz), self.origin_xyz, self.vxSize_xyz, self.direction_matrix)
        slices = []
        for idx, axis_size in enumerate(axis_sizes):
            start_idx = int(round(center_irc[idx] - axis_size/2))
            end_idx = int(start_idx + axis_size)

            # if start_idx < 0, set start_idx to 0 and end_idx to axis_size[axis]
            # if end_idx > ct_sizes[axis], set end_idx to axis_size and start_idx to ct_sizes[axis] - axis_size
            ct_sizes = self.ct_img_arr.shape
            
            if start_idx < 0:
                start_idx = 0
                end_idx = axis_size
            if end_idx > ct_sizes[idx]:
                end_idx = ct_sizes[idx]
                start_idx = int(ct_sizes[idx] - axis_size)
            
            slices.append(slice(start_idx, end_idx))
        ct_cropped = self.ct_img_arr[tuple(slices)]
        return ct_cropped

@lru_cache(maxsize=1, typed=True)
def get_single_ct_lru_cache(seriesuid):
    ct = CTScan(seriesuid)
    return ct

@disk_cache.memoize(typed=True)
def get_ct_cropped_disk_cache(seriesuid, center_xyz, axis_sizes = (32, 48, 48)):
    ct = get_single_ct_lru_cache(seriesuid)
    # this is why need lru cache of size 1, 
    # but we shuffled the dataset, size 1 cache is not enough
    # so single ct object cache is only useful during prepcache, the datasets always rely on diskcached ct chunks
    return ct.get_candidate_croppedChunk_inVoxelCoord(center_xyz, axis_sizes)


In [5]:
# ctG = CTScan('1.3.6.1.4.1.14519.5.2.1.6279.6001.625270601160880745954773142570')
# ct = Ct('1.3.6.1.4.1.14519.5.2.1.6279.6001.625270601160880745954773142570')
# ctG

<__main__.CTScan at 0x1807eb2f590>

In [7]:
# (ctG.get_candidate_croppedChunk_inVoxelCoord((-108.2007072, 46.48017452, -143.2481594)) == ct.get_candidate_croppedChunk_inVoxelCoord((-108.2007072, 46.48017452, -143.2481594))).all()

True

In [4]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from test.util.util import enumerateWithEstimate
from dsetsG import LunaDataset, LunaDataset_Train, LunaDataset_Val


dataset = LunaDataset()
trainset = LunaDataset_Train()
valset = LunaDataset_Val()
print(len(dataset), len(trainset), len(valset))
batch_size = 64
shuffle = True
num_workers = 4
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers)
valloader = DataLoader(valset, batch_size=batch_size, num_workers=num_workers)

duplicate_rows = pd.merge(trainset.df_candidates, valset.df_candidates, how='inner')
assert len(duplicate_rows) == 0, "duplicate rows in train and val set"
len(duplicate_rows), len(trainset.df_candidates), len(valset.df_candidates)

# batch_iter = enumerateWithEstimate(dataloader,"test cache 1", start_ndx=num_workers, jump = 4)
# for _ in batch_iter:
#     pass
# batch_iter = enumerateWithEstimate(trainloader,"test cache 2", start_ndx=num_workers,)
# for _ in batch_iter:
#     pass
# batch_iter = enumerateWithEstimate(valloader,"test cache 3", start_ndx=num_workers,)
# for _ in batch_iter:
#     pass
# dataset.df_candidates

55603 38922 16681


(0, 38922, 16681)

In [3]:
class A:
    def __init__(self) -> None:
        A.a = 'hehe'
    def test(self):
        print(self.a)
A().test()

hehe


In [1]:
import os
import shutil

# clean up any old data that might be around.
# We don't call this by default because it's destructive, 
# and would waste a lot of time if it ran when nothing 
# on the application side had changed.
def cleanCache():
    shutil.rmtree('C:\\Users\\justm\\OneDrive\\Desktop\\Luna Grand Chalenge\\disk_cache')
    os.mkdir('C:\\Users\\justm\\OneDrive\\Desktop\\Luna Grand Chalenge\\disk_cache')

# cleanCache()