In [16]:
%load_ext autoreload
import glob
from sklearn.model_selection import train_test_split
import nibabel

import pandas as pd
import numpy as np
from collections import defaultdict
from matplotlib import pyplot as plt
from tqdm import tqdm
from joblib import Parallel, delayed
from IPython.core.debugger import set_trace
import os

import torch
from torch import nn
from torch import autograd
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from models.v2v import V2VModel

import yaml
from easydict import EasyDict as edict

from utils import show_slices, check_patch, pad_arrays, normalized, load, create_dicts

from multiprocessing import cpu_count
N_CPU = cpu_count()

SEED = 42
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [17]:
class CatBrainMaskPatchLoader(Dataset):
    
    def __init__(self, config, train=True):
        
        self.root = config.root
        self.metadata_path = config.metadata_path
        self.train = train
        self.patch_size = config.patch_size
        self.use_features = config.use_features
        
        self.concatenate_adjacent_patch = config.concatenate_adjacent_patch
        self.difference_with_adjacent_patch = config.concatenate_adjacent_patch
        
        metadata_key = 'train' if train else 'test' 
        self.metadata = pd.read_csv(self.metadata_path).query(f"is_train == {metadata_key}")
        
    def __getitem__(self, idx):
        
        metaindex = self.metadata.iloc[idx]
        
        label = 'tensor_' + metaindex.label
        
        tensor_dict = torch.load(os.path.join(self.root, label))
        brain_tensor_torch = tensor_dict['brain']
        mask_tensor_torch = tensor_dict['mask']
        label_tensor_torch = tensor_dict['label']
        
        x,y,z = metaindex[['x','y','z']].astype(int)

        x1,x2 = x-self.patch_size//2, x+self.patch_size//2
        y1,y2 = y-self.patch_size//2, y+self.patch_size//2
        z1,z2 = z-self.patch_size//2, z+self.patch_size//2

        if self.use_features:
            brain_patch = brain_tensor_torch[:,x1:x2,y1:y2,z1:z2] # [N_features,H,W,D] 
        else:
            brain_patch = brain_tensor_torch[x1:x2,y1:y2,z1:z2].unsqueeze(0) # [1,H,W,D] 
            
        label_patch = torch.tensor(metaindex.is_fcd, dtype=torch.long) # label
    
        return brain_patch, label_patch.unsqueeze(0)

    def __len__(self):
        return self.metadata.shape[0]

In [27]:
# torch.load('../fcd_data/normalized_tensors/tensor_1')

# Creating dataset

In [18]:
labels_components = np.load('labels_info.npy', allow_pickle=True).item()
single_component_keys = {k for k,v in labels_components.items() if len(v['cc3d'][0]) == 2}

USE_GEOM_FEATURES = True
GEOM_FEATURES = ['thickness', 'sulc', 'curv']

root_label = '../fcd_data/normalized_label'
root_data = '../fcd_data/normalized_data/'
root_geom_features = '../fcd_data/preprocessed_data_anadezhda/'

paths_dict = create_dicts(root_label,
                         root_data,
                         root_geom_features, 
                         single_component_keys,
                         USE_GEOM_FEATURES, 
                         GEOM_FEATURES)

In [19]:
len(paths_dict)

77

In [20]:
metadata = np.load('metadata.npy', allow_pickle=True).item()

# Make metadata - all patches

In [133]:
PATCH_SIZE=50
pad=PATCH_SIZE//2
DEVICE = torch.device('cuda:3')

PERC_THRESHOLD = 0.9 # how much tissue
LABEL_THRESHOLD = int((PATCH_SIZE**3)/100) # how much FCD pixels to be considered as FCD patch

In [134]:
indexes_seleted_root = f'../fcd_data/indexes_selected_P{PERC_THRESHOLD}_L{LABEL_THRESHOLD}'
if not os.path.isdir(indexes_seleted_root):
    os.makedirs(indexes_seleted_root)

