In [30]:
%load_ext autoreload
%autoreload 2

import torch
import torchio as tio
from ipywidgets import interact
import matplotlib.pyplot as plt

from context import Context
from utils import slice_volume

if torch.cuda.is_available():
    device = torch.device('cuda')
    print("CUDA is available. Using GPU.")
else:
    device = torch.device('cpu')
    print("CUDA is not available. Using CPU.")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
CUDA is available. Using GPU.


In [8]:
file_name = f"X:/Checkpoints/Diffusion_MRI/dmri-hippo-northern-pond-67/checkpoints/iter{100:08}.pt"
variables = dict(DATASET_PATH="X:/Datasets/Diffusion_MRI/", CHECKPOINTS_PATH="X:/Checkpoints/")
from configs.diffusion_hippocampus import *

context = Context(device, file_name=file_name, variables=variables, globals=globals())

In [32]:
dataset = context.dataset
all_subject_names = list(dataset.all_subjects_map.keys())

def vis_subject(subject):
    images = {key: val for key, val in subject.items() if isinstance(val, tio.ScalarImage)}
    label_maps = {key: val for key, val in subject.items() if isinstance(val, tio.LabelMap)}
    
    @interact(image=images, label_map=label_maps, plane=['Axial', 'Coronal', 'Saggital'])
    def select_images(image, label_map, plane):
        W, H, D = image.spatial_shape
        num_slices = {'Axial': D, 'Coronal': W, 'Saggital': H}[plane]
        
        @interact(slice_id=(0, num_slices-1))
        def select_slice(slice_id):
            image_slice = slice_volume(image['data'], 0, plane, slice_id)
            label_slice = slice_volume(label['data'], 0, )
            

def vis_features(x):
    N, C, W, H, D = x.shape
    
    @interact(i=(0, N-1), c=(0, C-1), d=(0, D-1))
    def plot_feature_map(i, c, d):
        fig = plt.figure(figsize=(10, 10))
        plt.imshow(x[i, c, :, :, d].cpu(), cmap="gray")
        plt.colorbar()
        

def vis_model(subject):
    X = subject['X']['data'].unsqueeze(0).to(device)
    modules = list(context.model.named_modules())
    
    @interact(module=modules[1:])
    def select_module(module):
    
        def forward_module_hook(module, x_in, x_out):
            vis_features(x_out.cpu())
            
        hook_handle = module.register_forward_hook(forward_module_hook)
        with torch.no_grad():
            context.model(X)
        hook_handle.remove()


@interact(name=all_subject_names)
def get_subject(name):
    subject = context.dataset[name]
    vis_model(subject)

interactive(children=(Dropdown(description='name', options=('ab300_001', 'ab300_002', 'ab300_003', 'ab300_004'…