In [None]:
## Notebook for processing raw MMWHS data to HighRes data and corresponding segmentations

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import glob
import numpy as np
from tqdm import tqdm
import tensorflow as tf
import skimage
import nibabel as nib
from process_utils import *

In [None]:
#Makes directories for MMWHS dataset

if not os.path.exists('./MMWHS'):
    os.makedirs('./MMWHS')
    
if not os.path.exists('./MMWHS/ims'):
    os.makedirs('./MMWHS/ims')
    
if not os.path.exists('./MMWHS/segmentations'):
    os.makedirs('./MMWHS/segmentations')

In [None]:
#Loads segmentation data
#Must split downloaded MMWHS data into './MMWHS/ims' and './MMWHS/segmentations' folders
def load_data_samples_seg(no,filename="./MMWHS/segmentations/*.nii.gz"):
	clean_ims = []
	pix_dimensions = []
	if not os.path.exists("./MMWHS/segmentations"):
		raise Exception("Error with file path.")
	else:
		files = glob.glob(filename)
		files = sorted(files)
		for i in range(1):
			i = no
			
			nii_img  = nib.load(files[i])

			nii_data = nii_img.get_fdata()

			if nii_data.shape[1] == 512 and nii_data.shape[0] == 512:
				clean_ims.append(nii_data[:,::-1,:])
			else:
				clean_ims.append(np.transpose(nii_data[:,::-1,::-1],[2,1,0]))
			
			pix_dimensions.append(nii_img.header['pixdim'])

		return clean_ims, pix_dimensions

#Loads image data
def load_data_samples(no,filename="./MMWHS/ims/*.nii.gz"):
	clean_ims = []
	pix_dimensions = []
	orientations = []
	if not os.path.exists("./MMWHS/ims"):
		raise Exception("Error with file path.")
	else:
		files = glob.glob(filename)
		files = sorted(files)
		for i in range(1):
			i=no
			
			nii_img  = nib.load(files[i])

			nii_data = nii_img.get_fdata()
			if nii_data.shape[1] == 512 and nii_data.shape[0] == 512:
				clean_ims.append(nii_data[:,::-1,:])
			else:
				clean_ims.append(np.transpose(nii_data[:,::-1,::-1],[2,1,0]))

			affine = nii_img.affine
			orientation = nib.aff2axcodes(affine)
			orientations.append(''.join(orientation))
			
			pix_dimensions.append(nii_img.header['pixdim'])
			

		return clean_ims, pix_dimensions , orientations

In [None]:
#Makes directory to store data
if not os.path.exists('./seg_data'):
    os.makedirs('./seg_data')

#Set number of MMWHS datasets to process
MMWHS_No = 1

#Set output data size
z=112
x=256
y=128

unique_values = [500,600,420,550,820,850,205,0] # [LV,RV,LA,RA,Aorta,PA,Myocardium,Background]

