In [1]:
from glob import glob
import pydicom as dicom
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from scipy.ndimage import center_of_mass

In [3]:
data_folder = Path('manifest-1686081801328')  # Set your data folder here
# Print metadata
metadata_df = pd.read_csv(data_folder/'metadata.csv')
metadata_df.tail()

Unnamed: 0,Series UID,Collection,3rd Party Analysis,Data Description URI,Subject ID,Study UID,Study Description,Study Date,Series Description,Manufacturer,Modality,SOP Class Name,SOP Class UID,Number of Images,File Size,File Location,Download Timestamp
1198,2.25.179810891732019793459496497832233425115.1,PDMR-Texture-Analysis,NO,https://doi.org/10.7937/3KQ0YK19,BL0382-F1232-1724,2.25.96089192699936339821029186870284643178,NCI PDMR Tumor Characterization,06-11-2020,TSE45 split,Philips Medical Systems,MR,Raw Data Storage,1.2.840.10008.5.1.4.1.1.66,1,265.47 KB,.\PDMR-Texture-Analysis\BL0382-F1232-1724\06-1...,2023-11-17T16:10:23.58
1199,1.3.6.1.4.1.5962.1.2.0.1670154536.70508.0.175.2,PDMR-Texture-Analysis,NO,https://doi.org/10.7937/3KQ0YK19,BL0382-F1232-1728,1.3.6.1.4.1.5962.1.2.0.1670154536.70508.0.175.1,NCI PDMR Tumor Characterization,05-20-2020,PDM Mouse Overview,PixelMed,SR,Acquisition Context SR Storage,1.2.840.10008.5.1.4.1.1.88.71,1,4.41 KB,.\PDMR-Texture-Analysis\BL0382-F1232-1728\05-2...,2023-11-17T16:10:24.266
1200,2.25.308200546840067706797055294221268540073.1,PDMR-Texture-Analysis,NO,https://doi.org/10.7937/3KQ0YK19,BL0382-F1232-1728,2.25.295695138875747345906456588608783888488,NCI PDMR Tumor Characterization,06-11-2020,TSE45 split,Philips Medical Systems,MR,Raw Data Storage,1.2.840.10008.5.1.4.1.1.66,1,265.46 KB,.\PDMR-Texture-Analysis\BL0382-F1232-1728\06-1...,2023-11-17T16:10:27.462
1201,2.25.308200546840067706797055294221268540073,PDMR-Texture-Analysis,NO,https://doi.org/10.7937/3KQ0YK19,BL0382-F1232-1728,2.25.295695138875747345906456588608783888488,NCI PDMR Tumor Characterization,06-11-2020,TSE45 split,Philips Medical Systems,MR,MR Image Storage,1.2.840.10008.5.1.4.1.1.4,36,22.51 MB,.\PDMR-Texture-Analysis\BL0382-F1232-1728\06-1...,2023-11-17T16:10:45.81
1202,2.25.243475922033648146678409629073405401370,PDMR-Texture-Analysis,NO,https://doi.org/10.7937/3KQ0YK19,997537-175-T-1327,2.25.244508863228859243918875004939444840754,NCI PDMR Tumor Characterization,11-14-2018,TSE45 split,Philips Medical Systems,MR,MR Image Storage,1.2.840.10008.5.1.4.1.1.4,36,22.51 MB,.\PDMR-Texture-Analysis\997537-175-T-1327\11-1...,2023-11-17T16:12:46.909


In [4]:
level1_paths = glob(str(data_folder/'PDMR-Texture-Analysis'/'*'))
data_dict = {}
for lvl1_path in level1_paths:
    level2_paths = glob(f'{lvl1_path}/*')
    lvl2_dict = {}
    for lvl2_path in level2_paths:
        lvl3_dict = {}
        lvl3_paths = glob(f'{lvl2_path}/*')
        for scan_path in lvl3_paths:
            dcm_slices = glob(f'{scan_path}/*.dcm')
            lvl3_dict[Path(scan_path).name] = dcm_slices
        lvl2_dict[Path(lvl2_path).name] = lvl3_dict
    data_dict[Path(lvl1_path).name] = lvl2_dict

In [6]:
import nrrd
data1, header1 = nrrd.read('segmentation_test\\3012 TSE45 split.nrrd')
data2, header2 = nrrd.read('segmentation_test\\Segmentation.seg.nrrd')

