In [49]:
import glob
from sklearn.model_selection import train_test_split
import nibabel

import pandas as pd
import numpy as np
from collections import defaultdict
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook, tqdm
from joblib import Parallel, delayed
from IPython.core.debugger import set_trace
import os

import torch
from torch import nn
from torch import autograd
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from models.v2v import V2VModel

import yaml
from easydict import EasyDict as edict

from utils import show_slices, check_patch, get_symmetric_value, pad_arrays, normalized, load

from multiprocessing import cpu_count
N_CPU = cpu_count()

SEED = 42

In [2]:
class CatBrainMaskPatchLoader(Dataset):
    
    def __init__(self, config, train=True):
        self.root = config.root
        self.train = train
        self.patch_size = config.patch_size
        self.use_features = config.use_features

        metadata_name = 'metadata_' + ('train' if train else 'test') 
        self.metadata = pd.read_csv(os.path.join(self.root, metadata_name))

    def __getitem__(self, idx):
        
        metaindex = self.metadata.iloc[idx]
        
        label = 'tensor_' + metaindex.label
        tensor = torch.load(os.path.join(self.root, label))
        x,y,z = metaindex[['x','y','z']].astype(int)

        x1,x2 = x-self.patch_size//2, x+self.patch_size//2
        y1,y2 = y-self.patch_size//2, y+self.patch_size//2
        z1,z2 = z-self.patch_size//2, z+self.patch_size//2

        if self.use_features:
            brain_patch = tensor[0:-1,x1:x2,y1:y2,z1:z2] # brain 
        else:
            brain_patch = tensor[0:1,x1:x2,y1:y2,z1:z2] # brain 
        label_patch = tensor[-1,x1:x2,y1:y2,z1:z2] # label
    
        return brain_patch, label_patch.unsqueeze(0)

    def __len__(self):
        return self.metadata.shape[0]

# Creating dataset

In [47]:
USE_GEOM_FEATURES = False
GEOM_FEATURES = ['thickness', 'sulc', 'curv']
patch_size=100
pad=patch_size//2

In [17]:
# patches_data_root = f'../fcd_data/patches_dataset_{patch_size}' + ('_features' if USE_GEOM_FEATURES else '')

# if not os.path.isdir(patches_data_root):
#     os.makedirs(patches_data_root)

In [18]:
paths_dict = defaultdict(dict)
for p in os.listdir('../fcd_data/normalized_label'):
    
    label = p.split('.')[0]
    
    sub_root = f'../fcd_data/normalized_data/sub-{label}/anat/'
    brain_path = glob.glob(os.path.join(sub_root, '*Asym_desc-preproc_T1w.nii.gz'))[0]
    mask_path = glob.glob(os.path.join(sub_root, '*Asym_desc-brain_mask.nii.gz'))[0]
    label_path = f'../fcd_data/normalized_label/{p}' 
    
    # features
    if USE_GEOM_FEATURES:
        absent_feature = False
        for feature_name in GEOM_FEATURES:
            feature_path = f'../fcd_data/preprocessed_data_anadezhda/{feature_name}/norm-{label}.nii'
            if not os.path.isfile(feature_path):
                absent_feature=True
                continue
            paths_dict[label][f'{feature_name}'] = feature_path
        if absent_feature:
            continue
        
    paths_dict[label]['label'] = label_path
    paths_dict[label]['brain'] = brain_path    
    paths_dict[label]['mask'] = mask_path 

In [19]:
labels_info = np.load('labels_info.npy', allow_pickle=True).item()
labels_info = {k:v for k,v in labels_info.items() if len(v['cc3d'][0]) == 2}

In [20]:
train_keys, test_keys = train_test_split(list(labels_info.keys()), test_size=0.1, random_state=SEED)

In [22]:
labels_all = {'train':train_keys,
              'test':test_keys}

In [25]:
labels_info['30']

{'cc3d': [array([0, 1], dtype=uint16), array([22911852,     4356])],
 'd_s': [19, 26, 28],
 'center': array([152, 212,  60])}

# Make metadata

In [None]:
AUGMENTATION_STEPS = 2

metadata_fcd_pivot_patches = defaultdict(dict)

for split_type, split_keys in labels_all.items():
    for k in split_keys:
        path_dict = paths_dict[k]
        info_patch = {}
        
        brain_tensor, mask_tensor, label_tensor = load(path_dict)
        label_info = labels_info[k]
        center = label_info['center']
        
        aug_iter = np.arange(-AUGMENTATION_STEPS, AUGMENTATION_STEPS+1)
        for shift in len(list(product(aug_iter, aug_iter, aug_iter))):
            
            center_i = center + np.array(shift)
            
            info_patch['patch_center'] = center_i
            info_patch['patient_number'] = k
            info_patch['is_fcd'] = 1
            
        x_c, y_c, z_c = center
        mask_tensor[x_c-pad:x_c+pad, y_c-pad:y_c+pad, z_c-pad:z_c+pad] = False
        
        ###################################
        # CREATE RELEVANT non-FCD INDEXES #
        ###################################
        
        xyz_grid = np.stack(np.meshgrid(np.arange(X), np.arange(Y), np.arange(Z), indexing='ij'), -1)
        xyz_grid = xyz_grid[mask_tensor]

        indexes_selected = Parallel(n_jobs=N_CPU//2)(delayed(check_patch)(x,y,z,\
                                                                            mask_tensor,\
                                                                            label_tensor,\
                                                                            patch_size) \
                                                               for x,y,z in xyz_grid)

        indexes_selected = list(filter(lambda x: x is not None, indexes_selected))
