In [1]:
%load_ext autoreload
import glob
import nibabel
import pandas as pd
import numpy as np

from collections import defaultdict, Counter
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook, tqdm
from joblib import Parallel, delayed

from IPython.core.debugger import set_trace

import os
import shutil
import argparse
import time
import json
import pickle

import torch
from torch import nn
from torch import autograd
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from models.v2v import V2VModel

import yaml
from easydict import EasyDict as edict

from utils import show_slices, check_patch, get_symmetric_value, pad_arrays, create_dicts, normalized, load

from multiprocessing import cpu_count
N_CPU = cpu_count()

from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.model_selection import train_test_split

import cc3d

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important}</style>"))

from celluloid import Camera

SEED = 42
%autoreload 2

In [2]:
labels_components = np.load('labels_info.npy', allow_pickle=True).item()
single_component_keys = {k for k,v in labels_components.items() if len(v['cc3d'][0]) == 2}

In [3]:
len(single_component_keys)

87

In [4]:
USE_GEOM_FEATURES = True
GEOM_FEATURES = ['thickness', 'sulc', 'curv']

root_label = '../fcd_data/normalized_label'
root_data = '../fcd_data/normalized_data/'
root_geom_features = '../fcd_data/preprocessed_data_anadezhda/'

paths_dict = create_dicts(root_label,
                         root_data,
                         root_geom_features, 
                         single_component_keys,
                         USE_GEOM_FEATURES, 
                         GEOM_FEATURES)

In [5]:
len(paths_dict)

77

In [6]:
def make_tensor(root, label, pdict, USE_GEOM_FEATURES):
    
    brain_tensor, mask_tensor, label_tensor = load(pdict)
    
    ################
    # MAKE TENSORS #
    ################
    
    label_tensor_torch = torch.tensor(label_tensor, dtype=torch.float32)
    mask_tensor_torch = torch.tensor(mask_tensor, dtype=torch.long)
    brain_tensor_torch = torch.tensor(brain_tensor, dtype=torch.float32)
    
    mask = mask_tensor_torch.type(torch.bool)
    brain_tensor_torch[mask] = normalized(brain_tensor_torch[mask])
    brain_tensor_torch[~mask] = 0
    
    torch_tensor = {'brain':brain_tensor_torch, 
                      'mask':mask_tensor_torch,
                      'label':label_tensor_torch}
    
    if USE_GEOM_FEATURES:
        GEOM_FEATURES = list(set(pdict.keys()) - {'label', 'brain', 'mask'})
        feature_tensors = [nibabel.load(pdict[f'{feature_name}']).get_fdata() \
                           for feature_name in GEOM_FEATURES]
        
        for i, feature_name in enumerate(GEOM_FEATURES):
            feature_tensor = torch.tensor(feature_tensors[i], dtype=torch.float)
            feature_tensor[mask] = normalized(feature_tensor[mask])
            feature_tensor[~mask] = 0
            
            torch_tensor[feature_name] = feature_tensor
        
    torch.save(torch_tensor, os.path.join(root, f'tensor_{label}'))

_ = Parallel(n_jobs=-1)(delayed(make_tensor)('../fcd_data/normalized_tensors', k, v, USE_GEOM_FEATURES) for k,v in tqdm(paths_dict.items()))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 77/77 [00:00<00:00, 139.99it/s]


In [78]:
# make_tensor('.', '35', paths_dict['35'], USE_GEOM_FEATURES)
# tensor = torch.load('tensor_35')

In [13]:
def video(brain_tensor, mask_tensor=None, n_slides=100):
    
    fig, ax = plt.subplots()
    X_max, Y_max, Z_max = brain_tensor.shape
    camera = Camera(fig)
    
    for i in range(n_slides):
        
        y_slice_pos = (Y_max//(n_slides+2))*(i+1)
        
        brain_tensor_slice = brain_tensor[:,y_slice_pos,:]
        ax.imshow(brain_tensor_slice, 'gray')
        
        if mask_tensor is not None:
            
            mask_tensor_slice = mask_tensor[:,y_slice_pos,:]
            ax.imshow(mask_tensor_slice, 'jet', interpolation='none', alpha=0.7)
        
        camera.snap()
        
    return camera   
    

In [8]:
tensor = torch.load('../fcd_data/normalized_tensors/tensor_1')

In [9]:
# tensor['mask']

In [10]:
for k,v in tensor.items():
    print(k, v.max(), v.min(), v.shape)

brain tensor(1.) tensor(0.) torch.Size([241, 336, 283])
mask tensor(1) tensor(0) torch.Size([241, 336, 283])
label tensor(1.) tensor(0.) torch.Size([241, 336, 283])
sulc tensor(1.) tensor(0.) torch.Size([241, 336, 283])
thickness tensor(1.) tensor(0.) torch.Size([241, 336, 283])
curv tensor(1.) tensor(0.) torch.Size([241, 336, 283])


In [11]:
# tensor['brain'].numpy()

In [14]:
plt.ioff()
# camera = video(*data_test[0])
camera = video(tensor['brain'].numpy(), tensor['label'].numpy()) 
animation = camera.animate() # animation ready
HTML(animation.to_html5_video()) # displaying the animation

# Metadata

In [21]:
metadata_root = './metadata'
train_keys, test_keys = train_test_split(list(paths_dict.keys()), test_size=0.1, random_state=SEED)
metadata = {'train':train_keys,
            'test':test_keys,
            'seed':SEED}

np.save(metadata_root, metadata)

In [22]:
metadata = np.load('metadata.npy',allow_pickle=True).item()

In [48]:
def trim(brain_tensor, mask_tensor, label_tensor):
    '''
    mask_tensor - [H,W,D]
    brain_tensor - [N_features, H,W,D]
    label_tensor - [H,W,D]
    
    '''
    X,Y,Z = mask_tensor.shape
    
    X_mask = mask_tensor.sum(dim=[1,2]) > 0
    Y_mask = mask_tensor.sum(dim=[0,2]) > 0
    Z_mask = mask_tensor.sum(dim=[0,1]) > 0
    
    brain_tensor_trim = brain_tensor[:,X_mask][:,:,Y_mask][:,:,:,Z_mask]
    mask_tensor_trim = mask_tensor[X_mask][:,Y_mask][:,:,Z_mask]    
    label_tensor_trim = label_tensor[X_mask][:,Y_mask][:,:,Z_mask]
    
    return brain_tensor_trim, mask_tensor_trim, label_tensor_trim

In [49]:
brain_tensor, mask_tensor, label_tensor = trim(tensor['brain'].unsqueeze(0), tensor['mask'], tensor['label'])

In [54]:
plt.ioff()
# camera = video(*data_test[0])
camera = video(brain_tensor[0], mask_tensor=label_tensor) 
animation = camera.animate() # animation ready
HTML(animation.to_html5_video()) # displaying the animation