# Loop for data generation
for pat in tqdm(range(MMWHS_No)):
    
    # Loading data
	im_1,pix_dimensions_1 , orientations = load_data_samples(pat)  
	seg_1,pix_dimensions_1= load_data_samples_seg(pat)  

	image = np.expand_dims(im_1[0],-1)
	segmentation = np.expand_dims(seg_1[0],-1)
	data = np.concatenate((image,segmentation),-1)

	segmentation = data[...,-1]

	segmentations = []
	for i in range(6):
		seg = segmentation==unique_values[i]
		segmentations.append(seg)
	segmentations = np.array(segmentations)
	segmentations = np.transpose(segmentations,(1,2,3,0))

	temp = np.concatenate((image,segmentations),-1)

	#Fixing orientation for all volumes 
	if orientations[pat] == 'LSP':
		sf_x = pix_dimensions_1[pat][1]
		sf_y = pix_dimensions_1[pat][2]
		sf_z = pix_dimensions_1[pat][3]
	else:
		sf_x = pix_dimensions_1[pat][3]
		sf_y = pix_dimensions_1[pat][2]
		sf_z = pix_dimensions_1[pat][1]
  
	#Rescaling to 1.5mm isotropic data
	rescaled_temp = []
	for channel in range(temp.shape[-1]):
		rescaled = skimage.transform.rescale(temp[...,channel], scale=[sf_x/1.5,sf_y/1.5,sf_z/1.5], order=3,anti_aliasing= True,preserve_range=True, mode ='constant',cval=0)
		rescaled_temp.append(rescaled)
  
	rescaled = np.array(rescaled_temp)
	rescaled = np.transpose(rescaled,(1,2,3,0))

	#Threshold to keep masks binary
	rounded_data = []
	for i in range(6):
		int_seg = rescaled[:,:,:,i+1] > 0.2
		rounded_data.append(int_seg)
	
	rounded_data = np.array(rounded_data)
	rounded_data = np.transpose(rounded_data,(1,2,3,0))
	temp = np.concatenate((rescaled[...,:1],rounded_data),-1)

	#Crop or pad to desired shape
	y_mid = find_com(temp[int(np.rint(temp.shape[0]/2)),:,:,0])[0]
	z_com = find_com(temp[:,:,int(y_mid),0])[1]
	cropped_vol= []
	seg_cropped_vol = []
	
	start = z_com - z/2
	empty_slice = np.zeros((256,128,8))
	empty_slice_1 = np.zeros((256,128,1))
	if start + z > temp.shape[0]:
		start = temp.shape[0]-z
		missing = 0
	if start < 0:
		start = 0
		missing = z - temp.shape[0]
	else:
		missing = 0

	if missing!=0:
		pad_number_1 = int(np.rint(missing/2))
		for j in range(pad_number_1):
			cropped_vol.append(empty_slice_1)
			seg_cropped_vol.append(empty_slice)
	for i in range(temp.shape[0]):
		slice = temp[i,...,:1]
		seg_slice = temp[i,...,1:]
		if start <= i < start+z:
			pad_x = x-slice.shape[0]
			pad_y = y - slice.shape[1]
			if pad_x > 0 and pad_y >0:
				slice = np.pad(slice,((int(np.floor(pad_x/2)),int(np.ceil(pad_x/2))),(int(np.floor(pad_y/2)),int(np.ceil(pad_y/2))),(0,0)),mode='symmetric')
				seg_slice = np.pad(seg_slice,((int(np.floor(pad_x/2)),int(np.ceil(pad_x/2))),(int(np.floor(pad_y/2)),int(np.ceil(pad_y/2))),(0,0)),mode='constant')
			elif pad_x > 0:
				slice = np.pad(slice,((int(np.floor(pad_x/2)),int(np.ceil(pad_x/2))),(0,0),(0,0)),mode='symmetric')
				seg_slice = np.pad(seg_slice,((int(np.floor(pad_x/2)),int(np.ceil(pad_x/2))),(0,0),(0,0)),mode='constant')
				slice = tf.image.resize_with_crop_or_pad(slice,x,y)
				seg_slice = tf.image.resize_with_crop_or_pad(seg_slice,x,y)
			elif pad_y > 0:
				slice = np.pad(slice,((0,0),(int(np.floor(pad_y/2)),int(np.ceil(pad_y/2))),(0,0)),mode='symmetric')
				seg_slice = np.pad(seg_slice,((0,0),(int(np.floor(pad_y/2)),int(np.ceil(pad_y/2))),(0,0)),mode='constant')
				slice = tf.image.resize_with_crop_or_pad(slice,x,y)
				seg_slice = tf.image.resize_with_crop_or_pad(seg_slice,x,y)
			else:
				slice = tf.image.resize_with_crop_or_pad(slice,x,y)
				seg_slice = tf.image.resize_with_crop_or_pad(seg_slice,x,y)
			cropped_vol.append(slice)
			seg_cropped_vol.append(seg_slice)
	if missing!=0:
		pad_number_2 = missing - pad_number_1
		for j in range(pad_number_2):
			cropped_vol.append(empty_slice_1)
			seg_cropped_vol.append(empty_slice)

	cropped_vol_1 = np.array(cropped_vol)
	seg_cropped_vol_1 = np.array(seg_cropped_vol)

	cropped_vol_1 = apply_clahe(cropped_vol_1) #CLAHE

	cropped_vol_fin = np.concatenate((cropped_vol_1,seg_cropped_vol_1),-1)
 
	np.save(f'./seg_data/High_Res_MMWHS_NoAug_{pat}.npy', cropped_vol_fin)