In [457]:
import os
import sys
import glob
import pickle

In [458]:
import imp

import utils.data_processor as data_processor
imp.reload(data_processor)
from utils.data_processor import *

import utils.visualization_tools as visualization_tools
imp.reload(visualization_tools)
from utils.visualization_tools import *

import utils.metrics as metrics
imp.reload(metrics)
from utils.metrics import *

# Functions

In [459]:
def binary_3dmaps_to_point_cloud_and_labels(image_brain, image_hypo, size = 256, seg =None):
    """ 
    Transforms 3d tensors of brain and hippocampus into pointcloud and labels for it. Both only coordinates and 
    coordinates + intensity modes are suppoted
      Args:
          image_brain: torch tensor of size [size,size,size] with 1 at the positions with brain and 0 otherwise
          image_hypo: torch tensor of size [size,size,size] with 1 at the positions with hippocampus and 0 otherwise
          size: size of the input tensors along each direction, default = 256
          seg: torch tensor of size [size,size,size] with intensities of brain, default None
      Output:
          torch tensor of size [N, 3] is seg is None and [N, 4] otherwise and [N,] tensor with labels
      """
    grid_x, grid_y, grid_z = torch.meshgrid(torch.tensor(range(size)),\
                                            torch.tensor(range(size)),\
                                            torch.tensor(range(size)))
    if seg is None:
        new = torch.cat((grid_x.unsqueeze(-1), grid_y.unsqueeze(-1),grid_z.unsqueeze(-1)), -1)
    else:
        new = torch.cat((grid_x.unsqueeze(-1).float(), 
                         grid_y.unsqueeze(-1).float(),
                         grid_z.unsqueeze(-1).float(), 
                         seg.unsqueeze(-1).float()), -1)
    pc_hypo = new[image_hypo==1,:]
    pc_brain_without_hypo = new[(image_hypo==0)*(image_brain == 1),:]
    return torch.cat([pc_hypo,pc_brain_without_hypo]),\
np.array([1] * pc_hypo.shape[0] + [0] * pc_brain_without_hypo.shape[0])

In [460]:
def filename_to_pc_and_labels(file, size = 256, target = [17,53], segfile = None):
    """ 
    Procceses filename of brain into pointcloud with labels
      Args:
          file: filename 
          size: size of the input tensors along each direction is 256, but it can be maxpulled to size. Default = 256
          segfile: file with segmentation path
      Output:
          torch tensor of size [N, 3] is seg is None and [N, 4] otherwise and [N,] tensor with labels
      """
    
    tmp = load_nii_to_array(file)
    
    hypo = tmp.copy()
    hypo[~np.isin(hypo,target)]=0
    hypo[np.isin(hypo,target)]=1
    
    brain = tmp.copy()
    brain[~np.isin(brain,[0])]=1
    
    if segfile is not None:
        seg = load_nii_to_array(segfile)
        seg = brain * seg
    else:
        seg = None
    
    if size == 32:
        m = torch.nn.MaxPool3d(8, 8)
        brain = m(torch.tensor(brain, dtype=torch.float64).unsqueeze(0).unsqueeze(0)).detach().squeeze()
        hypo = m(torch.tensor(hypo, dtype=torch.float64).unsqueeze(0).unsqueeze(0)).detach().squeeze()
        if segfile is not None:
            seg = m(torch.tensor(seg, dtype=torch.float64).unsqueeze(0).unsqueeze(0)).detach().squeeze()
    
    pc,labels = binary_3dmaps_to_point_cloud_and_labels(brain, hypo, size = size, seg = seg)
    
    return pc, labels

In [461]:
def pc_norm(pc):
    """ 
    Normalises given point cloud
      Args:
          pc: pointcloud - torchtensor of size [N, 3] or [N, 4]
      Output:
          normalised pointcloud of same shape
    """
    
    pmin = np.min(pc[:,:3])
    pmax = np.max(pc[:,:3])
    pc -= (pmin + pmax) / 2
    scale = np.max(np.linalg.norm(pc[:,:3], axis=1))
    pc[:,:3] *= 1.0 / scale
    
    if pc.shape[1] == 4:
        pc[:,3:4] -= np.mean(pc[:,3:4])
        pc[:,3:4] /= np.max(abs(pc[:,3:4]))

    return pc

# Data for experiment 1 creation

In [None]:
SIZE = 32
TARGET_LABELS = [17,53]
TEST_DATA = 100
POSTFIX = ''

In [None]:
pcs, labels = [], []
for file in tqdm(glob.glob('../datasets/fcd_classification_bank/*_aparc+aseg.nii*')):
    pc,label = filename_to_pc_and_labels(file, 
                                         size = SIZE,
                                         target = TARGET_LABELS
                                        )
    pc = pc_norm(np.array(pc.detach(),dtype = float))
    pcs.append(pc)
    labels.append(label)
    