In [None]:
for split_type, split_keys in metadata.items():
    for k in tqdm(split_keys):
        
        path_dict = paths_dict[k]
        brain_tensor, mask_tensor, label_tensor = load(path_dict) # float, bool, int

        X,Y,Z = mask_tensor.shape
        X_mean = X//2

        # get rid of a mid-brain
        thresh_mask = (np.arange(X) < (X_mean - pad)) | (np.arange(X) > (X_mean + pad))
        thresh_mask = np.tile(thresh_mask, (Y,Z,1)).transpose(2,0,1)
        mask_tensor = mask_tensor*(thresh_mask > 0)


        ###########################
        # CREATE RELEVANT INDEXES #
        ###########################
        xyz_grid = np.stack(np.meshgrid(np.arange(pad, X-pad), 
                                        np.arange(pad, Y-pad), 
                                        np.arange(pad, Z-pad), 
                                        indexing='ij'), -1)

        xyz_grid = xyz_grid[mask_tensor[pad:-pad,pad:-pad,pad:-pad]]

        indexes_selected = Parallel(n_jobs=-1)(delayed(check_patch)(x,y,z,\
                                                                    mask_tensor,\
                                                                    label_tensor,\
                                                                    pad,\
                                                                    p_thresh=PERC_THRESHOLD) \
                                                               for x,y,z in xyz_grid)

        indexes_selected = list(filter(lambda x: x is not None, indexes_selected))
        if len(indexes_selected) > 0:
            df = pd.DataFrame(indexes_selected)
            df = df.query(f'p_mask >= {PERC_THRESHOLD} & (n_label >= {LABEL_THRESHOLD} | n_label==0)')

            indexes_path = os.path.join(indexes_seleted_root, f'{k}')
            df.to_csv(indexes_path)

 84%|█████████████████████████████████████████████████████████████████████████████████████████████████████████                    | 58/69 [1:35:44<18:44, 102.24s/it]

Making patches - bad idea!!!

# Make metadata - pivot patches

In [None]:
# AUGMENTATION_STEPS = 2

# metadata_pivot_patches = defaultdict(dict)

# for split_type, split_keys in metadata.items():
    
#     for k in split_keys:
        
#         path_dict = paths_dict[k]
        
        
#         brain_tensor, mask_tensor, label_tensor = load(path_dict) # float, bool, int
#         label_info = labels_components[k]
#         # center of the pivot patch with FCD
#         center = label_info['center']
        
#         # shift around the pivot patch
#         aug_iter = np.arange(-AUGMENTATION_STEPS, AUGMENTATION_STEPS+1)
#         for shift in len(list(product(aug_iter, aug_iter, aug_iter))):
#             info_patch = {}
#             center_i = center + np.array(shift)
#             info_patch['patch_center'] = center_i
#             info_patch['subject'] = k
#             info_patch['label'] = 1
        
#         X,Y,Z = mask_tensor.shape
#         X_mean = X//2
        
#         # get rid of a mid-brain
#         thresh_mask = (np.arange(X) < (X_mean - pad)) | (np.arange(X) > (X_mean + pad))
#         thresh_mask = np.tile(thresh_mask, (Y,Z,1)).transpose(2,0,1)
#         mask_tensor = mask_tensor*(thresh_mask > 0)
        
#         ###########################
#         # CREATE RELEVANT INDEXES #
#         ###########################
#         xyz_grid = np.stack(np.meshgrid(np.arange(X), np.arange(Y), np.arange(Z), indexing='ij'), -1)
#         xyz_grid = xyz_grid[mask_tensor]

#         indexes_selected = Parallel(n_jobs=-1)(delayed(check_patch)(x,y,z,\
#                                                                     mask_tensor,\
#                                                                     label_tensor,\
#                                                                     pad) \
#                                                                for x,y,z in xyz_grid)

#         indexes_selected = list(filter(lambda x: x is not None, indexes_selected))
#         np.save('../fcd_data/indexes_selected/')
#         break
#     break