In [None]:
import os
import glob2 as glob
import numpy as np
import SimpleITK as sitk

from sklearn.decomposition import PCA
from skimage import io
import pickle, warnings

import seaborn as sns

In [None]:
def read_nii_from_file(filename, is_label=False):
    sitk_niim = sitk.ReadImage(filename)
    niim = sitk.GetArrayFromImage(sitk_niim)
    return niim

In [None]:
subjects = glob.glob('/datasets/isles18/TRAINING/case_*/')

In [None]:
MIN_DWI_ZDIM = 28

In [None]:
# dwi_zdims = []
dwi_4d_nz_vals_list = []
for casei, case in enumerate(subjects):
    for subdir, dirs, files in os.walk(case):
        for file in files:
            if file.endswith('.nii'):
                if 'O.OT.' in file:
                    is_label = True
                else:
                    is_label = False

                if 'O.MR_4DPWI.' in file:
                    im = read_nii_from_file(os.path.join(subdir, file), is_label=is_label)
                    z, x, y = np.where(np.sum(im, axis=0) > -23*MIN_DWI_ZDIM)
                    dwi_4d_nz_vals_list.append(im[:MIN_DWI_ZDIM, z, x, y])  
                else:
                    continue
                
                # dwi_zdims.append(im.shape[0])
                # print(int(np.min(im)), int(np.max(im)), int(np.mean(im)))
                
dwi_4d_nz_vals = np.concatenate(dwi_4d_nz_vals_list, axis=1)

In [None]:
dwi_4d_nz_vals_nm = (dwi_4d_nz_vals - np.mean(dwi_4d_nz_vals)) / np.std(dwi_4d_nz_vals)

In [None]:
u, s, vh = np.linalg.svd(dwi_4d_nz_vals_nm, full_matrices=False)

In [None]:
with open('isles_4ddwi_svd.pkl', 'wb') as f:
    pickle.dump([u, s], f)

In [None]:
with open('isles_4ddwi_svd.pkl', 'rb') as f:
    u, s = pickle.load(f)

In [None]:
np.tensordot(np.transpose(u[:,:2]), im[:MIN_DWI_ZDIM,:], axes=(1, 0)).shape

In [None]:
def im_normalize(im):    
    im_mean = np.mean(im)
    im_std = np.std(im)
    
    im_uplimit = im_mean + .5*im_std
    im_lowlimit = 0
    
    im[im > im_uplimit] = im_uplimit
    im[im < im_lowlimit] = im_lowlimit
    
    im = (im - np.min(im)) / (np.max(im) - np.min(im)) * 255
    
    return im

In [None]:
def write_2d_im(subjects, save_dir):
    for casei, case in enumerate(subjects):
        imct = None
        im4dwi = None
        imlb = None
        for subdir, dirs, files in os.walk(case):
            for file in files:
                if file.endswith('.nii'):
                    if 'O.OT.' in file:
                        is_label = True
                    else:
                        is_label = False

                    im = read_nii_from_file(os.path.join(subdir, file), is_label=is_label)

                    if '_4DPWI.' in file:
                        im4dwi = np.tensordot(np.transpose(u[:,:2]), im[:MIN_DWI_ZDIM,:], axes=(1, 0))
                        _im4dwi = im
                    elif 'O.CT.' in file:
                        imct = im
                    elif '_CBF.' in file:
                        imcbf = im
                    elif '_CBV.' in file:
                        imcvf = im
                    elif '_MTT.' in file:
                        immtt = im
                    elif '_Tmax.' in file:
                        imtmax = im
                    elif 'O.OT.' in file:
                        imlb = im
                    else:
                        continue

        im4dwi0 = -1*im4dwi[0,:]

        im4dwi0_nm = im_normalize(im4dwi0)
        imct_nm = im_normalize(imct)
        imcbf_nm = im_normalize(imcbf)
        imcvf_nm = im_normalize(imcvf)
        immtt_nm = im_normalize(immtt)
        imtmax_nm = imtmax * 10
        
        _im4dwi_nm = im_normalize(_im4dwi)
        
        if imlb is None:
            imlb = np.zeros_like(imct)

        for zidx in range(0, imct.shape[0]):
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                im2d = np.stack([imct_nm[zidx,:,:], im4dwi0_nm[zidx,:,:], imtmax_nm[zidx,:,:]], axis=2).astype(np.uint8)
                imfname = os.path.join(save_dir,
                    os.path.basename(os.path.normpath(case)) + '_3ch_sli' + str(zidx) + '.png')
                io.imsave(imfname, im2d)
                
                im2d2 = np.stack([imcbf_nm[zidx,:,:], imcvf_nm[zidx,:,:], immtt_nm[zidx,:,:]], axis=2).astype(np.uint8)
                imfname = os.path.join(save_dir,
                    os.path.basename(os.path.normpath(case)) + '_3ch2_sli' + str(zidx) + '.png')
                io.imsave(imfname, im2d2)

                lb2d = imlb[zidx,:,:] * 255
                lbfname = os.path.join(save_dir,
                    os.path.basename(os.path.normpath(case)) + '_lb_sli' + str(zidx) + '.png')
                io.imsave(lbfname, lb2d)
                
                lb2d2 = np.stack([_im4dwi_nm[0,zidx,:,:], lb2d, _im4dwi_nm[-1,zidx,:,:]], axis=2).astype(np.uint8)
                #lb2d2 = np.stack([lb2d, immtt_nm[zidx,:,:]*0.5, imcbf_nm[zidx,:,:]*0.5], axis=2).astype(np.uint8)
                lb2fname = os.path.join(save_dir,
                    os.path.basename(os.path.normpath(case)) + '_lb2_sli' + str(zidx) + '.png')
                io.imsave(lb2fname, lb2d2)
                
    return imct, imct_nm, im4dwi0, im4dwi0_nm, imtmax, imtmax_nm

