# DICOM slices to torch tensors
- Converting the DICOM slices into torch tensors and save the CT Volumes in different folder
- Window the image based on the window information in the dicom slice
- The window information is same for all the slices. So replaced with the common value if the information is missing in the dicom file.

In [20]:
import numpy as np
from pydicom import dcmread
from pydicom import multival
import os
from ipywidgets import interact
import ipywidgets as widgets

import matplotlib.pyplot as plt
from scipy import ndimage
import SimpleITK as sitk
from einops import rearrange
from tqdm import tqdm
import shutil
import torch

from skimage.transform import resize

In [2]:
CT_Slice_DIR = "/Users/sudarshan/darshanz/datasets/lung1/CT_ONLY/slices"
CT_Vol_DIR = "/Users/sudarshan/darshanz/datasets/lung1/CT_ONLY/vols"

In [3]:
# Function to take care of teh translation and windowing. 
def window_image(img, window_center,window_width, intercept, slope, rescale=True):
    img = (img*slope +intercept) #for translation adjustments given in the dicom file. 
    img_min = window_center - window_width//2 #minimum HU level
    img_max = window_center + window_width//2 #maximum HU level
    img[img<img_min] = img_min #set img_min for all HU levels less than minimum HU level
    img[img>img_max] = img_max #set img_max for all HU levels higher than maximum HU level
    if rescale: 
        img = (img - img_min) / (img_max - img_min)*255.0 
    img_resized = resize(img, (224, 224))
    return img_resized

def get_first_of_dicom_field_as_int(x):
    #get x[0] as in int is x is a 'pydicom.multival.MultiValue', otherwise get int(x)
    if type(x) == multival.MultiValue: return int(x[0])
    else: return int(x)
    
def get_windowing(data):
    dicom_fields = [data[('0028','1050')].value, #window center
                    data[('0028','1051')].value, #window width
                    data[('0028','1052')].value, #intercept
                    data[('0028','1053')].value] #slope
    return [get_first_of_dicom_field_as_int(x) for x in dicom_fields]

#### iterate through all subjects

In [4]:
for subject_ in tqdm(sorted(os.listdir(CT_Slice_DIR))):
    if subject_ not in '.DS_Store':
        #check slices
        ls_vol = []
        slices = os.listdir(f'{CT_Slice_DIR}/{subject_}')
        for dcm_file in sorted(slices):
            data = dcmread(f"{CT_Slice_DIR}/{subject_}/{dcm_file}")
            try:
                window_center , window_width, intercept, slope = get_windowing(data)  
            except:
                window_center , window_width, intercept, slope = 40, 400, -1024, 1
            output = window_image(data.pixel_array, window_center, window_width, intercept, slope, rescale = False)
            ls_vol.append(output)

        torch.save(torch.tensor(np.array(ls_vol), dtype=torch.float64), f"{CT_Vol_DIR}/{subject_}.pt")

100%|█████████████████████████████████████████| 423/423 [04:49<00:00,  1.46it/s]


### make number of slices equal

In [15]:
slice_counts = []
for subject_ in tqdm(sorted(os.listdir(CT_Vol_DIR))):
    if subject_ not in '.DS_Store':
        ct = torch.load(f'{CT_Vol_DIR}/{subject_}')
        slice_counts.append(ct.shape[0])
print(max(slice_counts))

100%|█████████████████████████████████████████| 423/423 [00:12<00:00, 34.36it/s]

297





In [17]:
cv_volume
slice_counts_new = []
for subject_ in tqdm(sorted(os.listdir(CT_Vol_DIR))):
    if subject_ not in '.DS_Store':
        ct = torch.load(f'{CT_Vol_DIR}/{subject_}')
        ct_resized = torch.cat((ct, torch.zeros(300 - ct.size(0), 224, 224)), dim=0)
        slice_counts_new.append(ct_resized.shape[0])
        cv_volume = ct_resized
print(max(slice_counts_new))

100%|█████████████████████████████████████████| 423/423 [00:15<00:00, 27.12it/s]

300





In [22]:
ct_resized.shape

torch.Size([300, 224, 224])

In [24]:
def plot_func(slice_num, img_):
    plt.imshow(img_[slice_num], plt.cm.gray)

interact(plot_func, slice_num = widgets.IntSlider(value=1, min=0,  max=299, step=1), img_=widgets.fixed(ct_resized))

interactive(children=(IntSlider(value=1, description='slice_num', max=299), Output()), _dom_classes=('widget-i…

<function __main__.plot_func(slice_num, img_)>

### Save the resized volumes

In [25]:
for subject_ in tqdm(sorted(os.listdir(CT_Vol_DIR))):
    if subject_ not in '.DS_Store':
        ct = torch.load(f'{CT_Vol_DIR}/{subject_}')
        ct_resized = torch.cat((ct, torch.zeros(300 - ct.size(0), 224, 224)), dim=0)
        torch.save(ct_resized, f'{CT_Vol_DIR}/{subject_}')

100%|█████████████████████████████████████████| 423/423 [01:22<00:00,  5.13it/s]