sc_labels = [np.array(0)]*len(labels)

In [None]:
pcs_test,labels_test, sc_labels_test = pcs[:TEST_DATA],labels[:TEST_DATA],sc_labels[:TEST_DATA]
pcs_train,labels_train, sc_labels_train = pcs[TEST_DATA:],labels[TEST_DATA:],sc_labels[TEST_DATA:]

In [None]:
data_test = (pcs_test,labels_test, sc_labels_test)
data_train = (pcs_train,labels_train, sc_labels_train)
with open(f'../CloserLook3D/pytorch/data/BrainData/test_data{POSTFIX}.pkl', 'wb') as f:
    pickle.dump(data_test, f)
with open(f'../CloserLook3D/pytorch/data/BrainData/trainval_data{POSTFIX}.pkl', 'wb') as f:
    pickle.dump(data_train, f)

# Data for experiment 2 creation

In [445]:
SIZE = 32
TARGET_LABELS = [17,53]
TEST_DATA = 100
POSTFIX = '_2exp'

In [None]:
pcs, labels = [], []
for file in tqdm(glob.glob('../datasets/fcd_classification_bank/*_aparc+aseg.nii*')):
    try:
        segfile = [x for x in glob.glob(file.split('aparc+aseg.nii')[0]+'*') if 'aparc+aseg' not in x][0]
    except Exception:
        pass
    pc,label = filename_to_pc_and_labels(file, 
                                         size = SIZE,
                                         target = TARGET_LABELS, 
                                         segfile = seg_file
                                        )
    pc = pc_norm(np.array(pc.detach(),dtype = float))
    pcs.append(pc)
    labels.append(label)

In [436]:
sc_labels = [np.array(0)]*len(labels)

In [437]:
pcs_test,labels_test, sc_labels_test = pcs[:TEST_DATA],labels[:TEST_DATA],sc_labels[:TEST_DATA]
pcs_train,labels_train, sc_labels_train = pcs[TEST_DATA:],labels[TEST_DATA:],sc_labels[TEST_DATA:]

In [438]:
data_test = (pcs_test,labels_test, sc_labels_test)
data_train = (pcs_train,labels_train, sc_labels_train)
with open(f'../CloserLook3D/pytorch/data/BrainData/test_data{POSTFIX}.pkl', 'wb') as f:
    pickle.dump(data_test, f)
with open(f'../CloserLook3D/pytorch/data/BrainData/trainval_data{POSTFIX}.pkl', 'wb') as f:
    pickle.dump(data_train, f)

# Additional functions for PC generation for FCD task

In [462]:
def fcd_filename_to_pc_and_labels(file, file_mask, 
                                  size = 256, 
                                  LIST_FCD = [ 8, 10, 11, 12, 13, 16, 17, 18, 26, 47, 49, 50,
                                              51, 52, 53, 54, 58, 85, 251, 252, 253, 254, 255], 
                                  segfile = None):
    """ 
    Procceses filename of brain and mask into pointcloud with labels
      Args:
          file: path to brain file
          file_mask: path to mask file
          size: size of the input tensors along each direction is 256, but it can be maxpulled to size. Default = 256
          segfile: file with segmentation path
      Output:
          torch tensor of size [N, 3] is seg is None and [N, 4] otherwise and [N,] tensor with labels
      """
    
    brain = load_nii_to_array(file)
    brain[0][0][(np.isin(brain[0][0], LIST_FCD))] = 1.0
    brain[brain >= 1000] = 1.0
    brain[brain != 1] = 0.0
    
    
    mask = load_nii_to_array(file_mask)
    
    if segfile is not None:
        seg = load_nii_to_array(segfile)
        seg = torch.tensor(brain * seg, dtype=torch.float64).detach()
    else:
        seg = None
    
    if size == 32:
        m = torch.nn.MaxPool3d(8, 8)
        brain = m(torch.tensor(brain, dtype=torch.float64).unsqueeze(0).unsqueeze(0)).detach().squeeze()
        mask = m(torch.tensor(hypo, dtype=torch.float64).unsqueeze(0).unsqueeze(0)).detach().squeeze()
        if segfile is not None:
            seg = m(torch.tensor(seg, dtype=torch.float64).unsqueeze(0).unsqueeze(0)).detach().squeeze()
    
    pc,labels = binary_3dmaps_to_point_cloud_and_labels(brain, mask, size = size, seg = seg)
    
    return pc, labels

# Data for experiment 3 creation