FileNotFoundError: [Errno 2] No such file or directory: 'segmentation_test\\3012 TSE45 split.nrrd'

In [7]:
n_cols = 10

def plot_scan(image_paths):
    N = len(image_paths)
    n_rows = int(np.ceil(N / n_cols))
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 14))
    axes = axes.flatten()
    for i, img_path in enumerate(image_paths):
        ds = dicom.dcmread(img_path)
        ds.pixel_array
        axes[i].imshow(ds.pixel_array[200:-200])
        axes[i].axis('off')

    for j in range(i + 1, len(axes)):
        axes[j].axis('off') 
    return fig, axes

def plot_single(image_paths, n=17):

    ds = dicom.dcmread(image_paths[n])
    img = ds.pixel_array[200:-200]
    fig, ax = plt.subplots(figsize=(8, 12))
    ax.imshow(img)
    ax.axis('off')

    return fig, ax

lvl1_keys = list(data_dict.keys())
for key1 in lvl1_keys:
    lvl2_data = data_dict[key1]
    lvl2_keys = list(lvl2_data.keys())
    for key2 in lvl2_keys:
        lvl3_data = data_dict[key1][key2]
        lvl3_keys = list(lvl3_data.keys())
        for scan_name in lvl3_keys: # Chosing second folder as it usually contains the correct scan with multiple images
            scan_image_paths = data_dict[key1][key2][scan_name]
            for path in scan_image_paths:
                img_name = Path(path).stem
                img_order = int(img_name.split('-')[1])
                img_scan_name = scan_name
                img_inspection_name = key2
                img_inspection_date = ''.join(key2.split('-')[0:3])
                img_specimen_name = key1
                img_original_patient_id = ''.join(key2.split('-')[0:3])
                img_mouse_id = key2.split('-')[-1]
                

            # #fig, axes = plot_scan(scan_image_paths)
            # #fig.suptitle(key2) 
            # #fig.tight_layout()
            # #fig.subplots_adjust(top=0.97)
            # #fig.savefig(f'{key2}.png', transparent=True)
            # fig_single, ax = plot_single(scan_image_paths)
            # #fig_single.suptitle(key2)
            # fig_single.tight_layout()
            # fig_single.savefig(f'{key1}_{key2}_single.png', transparent=True)
        #plt.subplots_adjust(wspace=0.1, hspace=0.1)

In [8]:
import shutil
from datetime import datetime
import os
import nibabel as nib
os.makedirs('images_2d', exist_ok=True)

def convert_dicom_to_nifti(dicom_pixels):
    # Stack the DICOM slices and convert to a NumPy array
    # Convert the image data to float and scale it
    image_data = dicom_pixels.astype(np.float32)
    image_data -= np.min(image_data)
    image_data /= np.max(image_data)

    # Create an affine matrix for the NIfTI image
    # This is a basic affine matrix, you might need to adjust it based on your DICOM metadata
    affine = np.eye(4)

    # Create the NIfTI image
    nifti_img = nib.Nifti1Image(image_data, affine)

    return nifti_img

data = {
    'img_name': [],
    'Z_order_no': [],
    'inspection_date': [],
    'scan_name': [],
    'human_patient_id': [],
    'mouse_id': [],

    'original_img_name': [],
    'inspection_name': [],

    'specimen_name': [],
}
img_id = 0
lvl1_keys = list(data_dict.keys())
print(
    lvl1_keys
)
for key1 in lvl1_keys[1:2]:
    lvl2_data = data_dict[key1]
    lvl2_keys = list(lvl2_data.keys())
    for key2 in lvl2_keys:
        lvl3_data = data_dict[key1][key2]
        lvl3_keys = list(lvl3_data.keys())
        for scan_name in lvl3_keys:
            scan_image_paths = data_dict[key1][key2][scan_name]

            if len(scan_image_paths) > 2:
                continue
            print(scan_image_paths)
            for path in scan_image_paths:
                new_img_name = f'{img_id}.nii'
                ds = dicom.dcmread(path)
                try:
                    nifti = convert_dicom_to_nifti(np.array([ds.pixel_array]))
                except Exception as e:
                    print(e)
                    continue
                nifti.to_filename(f'images_2d/{new_img_name}')
                #shutil.copy2(path, f'images_2d/{new_img_name}')
                img_name = Path(path).stem
                data['img_name'].append(new_img_name)
                data['original_img_name'].append(img_name)
                data['Z_order_no'].append(int(img_name.split('-')[1]))
                data['scan_name'].append(scan_name)
                img_original_patient_id = '-'.join(key1.split('-')[0:3])
                img_mouse_id = key1.split('-')[-1]

                data['human_patient_id'].append(img_original_patient_id)
                data['mouse_id'].append(img_mouse_id)

                date_str = '-'.join(key2.split('-')[0:3])
                data['inspection_date'].append(datetime.strptime(date_str, "%m-%d-%Y"))
                data['inspection_name'].append(key2)
                data['specimen_name'].append(key1)
                img_id += 1