In [None]:
imct, imct_nm, im4dwi0, im4dwi0_nm, imtmax, imtmax_nm =\
  write_2d_im(subjects, '/datasets/isles18/training_png')

In [None]:
sns.distplot(imct.flatten())

In [None]:
print(np.mean(imct), np.median(imct), np.std(imct), np.mean(imct) + .5*np.std(imct),
      np.mean(imct) + np.std(imct))

In [None]:
sns.distplot(im4dwi0.flatten())

In [None]:
print(np.mean(im4dwi0), np.median(im4dwi0), np.std(im4dwi0), np.mean(im4dwi0) + .5*np.std(im4dwi0),
      np.mean(im4dwi0) + np.std(im4dwi0))

In [None]:
sns.distplot(imtmax.flatten())

In [None]:
sns.distplot(imct_nm.flatten())

In [None]:
sns.distplot(im4dwi0_nm.flatten())

In [None]:
sns.distplot(imtmax_nm.flatten())

In [None]:
subjects = glob.glob('/datasets/isles18/TESTING/case_*/')
imct, imct_nm, im4dwi0, im4dwi0_nm, imtmax, imtmax_nm =\
    write_2d_im(subjects, '/datasets/isles18/testing_png')

In [None]:
def write_2dext_im(subjects, save_dir):
    for casei, case in enumerate(subjects):
        imct = None
        im4dwi = None
        imlb = None
        for subdir, dirs, files in os.walk(case):
            for file in files:
                if file.endswith('.nii'):
                    if 'O.OT.' in file:
                        is_label = True
                    else:
                        is_label = False

                    im = read_nii_from_file(os.path.join(subdir, file), is_label=is_label)

                    if '_4DPWI.' in file:
                        #im4dwi = np.tensordot(np.transpose(u[:,:2]), im[:MIN_DWI_ZDIM,:], axes=(1, 0))
                        im4dwi = np.zeros((64, im.shape[1], im.shape[2], im.shape[3]))
                        im4dwi[:im.shape[0],:,:,:] = im
                        im4dwi = np.swapaxes(im4dwi, 0, 1)
                        im4dwi = np.swapaxes(im4dwi, 1, 2)
                        im4dwi = np.swapaxes(im4dwi, 2, 3)
                    elif 'O.CT.' in file:
                        imct = im
                    elif '_CBF.' in file:
                        imcbf = im
                    elif '_CBV.' in file:
                        imcbv = im
                    elif '_MTT.' in file:
                        immtt = im
                    elif '_Tmax.' in file:
                        imtmax = im
                    elif 'O.OT.' in file:
                        imlb = im
                    else:
                        continue

        im4dwi0 = -1*im4dwi[0,:]

        im4dwi_nm = im_normalize(im4dwi)
        imct_nm = im_normalize(imct)
        imcbf_nm = im_normalize(imcbf)
        imcbv_nm = im_normalize(imcbv)
        imcbf_nm = im_normalize(imcbf)
        immtt_nm = im_normalize(immtt)
        imtmax_nm = imtmax * 10
        
        if imlb is None:
            imlb = np.zeros_like(imct)

        for zidx in range(0, imct.shape[0]):
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                im2dext = np.stack([imct_nm[zidx,:,:], imcbf_nm[zidx,:,:], imcbv_nm[zidx,:,:], 
                                    immtt_nm[zidx,:,:], imtmax_nm[zidx,:,:]], axis=2)
                im2dext = np.concatenate((im2dext, im4dwi_nm[zidx,:,:,:]), axis=2)
                
                imfname = os.path.join(save_dir,
                    os.path.basename(os.path.normpath(case)) + '_extch_sli' + str(zidx) + '.pkl')
                with open(imfname, 'wb') as f:
                    pickle.dump(im2dext, f)

                lb2d = imlb[zidx,:,:] * 255
                lbfname = os.path.join(save_dir,
                    os.path.basename(os.path.normpath(case)) + '_lb_sli' + str(zidx) + '.png')
                io.imsave(lbfname, lb2d)

In [None]:
subjects = glob.glob('/raid/datasets/ISLES2018/TRAINING/case_*/')
write_2dext_im(subjects, '/raid/datasets/ISLES2018/training_2dext')

In [None]:
subjects = glob.glob('/raid/datasets/ISLES2018/TESTING/case_*/')
write_2dext_im(subjects, '/raid/datasets/ISLES2018/testing_2dext')