In [None]:
import os, glob
import numpy as np
import SimpleITK as sitk
import torch
import nibabel as nib
import shutil
import matplotlib.pyplot as plt
import torch.nn.functional as F
from skimage.exposure import equalize_hist
from einops.einops import rearrange

In [None]:
def getSymmetricRepresentation(ct_volume):
    # the volume is either h x w x t or h x w x modalities with the first modality the NCCT
    ct_slice = ct_volume[0]
    ct_slice_flipped = np.fliplr(ct_slice)
    fixedImage = sitk.GetImageFromArray(ct_slice)
    movingImage = sitk.GetImageFromArray(ct_slice_flipped)

    # Elastix part, determine rotation
    parameterMap = sitk.GetDefaultParameterMap("rigid")
    elastixImageFilter = sitk.ElastixImageFilter()
    elastixImageFilter.SetFixedImage(fixedImage)
    elastixImageFilter.SetMovingImage(movingImage)
    elastixImageFilter.LogToFileOn()
    elastixImageFilter.SetParameterMap(parameterMap)
    resultImage = elastixImageFilter.Execute()

    # We now determined the transformation of the flipped scan to the original,
    # either for the first frame or for the first modality
    # Now we apply the same transformation to the other modalities or the other frames

    empty = np.zeros((ct_volume.shape[0],256,256))
    for frame_nr in range(ct_volume.shape[0]):
        frame = ct_volume[frame_nr,:,:]
        frame_flipped = np.fliplr(frame)
        movingFrame = sitk.GetImageFromArray(frame_flipped)
        frameResult = sitk.Transformix(movingFrame, elastixImageFilter.GetTransformParameterMap())
        empty[frame_nr,:,:] = sitk.GetArrayFromImage(frameResult)
    return empty

def getSym(volume):
    empty = np.zeros((volume.shape[0],256,256,2))
    sym_frame = getSymmetricRepresentation(volume)
    print(sym_frame.shape)
    empty[...,0] = volume
    empty[...,1] = sym_frame
    return empty

def preprocessor(nifti, clip_value): #3D nifti
    # Clip entire 4D volume
    nifti = np.clip(nifti, 0, clip_value)
    # Equalize entire histogram
    nifti = equalize_hist(nifti, nbins=20000, mask=(nifti > 0))
    nifti = nifti - np.min(nifti)
    # Shift distribution, not the zeroes
    mask=(nifti > 0)
    mdata = np.ma.masked_array(nifti, mask=~mask.astype(bool))
    mdata =(mdata - 0.5)
    mdata.mask = np.ma.nomask
    return mdata

def smoothing(array):
    # array: (h, w, 2, t)
    h, w, c, t = array.shape
    kernel_np = np.array([0.25, 0.5, 0.25])
    kernel_torch = torch.tensor(kernel_np, dtype=torch.float32).to('cuda')
    # Queremos un kernel de forma (out_channels, in_channels, kernel_size) => (1, 1, 3)
    kernel_torch = kernel_torch.view(1, 1, -1)
    
    # Calcular longitud de la salida real
    t = array.shape[-1]
    output_len = (t - 3) // 2 + 1  # stride=2, kernel=3, padding=0
    
    # Crear el array de salida con el tamaño esperado
    out = np.empty([256, 256, 2, output_len], dtype=np.float32)
    
    # Para cada canal, aplanamos (h, w) en una dimensión N
    # channel i: (h, w, t) => reshape => (h*w, 1, t)
    for ch in range(c):
        channel = array[:,:,ch,:]  # (h, w, t)
        channel_flat = channel.reshape(-1, t)  # (h*w, t)
        channel_torch = torch.tensor(channel_flat, dtype=torch.float32, device='cuda')
        
        # Añadimos dimensión para conv1d: (N, 1, t)
        channel_torch = channel_torch.unsqueeze(1)  # -> (h*w, 1, t)
        
        # Aplicamos la conv1d a TODAS las secuencias a la vez
        result = F.conv1d(channel_torch, kernel_torch, stride=2, padding=0)
        # result es (h*w, 1, t//2)

        result_np = result.squeeze(1).cpu().numpy()  # (h*w, t//2)
        # Lo llevamos a (h, w, t//2)
        result_np = result_np.reshape(h, w, -1)
        
        # Asignamos en la salida
        out[:,:,ch,:] = result_np

    return out



def write_np(array, path):
    with open(path, 'wb') as f:
        np.save(f, array)

