In [239]:
import glob
from skimage.measure import marching_cubes_lewiner, marching_cubes
import nibabel
import trimesh
import numpy as np
from natsort import natsorted
import plotly.graph_objects as go
import pymeshlab
sorted = natsorted
from numba import njit
import cc3d
import re
from scipy.ndimage import gaussian_filter
from collections import defaultdict
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook
from joblib import Parallel, delayed
from tensorboardX import SummaryWriter  
from IPython.core.debugger import set_trace
from datetime import datetime
import os
import shutil
import argparse
import time
import json
from collections import defaultdict
import pickle

import torch
from torch import nn
from torch import autograd
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.parallel import DistributedDataParallel

from models.v2v import V2VModel

import yaml
from easydict import EasyDict as edict

In [362]:
def show_slices(brain_tensor, n_slices_show=5, mask_tensor=None):
    
    fig, axes = plt.subplots(ncols=3, nrows=n_slices_show, figsize=(15,n_slices_show*5))
    X_max, Y_max, Z_max = brain_tensor.shape
    for i in range(n_slices_show):

        x_slice_pos = (X_max//(n_slices_show+2))*(i+1)
        y_slice_pos = (Y_max//(n_slices_show+2))*(i+1)
        z_slice_pos = (Z_max//(n_slices_show+2))*(i+1)

        brain_tensor_x_slice = brain_tensor[x_slice_pos,:,:]
        brain_tensor_y_slice = brain_tensor[:,y_slice_pos,:]
        brain_tensor_z_slice = brain_tensor[:,:,z_slice_pos]

        axes[i,0].imshow(brain_tensor_x_slice, 'gray')
        axes[i,1].imshow(brain_tensor_y_slice, 'gray')
        axes[i,2].imshow(brain_tensor_z_slice, 'gray')
        
        if mask_tensor is not None:
            
            mask_tensor_x_slice = mask_tensor[x_slice_pos,:,:]
            mask_tensor_y_slice = mask_tensor[:,y_slice_pos,:]
            mask_tensor_z_slice = mask_tensor[:,:,z_slice_pos]

            axes[i,0].imshow(mask_tensor_x_slice, 'jet', interpolation='none', alpha=0.7)
            axes[i,1].imshow(mask_tensor_y_slice, 'jet', interpolation='none', alpha=0.7)
            axes[i,2].imshow(mask_tensor_z_slice, 'jet', interpolation='none', alpha=0.7)
    plt.tight_layout()
    plt.show()
    
def is_filled_volume(x,y,z,mask_tensor, threshold=0.5):
    x1,x2 = x-patch_size//2, x+patch_size//2
    y1,y2 = y-patch_size//2, y+patch_size//2
    z1,z2 = z-patch_size//2, z+patch_size//2
    volume = mask_tensor[x1:x2,y1:y2,z1:z2]
    if volume.sum()/np.prod(volume.shape) > threshold:
        return [x,y,z]
    else:
        return None
    
def get_symmetric_value(a, a_sym):
    diff = a-a_sym
    return a_sym - diff

In [363]:
class PatchBrainMaskLoader(Dataset):
    
    def __init__(self, root):
        self.root = root
        paths = [os.path.join(root, p) for p in os.listdir(root)]
        self.tensors_paths = list(filter())
        self.indexes = list(filter())
        
        brain_label_
        indexes_selected_

    def __getitem__(self, idx):
        brain_tensor_torch, mask_tensor_torch = torch.load(self.paths[idx])
        return brain_tensor_torch, mask_tensor_torch

    def __len__(self):
        return len(self.paths)

In [364]:
# brain_mask_paths = defaultdict(dict)
# data_root = '../fcd_data/fmriprep/'
# for sub_path in glob.glob(os.path.join(data_root,'sub-*/')):
#     mask_path = glob.glob(os.path.join(sub_path,'anat/*Asym_desc-brain_mask*.nii.gz'))[0]
#     T1_path = glob.glob(os.path.join(sub_path,'anat/*Asym_desc-preproc_T1w*.nii.gz'))[0]
    
#     label = os.path.basename(os.path.normpath(sub_path))
#     brain_mask_paths[label]['mask_path'] = mask_path
#     brain_mask_paths[label]['T1_path'] = T1_path

In [260]:
paths_dict = defaultdict(dict)
for p in os.listdir('../fcd_data/normalized_label'):
    label = p.split('.')[0]
    sub_root = f'../fcd_data/normalized_data/sub-{label}/anat/'
    brain_path = glob.glob(os.path.join(sub_root, '*Asym_desc-preproc_T1w.nii.gz'))[0]
    mask_path = glob.glob(os.path.join(sub_root, '*Asym_desc-brain_mask.nii.gz'))[0]
    
    label_path = f'../fcd_data/normalized_label/{p}'
    
    paths_dict[label]['label'] = label_path
    paths_dict[label]['brain'] = brain_path    
    paths_dict[label]['mask'] = mask_path    

In [None]:
patch_size=64
for label, path_dict in tqdm_notebook(paths_dict.items()):
    
    mask_tensor = nibabel.load(path_dict['mask']).get_fdata() > 0
    brain_tensor = nibabel.load(path_dict['brain']).get_fdata()
    brain_tensor = brain_tensor*mask_tensor.astype(int)
    label_tensor = nibabel.load(path_dict['label']).get_fdata()
    
    brain_tensor_torch = torch.tensor(brain_tensor, dtype=torch.float)
    label_tensor_torch = torch.tensor(label_tensor, dtype=torch.float)
    
    torch_tensor = torch.stack([brain_tensor_torch, label_tensor_torch])
    torch.save(torch_tensor, f'../fcd_data/patches_dataset_{patch_size}/brain_label_{label}')

    X,Y,Z = mask_tensor.shape

    thresh_mask = (np.arange(X) < (X_mean - patch_size//2)) | (np.arange(X) > (X_mean + patch_size//2))
    thresh_mask = np.tile(thresh_mask, (Y,Z,1)).transpose(2,0,1)
    mask_tensor = mask_tensor*thresh_mask > 0

    xyz_grid = np.stack(np.meshgrid(np.arange(X), np.arange(Y), np.arange(Z), indexing='ij'), -1)

    indexes_selected = Parallel(n_jobs=-1)(delayed(is_filled_volume)(x,y,z,mask_tensor) \
                                           for x,y,z in xyz_grid[mask_tensor])

    indexes_selected_ = np.array(list(filter(lambda x: x is not None, indexes_selected)))

    np.save(f'../fcd_data/patches_dataset_{patch_size}/indexes_selected_{label}', indexes_selected_)


This function will be removed in tqdm==5.0.0
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`



  0%|          | 0/92 [00:00<?, ?it/s]