In [1]:
import numpy as np
import pandas as pd
import pydicom
import os
import scipy.ndimage
import matplotlib.pyplot as plt
from matplotlib import patches
import SimpleITK as sitk
from skimage import measure, morphology
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from pathlib import Path
import json
from preproc import load_scan, resample
from tqdm.notebook import tqdm
from extract_cubes import coord_to_idx, slice_with_padding, get_slices
from torch.utils.data import Dataset, DataLoader, Subset, SequentialSampler
from timeit import default_timer
import itertools

In [2]:
class LUNA16DatasetFromCubes(Dataset):
    def __init__(self, cube_root_path, candidates_file, subsets=None):
        self.root = Path(cube_root_path)
        self.candidates = pd.read_csv(self.root / candidates_file)
        self.pos_arr = np.load(self.root / 'pos.npy')
        with open(self.root / 'uid_to_subset.json') as f:
            self.uid_to_subset = json.load(f)
        self.candidates['subset'] = self.candidates['seriesuid'].apply(self.uid_to_subset.get)
        if subsets is not None:
            self.candidates = self.candidates[self.candidates['subset'].isin(subsets)]
            self.candidates = self.candidates.reset_index(drop=True)
        self.pos_sample_idx = self.candidates[self.candidates['class'] == 1].index.to_numpy()
        self.neg_sample_idx = self.candidates[self.candidates['class'] == 0].index.to_numpy()
    
    def __len__(self):
        return len(self.candidates)
    
    def __getitem__(self, idx):
        row = self.candidates.iloc[idx]
        y = int(row['class'])
        if row['class'] == 1:
            X = self.pos_arr[row['i'],:,:,:,:]
        else:
            arr = np.load(self.root / self.uid_to_subset[row['seriesuid']] / ('neg_%s.npy' % row['seriesuid']))
            X = arr[row['i'],:,:,:,:]
        return X, y

In [3]:
class LUNA16DatasetFromIso(Dataset):
    def __init__(self, iso_root_path, candidates_file, subsets=None):
        self.root = Path(iso_root_path)
        self.candidates = pd.read_csv(self.root / candidates_file)
        with open(self.root / 'uid_to_subset.json') as f:
            self.uid_to_subset = json.load(f)
        self.candidates['subset'] = self.candidates['seriesuid'].apply(self.uid_to_subset.get)
        if subsets is not None:
            self.candidates = self.candidates[self.candidates['subset'].isin(subsets)]
            self.candidates = self.candidates.reset_index(drop=True)
        self.metadata = pd.read_csv(self.root / 'seriesuid_isometric_spacing_origin_direction.csv').set_index('seriesuid')
        self.pos_sample_idx = self.candidates[self.candidates['class'] == 1].index.to_numpy()
        self.neg_sample_idx = self.candidates[self.candidates['class'] == 0].index.to_numpy()
        self.cached_arr = None
        self.cached_seriesuid = None
    
    def __len__(self):
        return len(self.candidates)
    
    def __getitem__(self, idx):
        row = self.candidates.iloc[idx]
        seriesuid = row['seriesuid']
        if (self.cached_seriesuid is not None) and (self.cached_seriesuid == seriesuid):
            arr = self.cached_arr
        else:
            self.cached_seriesuid = seriesuid
            arr = np.load(self.root / self.uid_to_subset[seriesuid] / ('%s.npy' % seriesuid))
            self.cached_arr = arr
        coord = row[['coordX', 'coordY', 'coordZ']].astype(float).to_numpy()
        spacing = self.metadata.loc[seriesuid][:3].to_numpy()
        origin = self.metadata.loc[seriesuid][3:6].to_numpy()
        direction = self.metadata.loc[seriesuid][6:9].to_numpy()
        idx = coord_to_idx(coord, spacing, origin, direction)
        slices = get_slices(idx)
        X = slice_with_padding(slices, arr)
        y = int(row['class'])
        return X, y

In [4]:
d2 = LUNA16DatasetFromIso(
    '/scratch/zc2357/cv/final/datasets/luna16_iso',
    'candidates_V2.csv',
    subsets=['subset0'],
)

In [5]:
def shuffle_wrapper(x):
    np.random.shuffle(x)
    return x

df = (
    d2.candidates.loc[d2.neg_sample_idx]
    .copy().reset_index()
    .groupby('seriesuid')['index'].unique()
    .apply(shuffle_wrapper)  # shuffle within cases
    .apply(list)
)
df = df.sample(len(df))      # shuffle case order
neg_idx_shuffled = list(itertools.chain.from_iterable(df.values))  # flatten

In [6]:
dl = DataLoader(d2, batch_size=64, shuffle=False, sampler=neg_idx_shuffled, num_workers=1)

In [7]:
for i in range(100):
    start = default_timer()
    next(iter(dl))
    stop = default_timer()
    print(stop-start)

1.1050026081502438
0.9274624157696962
0.9287596885114908
0.9240182358771563


KeyboardInterrupt: 