In [None]:
def datamaker(case = 8, dataset='TRAINING'):
    path = os.path.join('/data/ISLES-2018/',dataset)
    for file in glob.glob(os.path.join(path,f'case_{case}','*/*.nii')):
        if 'CT_MTT' in file:
            CT_MTT = sitk.GetArrayFromImage(sitk.ReadImage(file))
            CT_MTT = rearrange(CT_MTT, 'd h (w c) -> c h w d', c=1)
        if 'CT_Tmax' in file:
            CT_Tmax = sitk.GetArrayFromImage(sitk.ReadImage(file))
            CT_Tmax = rearrange(CT_Tmax, 'd h (w c) -> c h w d', c=1)
        if 'CT_CBF' in file:
            CT_CBF = sitk.GetArrayFromImage(sitk.ReadImage(file))
            CT_CBF = rearrange(CT_CBF, 'd h (w c) -> c h w d', c=1)
        if 'CT_CBV' in file:
            CT_CBV = sitk.GetArrayFromImage(sitk.ReadImage(file))
            CT_CBV = rearrange(CT_CBV, 'd h (w c) -> c h w d', c=1)
        if 'OT' in file and dataset=='TRAINING':
            CT_MASK = sitk.GetArrayFromImage(sitk.ReadImage(file))
            CT_MASK = rearrange(CT_MASK, 'd h w -> h w d')
        if 'CT_4DPWI' in file:
            CT_4DPWI = sitk.GetArrayFromImage(sitk.ReadImage(file))
            CT_4DPWI = rearrange(CT_4DPWI, ' t d h w -> t h w d')
        if '.CT.' in file:
            CT = sitk.GetArrayFromImage(sitk.ReadImage(file))
            CT = rearrange(CT, 'd h (w c) -> c h w d', c=1)
    print('Data case %i loaded'%(case))
    folder = 'train' if dataset == 'TRAINING' else 'test'
    os.makedirs(os.path.join(os.getcwd(), folder, 'COMPLETE_MASK'), exist_ok=True)
    maskpath = [file for file in glob.glob(os.path.join(path,f'case_{case}','*/*.nii')) if 'OT' in file][0]
    savename = os.path.join(os.getcwd(), folder, 'COMPLETE_MASK', 'case_{}.nii'.format(str(case).zfill(2)))
    shutil.copy(maskpath,savename)


    CT = np.multiply(skull,CT)
    # create CTP without skull
    skull_frames = np.zeros(CT_4DPWI.shape)
    for i in range(CT_4DPWI.shape[0]):
        skull_frames[0,...] = skull[0,...]
    CT_4DPWI  = np.multiply(skull_frames,CT_4DPWI)
    # we generate the data per slice
    for _z in range(CT.shape[-1]):
        # We only use slices with infarcts fro training
        if np.max(CT_MASK[:,:,_z]) != 1.0 and dataset=='TRAINING':
            print('Slice without infarct found, continuing')
            continue

        baseline = preprocessor(CT[:,:,:,_z], 500)
        mtt = preprocessor(CT_MTT[:,:,:,_z], 500)
        tmax = preprocessor(CT_Tmax[:,:,:,_z], 500)
        cbf = preprocessor(CT_CBF[:,:,:,_z], 500)
        cbv = preprocessor(CT_CBV[:,:,:,_z], 500)
        ct_with_skull = preprocessor(CT_SKULL[:,:,:,_z], 500)
 
        ct_modalities = np.concatenate([baseline, mtt, tmax, cbf, cbv,ct_with_skull], axis=0)
        ct_modalities_sym = getSym(ct_modalities)

        ct_modalities_sym =  rearrange(ct_modalities_sym, 'mods h w c -> mods c h w')
        ctp_sym =  rearrange(ctp_sym, 'h w c t -> t c h w')

        folder = 'train' if dataset == 'TRAINING' else 'test'
        os.makedirs(os.path.join(os.getcwd(), folder, 'CTP'), exist_ok=True)
        os.makedirs(os.path.join(os.getcwd(), folder, 'CT'), exist_ok=True)
        os.makedirs(os.path.join(os.getcwd(), folder, 'CTP_CBF'), exist_ok=True)
        os.makedirs(os.path.join(os.getcwd(), folder, 'CTP_CBV'), exist_ok=True)
        os.makedirs(os.path.join(os.getcwd(), folder, 'CTP_MTT'), exist_ok=True)
        os.makedirs(os.path.join(os.getcwd(), folder, 'CTP_Tmax'), exist_ok=True)
        os.makedirs(os.path.join(os.getcwd(), folder, 'MASK'), exist_ok=True)
        write_np(ct_modalities_sym[0], os.path.join(os.getcwd(), folder, 'CT', 'case_%s_%s.npy'%(str(case).zfill(2),str(_z).zfill(2))))
        write_np(ct_modalities_sym[1], os.path.join(os.getcwd(), folder, 'CTP_MTT', 'case_%s_%s.npy'%(str(case).zfill(2),str(_z).zfill(2))))
        write_np(ct_modalities_sym[2], os.path.join(os.getcwd(), folder, 'CTP_Tmax', 'case_%s_%s.npy'%(str(case).zfill(2),str(_z).zfill(2))))
        write_np(ct_modalities_sym[3], os.path.join(os.getcwd(), folder, 'CTP_CBF', 'case_%s_%s.npy'%(str(case).zfill(2),str(_z).zfill(2))))
        write_np(ct_modalities_sym[4], os.path.join(os.getcwd(), folder, 'CTP_CBV', 'case_%s_%s.npy'%(str(case).zfill(2),str(_z).zfill(2))))
        if dataset=='TRAINING':
            write_np(CT_MASK[:,:,_z], os.path.join(os.getcwd(), folder, 'MASK', 'case_%s_%s.npy'%(str(case).zfill(2),str(_z).zfill(2))))

In [None]:
for i in range(95):
    datamaker(case=i+1)