In [None]:
import sys, os
if os.path.abspath(os.pardir) not in sys.path:
    sys.path.insert(0, os.path.abspath(os.pardir))
import CONFIG
%reload_ext autoreload
%autoreload 2

In [70]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn 
from torch.utils.data import Dataset, DataLoader
import pydicom
import matplotlib.pyplot as plt

In [None]:
DATA_DIR = CONFIG.CFG.DATA.BASE
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# number of images used to create a single 3D array of the scan
NUM_IMAGES = 8

In [None]:
train_df = pd.read_csv(os.path.join(DATA_DIR, "train.csv"))
test_df = pd.read_csv(os.path.join(DATA_DIR, "test.csv"))

In [None]:
TRAIN_PATIENTS = train_df['Patient'].unique().tolist()
# gave the gdcm error
BAD_PATIENT_IDS = ['ID00011637202177653955184', 'ID00052637202186188008618']
ALL_TRAIN_PATIENTS = [pat for pat in TRAIN_PATIENTS if pat not in BAD_PATIENT_IDS]
ALL_TEST_PATIENTS = test_df['Patient'].unique().tolist()

In [64]:
path = os.path.join(DATA_DIR, "train", TRAIN_PATIENTS[4])
all_files = os.listdir(path)
all_files.sort(key = lambda x: int(x.split('.')[0]))
slices = [pydicom.read_file(os.path.join(path, s)) for s in all_files]

In [66]:
def get_averaged_slices(patient_id, folder_path, num_images):
    # the preprocessed array with NUM_SLICES elements
    # TODO: Handle the case when the NUM_SLICES > the actual total slices
    # TODO: resize the image to 256 X 256?

    full_path = os.path.join(folder_path, patient_id)
    # list of all files in that path and sort them
    all_files = os.listdir(full_path)
    # sorted using the first number part of the file name
    all_files.sort(key = lambda x: int(x.split('.')[0]))

    # read all the dicom files for the patient into the slices list
    slices = [pydicom.read_file(os.path.join(full_path, s)) for s in all_files]
    # sort the slices using their order (file number works too)
    # slices.sort(key = lambda x: int(x.ImagePositionPatient[2]))

    # final array containing averaged num_images images
    out_array = []

    # how many extra files while averaging all images into (num_images) images
    remainder_array_size = len(slices)%num_images

    # how many to average to get a single averaged image
    avging_array_size = len(slices)//num_images

    # get the first one with the remainder images
    first_array = []
    # select the first remainder + avg_arrray_size imgaes and average into one
    for slice in slices[:remainder_array_size+avging_array_size]:
        first_array.append(slice.pixel_array)
    first_avged_array = np.average(first_array, axis=0)
    out_array.append(first_avged_array)

    # after the first one get the remaining ones into out_array rolling averaging (avging_array_size) at a time.
    for i in range(remainder_array_size + avging_array_size, len(slices), avging_array_size):
        temp_array = []
        for slice in slices[i:i+avging_array_size]:
            temp_array.append(slice.pixel_array)
        avged_temp_array = np.average(temp_array, axis=0)
        out_array.append(avged_temp_array)
    
    return np.array(out_array)

In [67]:
array_from_id = {}

In [68]:
# store the train and test images in array_from_id
for id in ALL_TRAIN_PATIENTS:
    array_from_id[id] = get_averaged_slices(id, os.path.join(DATA_DIR, "train"), NUM_IMAGES)

for id in ALL_TEST_PATIENTS:
    array_from_id[id] = get_averaged_slices(id, os.path.join(DATA_DIR, "test"), NUM_IMAGES)

In [72]:
class PulmonaryDataset(Dataset):
    def __init__(self, df, FV, test=False):
        self.df = df
        self.test = test
        self.FV = FV

    def __getitem__(self, idx):
        return {
            'imgarray': torch.from_numpy(array_from_id[self.df.iloc[idx]['Patient']]),
            'tabfeatures': torch.tensor(self.df[self.FV].iloc[idx].values),
            'target': torch.tensor(self.df['FVC'].iloc[idx])
        }

    def __len__(self):
        return len(self.df)

In [74]:
for key in array_from_id:
    print(array_from_id[key].shape)

(8, 512, 512)
(8, 768, 768)
(8, 512, 512)
(8, 512, 512)
(8, 843, 888)
(8, 768, 768)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 768, 768)
(8, 768, 768)
(8, 768, 768)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 768, 768)
(8, 768, 768)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 733, 888)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 768, 768)
(8, 768, 768)
(8, 788, 888)
(8, 768, 768)
(8, 512, 512)
(8, 512, 512)
(8, 752, 888)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 768, 768)
(8, 768, 768)
(8, 512, 512)
(8, 768, 768)
(8, 512, 512)
(8, 1302, 1302)
(8, 512, 512)
(8, 512, 512)
(8, 734, 888)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 768, 768)
(8, 768, 768)
(8, 512, 512)
(8, 512, 512)
(8, 768, 768)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 512, 512)
(8, 