In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torchio as tio
from ipywidgets import interact
import matplotlib.pyplot as plt

from context import Context
from evaluators import *
from utils import slice_volume, load_module

if torch.cuda.is_available():
    device = torch.device('cuda')
    print("CUDA is available. Using GPU.")
else:
    device = torch.device('cpu')
    print("CUDA is not available. Using CPU.")

CUDA is available. Using GPU.


In [None]:
file_name = f"X:/Checkpoints/Diffusion_MRI/dmri-hippo-northern-pond-67/checkpoints/iter{100:08}.pt"
variables = dict(DATASET_PATH="X:/Datasets/Diffusion_MRI/", CHECKPOINTS_PATH="X:/Checkpoints/")
from configs.diffusion_hippocampus import *

context = Context(device, file_name=file_name, variables=variables, globals=globals())

In [2]:
config = load_module("./configs/msseg2.py")

variables = dict(DATASET_PATH="X:/Datasets/MSSEG2_processed/", CHECKPOINTS_PATH="X:/Checkpoints/")
context = config.get_context(device, variables)

In [None]:
training_dataset = context.dataset.get_cohort_dataset('training')
validation_dataset = context.dataset.get_cohort_dataset('validation')
print(len(training_dataset))
[subject['name'] for subject in training_dataset.subjects]

In [3]:
dataset = context.dataset
all_subject_names = list(dataset.all_subjects_map.keys())

def vis_subject(subject):
    images = {key: val for key, val in subject.items() if isinstance(val, tio.ScalarImage)}
    label_maps = {key: val for key, val in subject.items() if isinstance(val, tio.LabelMap)}
    
    @interact(image_name=images.keys(), label_map_name=label_maps.keys(), plane=['Axial', 'Coronal', 'Saggital'])
    def select_images(image_name, label_map_name, plane):
        label_map_name = 'y'
        image = images[image_name]
        label_map = label_maps[label_map_name]
        W, H, D = image.spatial_shape
        num_slices = {'Axial': D, 'Coronal': H, 'Saggital': W}[plane]
        
        @interact(slice_id=(0, num_slices-1))
        def select_slice(slice_id):
            evaluator = ContourImageEvaluator(
                plane=plane, image_name=image_name, 
                target_label_map_name=label_map_name, 
                prediction_label_map_name=None, 
                slice_id=slice_id, legend=True, ncol=1
            )
            evaluation = evaluator([subject])
            pil_image = evaluation['image']
            fig = plt.figure(figsize=(10, 10))
            plt.imshow(pil_image, aspect='equal')
            plt.show()
            plt.close(fig)
            
def vis_features(x):
    N, C, W, H, D = x.shape
    
    @interact(i=(0, N-1), c=(0, C-1), d=(0, D-1))
    def plot_feature_map(i, c, d):
        fig = plt.figure(figsize=(10, 10))
        plt.imshow(x[i, c, :, :, d].cpu(), cmap="gray")
        plt.colorbar()
        

def vis_model(subject):
    X = subject['X']['data'].unsqueeze(0).to(device)
    modules = list(context.model.named_modules())
    
    @interact(module=modules[1:])
    def select_module(module):
    
        def forward_module_hook(module, x_in, x_out):
            vis_features(x_out.cpu())
            
        hook_handle = module.register_forward_hook(forward_module_hook)
        with torch.no_grad():
            context.model(X)
        hook_handle.remove()


@interact(name=all_subject_names)
def vis(name):
    subject = context.dataset[name]
    vis_subject(subject)

interactive(children=(Dropdown(description='name', options=('013', '015', '016', '018', '019', '020', '021', '…

In [None]:
for subject in context.dataset:
    time01 = subject['flair_time01'].data
    time01 = subject['flair_time01'].data
    print(subject['name'], subject['flair_time01'], time01.mean(), time01.std(), time01.min(), time01.max())

In [None]:
for subject in context.dataset:
    print(subject['name'], subject['flair_time01'])

In [None]:
subject = context.dataset[0]

In [None]:
import numpy as np

mask = subject['brain_mask'].data.bool()[0]
W, H, D = mask.shape
W_where, H_where, D_where = np.where(mask)
cropping = (
    W_where.min(), W - W_where.max(), 
    H_where.min(), H - H_where.max(), 
    D_where.min(), D - D_where.max()
)
cropping