In [344]:
SIZE = 256
UPSAMPLE_RATE = 10
POSTFIX = '_3exp'

In [345]:
pcs, labels = [], []
for file in tqdm(glob.glob('../datasets/fcd_classification_bank/fcd_*_aparc+aseg.nii*')):
    peace = file.split('/')[-1].split('_aparc')[0]
    try:
        file_mask = glob.glob(f'../masks/{peace}*')[0]
    except Exception:
        pass
    pc, label = fcd_filename_to_pc_and_labels(file, file_mask, 
                                  size = SIZE)
    pc = pc_norm(np.array(pc.detach(),dtype = float))
    pcs.append(pc)
    labels.append(label)

100%|██████████| 15/15 [00:29<00:00,  1.95s/it]


In [346]:
sc_labels = [np.array(0)]*len(labels)

In [347]:
for e in range(len(pcs)):
    TEST_DATA_INDEXES = [e]
    pcs_test,labels_test, sc_labels_test = [pcs[i] for i in TEST_DATA_INDEXES]*UPSAMPLE_RATE,\
                                            [labels[i] for i in TEST_DATA_INDEXES]*UPSAMPLE_RATE,\
                                            [sc_labels[i] for i in TEST_DATA_INDEXES]*UPSAMPLE_RATE
    pcs_train,labels_train, sc_labels_train = [pcs[i] for i in range(len(pcs)) if i not in TEST_DATA_INDEXES]*UPSAMPLE_RATE,\
                                            [labels[i] for i in range(len(pcs)) if i not in TEST_DATA_INDEXES]*UPSAMPLE_RATE,\
                                            [sc_labels[i] for i in range(len(pcs)) if i not in TEST_DATA_INDEXES]*UPSAMPLE_RATE
    
    data_test = (pcs_test,labels_test, sc_labels_test)
    data_train = (pcs_train,labels_train, sc_labels_train)
    with open(f'../CloserLook3D/pytorch/data/BrainData/test_data{POSTFIX}_{e}.pkl', 'wb') as f:
        pickle.dump(data_test, f)
    with open(f'../CloserLook3D/pytorch/data/BrainData/trainval_data{POSTFIX}_{e}.pkl', 'wb') as f:
        pickle.dump(data_train, f)

# Data for experiment 4 creation

In [463]:
SIZE = 256
UPSAMPLE_RATE = 10
POSTFIX = '_4exp'

In [464]:
pcs, labels = [], []
for file in tqdm(glob.glob('../datasets/fcd_classification_bank/fcd_*_aparc+aseg.nii*')):
    peace = file.split('/')[-1].split('_aparc')[0]
    try:
        file_mask = glob.glob(f'../masks/{peace}*')[0]
    except Exception:
        pass
    segfile = [x for x in glob.glob(file.split('aparc+aseg.nii')[0]+'*') if 'aparc+aseg' not in x][0]
    pc, label = fcd_filename_to_pc_and_labels(file, file_mask, 
                                  size = 256, segfile = segfile)
    pc = pc_norm(np.array(pc.detach(),dtype = float))
    pcs.append(pc)
    labels.append(label)

100%|██████████| 15/15 [00:43<00:00,  2.89s/it]


In [465]:
sc_labels = [np.array(0)]*len(labels)

In [466]:
for e in range(len(pcs)):
    TEST_DATA_INDEXES = [e]
    pcs_test,labels_test, sc_labels_test = [pcs[i] for i in TEST_DATA_INDEXES]*UPSAMPLE_RATE,\
                                            [labels[i] for i in TEST_DATA_INDEXES]*UPSAMPLE_RATE,\
                                            [sc_labels[i] for i in TEST_DATA_INDEXES]*UPSAMPLE_RATE
    pcs_train,labels_train, sc_labels_train = [pcs[i] for i in range(len(pcs)) if i not in TEST_DATA_INDEXES]*UPSAMPLE_RATE,\
                                            [labels[i] for i in range(len(pcs)) if i not in TEST_DATA_INDEXES]*UPSAMPLE_RATE,\
                                            [sc_labels[i] for i in range(len(pcs)) if i not in TEST_DATA_INDEXES]*UPSAMPLE_RATE
    
    data_test = (pcs_test,labels_test, sc_labels_test)
    data_train = (pcs_train,labels_train, sc_labels_train)
    with open(f'../CloserLook3D/pytorch/data/BrainData/test_data{POSTFIX}_{e}.pkl', 'wb') as f:
        pickle.dump(data_test, f)
    with open(f'../CloserLook3D/pytorch/data/BrainData/trainval_data{POSTFIX}_{e}.pkl', 'wb') as f:
        pickle.dump(data_train, f)