In [1]:
import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt
import os
import zarr
from tqdm import tqdm
import fastparquet
from joblib import Parallel, delayed
import cv2
# specify the paths to high and low quality datasets
hi_res_path = './datasets/High_Res_T2'
low_res_path = './datasets/Low_Res_T2'

## Load All Data From nii

In [6]:

#create a dataframe to keep track of the data
df = pd.DataFrame(columns=['pt', 'slice_num', 'file_name', 'MR_type', 'slice_width', 'voxel_size'])
# use the `load` function from `nibabel` to load the image
for root, dirs, files in os.walk(hi_res_path):
    for fp in tqdm(files):
        if '.gz' in fp:
            pt = fp.split('-')[0]
            file_name = fp

            hi_res = nib.load(os.path.join(hi_res_path, fp))
            low_res = nib.load(os.path.join(low_res_path, fp))

            # get the image data
            # hi_res_data = hi_res.get_fdata()
            # low_res_data = low_res.get_fdata()

            # get the image header
            hi_res_header = hi_res.header
            low_res_header = low_res.header
            voxel_size = hi_res_header.get_zooms()
            slice_width = hi_res_header.get_data_shape()[2]

            for i in range(slice_width):
                row = pd.DataFrame([[pt, i, file_name, 'T2', slice_width, voxel_size]], columns=['pt', 'slice_num', 'file_name', 'MR_type', 'slice_width', 'voxel_size'])
                df = pd.concat([df, row], ignore_index=True)
                
df.to_pickle('all_data.pkl')

  4%|▍         | 22/578 [00:04<01:47,  5.19it/s]


KeyboardInterrupt: 

## Split Data Into Train/Val/Test Datasets

In [1]:
df = pd.read_pickle('all_data.pkl')
pt_list = df['pt'].unique()

# randomly select 80% of the patients for training
train_pt_list = np.random.choice(pt_list, int(len(pt_list)*0.8), replace=False)

# the remaining patients are for testing
test_pt_list = np.setdiff1d(pt_list, train_pt_list)

#randomly select 20% of training patients for validation
val_pt_list = np.random.choice(train_pt_list, int(len(train_pt_list)*0.15), replace=False)

#remove val patients from train patients
train_pt_list = np.setdiff1d(train_pt_list, val_pt_list)

# create a dataframe for each of the training, validation, and testing sets
train_df = df[df['pt'].isin(train_pt_list)]
val_df = df[df['pt'].isin(val_pt_list)]
test_df = df[df['pt'].isin(test_pt_list)]

# save the dataframes to parquet files
os.makedirs('split_datasets', exist_ok=True)
os.makedirs('split_datasets/train', exist_ok=True)
os.makedirs('split_datasets/val', exist_ok=True)
os.makedirs('split_datasets/test', exist_ok=True)
train_df.to_pickle('split_datasets/train_df.pkl')
val_df.to_pickle('split_datasets/val_df.pkl')
test_df.to_pickle('split_datasets/test_df.pkl')

NameError: name 'pd' is not defined

In [2]:
#load images and save as zarr files
train_df = pd.read_pickle('split_datasets/train_df.pkl')
val_df = pd.read_pickle('split_datasets/val_df.pkl')
test_df = pd.read_pickle('split_datasets/test_df.pkl')

train_pts = train_df['pt'].unique()
val_pts = val_df['pt'].unique()
test_pts = test_df['pt'].unique()

gz_files = []
for root, dirs, files in os.walk(hi_res_path):
        for fp in tqdm(files):
            if '.gz' in fp:
                gz_files.append(fp)

#create a parralelized function to save the images as zarr files
def save_zarr(fp, hi_res_path, low_res_path, train_pts, val_pts, test_pts, resize_factor=0.5):
            pt = fp.split('-')[0]
            

            hi_res = nib.load(os.path.join(hi_res_path, fp))
            low_res = nib.load(os.path.join(low_res_path, fp))

            # get the image data
            hi_res_data = hi_res.get_fdata()
            low_res_data = low_res.get_fdata()
            #save each frame as a zarr file
            for i in range(hi_res_data.shape[2]):
                hi_res_frame = hi_res_data[:,:,i]
                low_res_frame = low_res_data[:,:,i]
                #resize image
                low_res_frame = cv2.resize(low_res_frame, (128,128))
                                
                if pt in train_pts:
                    assert pt not in val_pts and pt not in test_pts
                    zarr.save(os.path.join('split_datasets/train', f'hi_res_{pt}_{i}.zarr'), hi_res_frame)
                    zarr.save(os.path.join('split_datasets/train', f'low_res_{pt}_{i}.zarr'), low_res_frame)
                elif pt in val_pts:
                    assert pt not in train_pts and pt not in test_pts
                    zarr.save(os.path.join('split_datasets/val', f'hi_res_{pt}_{i}.zarr'), hi_res_frame)
                    zarr.save(os.path.join('split_datasets/val', f'low_res_{pt}_{i}.zarr'), low_res_frame)
                elif pt in test_pts:
                    assert pt not in train_pts and pt not in val_pts
                    zarr.save(os.path.join('split_datasets/test', f'hi_res_{pt}_{i}.zarr'), hi_res_frame)
                    zarr.save(os.path.join('split_datasets/test', f'low_res_{pt}_{i}.zarr'), low_res_frame)
                else:
                    raise ValueError('Patient not in any dataset')

#parralelize this function using joblib
Parallel(n_jobs=8, prefer='threads')(delayed(save_zarr)(fp, hi_res_path, low_res_path, train_pts, val_pts, test_pts) for fp in tqdm(gz_files))          
print('done')


100%|██████████| 578/578 [00:00<00:00, 577766.38it/s]
100%|██████████| 578/578 [09:08<00:00,  1.05it/s]


done