df = pd.DataFrame(data)
df.to_csv('image_data_2d.csv',sep=',')



['BL0382-F1232-1714', '144126-210-T-1669', '172845-121-T-1862', '698357-238-R-1953', '698357-238-R-1955', '146476-266-R-1621', '894883-131-R-2315', '833975-119-R-1572', '695669-166-R-2043', 'BL0382-F1232-1716', '521955-158-R4-2162', '625472-104-R-1544', '765638-272-R-2017', '146476-266-R-1630', '625472-104-R-1548', '765638-272-R-2007', '625472-104-R-1545', '698357-238-R-1977', 'BL0382-F1232-1728', '894883-131-R-2320', '779769-127-R-1675', '779769-127-R-1685', '146476-266-R-1623', '287954-098-R-1995', '833975-119-R-1584', '172845-121-T-1859', '172845-121-T-1847', '172845-142-T-1265', '172845-121-T-1853', '287954-098-R-2003', '287954-098-R-2004', '287954-098-R-1987', '146476-266-R-1631', '172845-142-T-1242', '144126-210-T-1664', '466636-57-R-444', '172845-142-T-1238', '698357-238-R-1970', '172845-142-T-1246', '521955-158-R4-2177', '146476-266-R-1616', '172845-142-T-1262', '695669-166-R-2051', '172845-142-T-1247', '833975-119-R-1568', '172845-142-T-1257', '172845-121-T-1854', '833975-119-

In [37]:
X = []
y = []
for i in range(6):
    X.append(np.load(f'./2d_dataset_example/data/Xy_{i}.npy')[0])
    y.append(np.load(f'./2d_dataset_example/data/Xy_{i}.npy')[1]) 


In [40]:
y[0].shape

(960, 320)

In [47]:
X[0].max()

335.0

In [61]:
print(f"Cuda available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device name: {torch.cuda.get_device_name(0)}")

Cuda available: False


In [153]:
class UNet2D(nn.Module):
    def __init__(self):
        super(UNet2D, self).__init__()
        n_filters = 128

        # Convolutional layers with kernel size 3 and no padding (valid)
        # Encoder
        self.enc_conv1_1 = nn.Conv2d(1, n_filters, kernel_size=3, padding='valid')
        self.enc_conv1_2 = nn.Conv2d(n_filters, n_filters, kernel_size=3, padding='valid')
        self.enc_conv1_3 = nn.Conv2d(n_filters, n_filters, kernel_size=3, padding='valid')

        # Max pooling
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        # Middle encoder layers
        self.encode_conv1 = nn.Conv2d(n_filters, n_filters, kernel_size=3, padding='valid')
        self.encode_conv2 = nn.Conv2d(n_filters, n_filters, kernel_size=3, padding='valid')
        self.encode_conv3 = nn.Conv2d(n_filters, n_filters, kernel_size=3, padding='valid')
        self.encode_conv4 = nn.Conv2d(n_filters, n_filters, kernel_size=3, padding='valid')

        # Dropout
        self.dropout1 = nn.Dropout()

        # Upscaling
        self.upscale1 = nn.ConvTranspose2d(n_filters, n_filters, kernel_size=2, stride=2)

        # Concatenation and Expansion
        self.expand_conv1_1 = nn.Conv2d(2 * n_filters, 2 * n_filters, kernel_size=3, padding='valid')
        self.expand_conv1_2 = nn.Conv2d(2 * n_filters, 2 * n_filters, kernel_size=3, padding='valid')
        self.expand_conv1_3 = nn.Conv2d(2 * n_filters, n_filters, kernel_size=3, padding='valid')

        # Final convolutional layer
        self.final_conv = nn.Conv2d(n_filters, 1, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = nn.ReLU()(self.enc_conv1_1(x))
        x1 = nn.ReLU()(self.enc_conv1_2(x1))
        x1 = nn.ReLU()(self.enc_conv1_3(x1))

        x2 = self.pool1(x1)

        # Middle encoder
        x2 = nn.ReLU()(self.encode_conv1(x2))
        x2 = nn.ReLU()(self.encode_conv2(x2))
        x2 = nn.ReLU()(self.encode_conv3(x2))
        x2 = nn.ReLU()(self.encode_conv4(x2))

        # Dropout
        x2 = self.dropout1(x2)

        # Upscale
        x3 = self.upscale1(x2)

        # Concatenation
        delta = [x1_size - x3_size for x1_size, x3_size in zip(x1.size()[2:], x3.size()[2:])]
        crop_x1 = x1[:, :, delta[0]//2:x1.size(2)-delta[0]//2, delta[1]//2:x1.size(3)-delta[1]//2]
        x3 = torch.cat((x3, crop_x1), dim=1)

        # Expansion
        x3 = nn.ReLU()(self.expand_conv1_1(x3))
        x3 = nn.ReLU()(self.expand_conv1_2(x3))
        x3 = nn.ReLU()(self.expand_conv1_3(x3))

        # Output
        x_out = torch.sigmoid(self.final_conv(x3))
        
        return x_out

    def fit(self, train_loader, num_epochs, device, patch_size, verbose=True):
        optimizer = torch.optim.Adam(self.parameters())
        for epoch in range(num_epochs):
            self.train()
            for i, (images, masks) in enumerate(train_loader):
                images, masks = images.float().to(device), masks.to(device)
                optimizer.zero_grad()
                outputs = self(images)
                
                center_crop = transforms.CenterCrop((outputs.shape[-1], outputs.shape[-1]))
                resized_masks = center_crop(masks)

                loss = -1 * dice_coefficient(outputs, resized_masks).mean()
                loss.backward()
                optimizer.step()

                if verbose and i % 1 == 0:
                    print(f'Epoch : {epoch} [{i * len(images)}/{len(train_loader.dataset)} ({100. * i / len(train_loader):.0f}%)]\tLoss: {loss:.6f}')


In [148]:
class MRIScansPatchDataset(Dataset):
    def __init__(self, images, masks, patch_size, transform=None):
        self.images = images
        self.masks = masks
        self.patch_size = patch_size
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]

        centroid = center_of_mass(mask)
        image_patch, mask_patch = self.extract_patch(image, mask, centroid)

        if self.transform:
            image_patch = self.transform(image_patch)
            mask_patch = self.transform(mask_patch)

        return image_patch[np.newaxis, ...], mask_patch[np.newaxis, ...]

    def extract_patch(self, image, mask, centroid):
        y, x = int(centroid[0]), int(centroid[1])
        half_patch = self.patch_size // 2

        # IF YOU GET THE INDEX OUT OF RANGE ERROR, THEN TRY A SMALLER PATCH SIZE!
        # patch must be a square, so in case of setting boundaries, additional padding will be required, which might reduce the quality of model
        image_patch = image[y-half_patch:y+half_patch, x-half_patch:x+half_patch]
        mask_patch = mask[y-half_patch:y+half_patch, x-half_patch:x+half_patch]

        return image_patch, mask_patch
    
def dice_coefficient(pred, target, smooth=1e-12):
    intersection = (pred * target).sum(axis=(1, 2))
    return (2. * intersection + smooth) / (pred.sum(axis=(1, 2)) + target.sum(axis=(1, 2)) + smooth)

In [154]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {DEVICE}.")
BATCH_SIZE = 2
PATCH_SIZE = 128
TRAIN_EPOCHS = 10


dataset = MRIScansPatchDataset(X,
                               y,
                               patch_size=PATCH_SIZE)
train_loader = DataLoader(dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True)
model = UNet2D()
model.fit(train_loader,
          num_epochs=TRAIN_EPOCHS,
          device=DEVICE,
          patch_size=PATCH_SIZE)

Using device cpu.


In [4]:
X = []
y = []
for i in range(100):
    X_inst = np.load(f'./2d_dataset/data/Xy_{i}.npy')[0]
    y_inst = np.load(f'./2d_dataset/data/Xy_{i}.npy')[1]
    print(X_inst.shape)
    print(y_inst.shape) 

(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)
(960, 320)