In [1]:
%load_ext autoreload
%autoreload 2

import copy
import os

import torch
import torchio as tio
from ipywidgets import interact
import matplotlib.pyplot as plt

from torch_context import TorchContext
from evaluators import *
from transforms import *
from utils import slice_volume, load_module


if torch.cuda.is_available():
    device = torch.device('cuda')
    print("CUDA is available. Using GPU.")
else:
    device = torch.device('cpu')
    print("CUDA is not available. Using CPU.")

CUDA is available. Using GPU.


In [None]:
config = load_module("./configs/diffusion_hippocampus.py")

variables = dict(DATASET_PATH="X:/Datasets/Diffusion_MRI/", CHECKPOINTS_PATH="X:/Checkpoints/")
context = config.get_context(device, variables)
context.init_components()

In [None]:
iteration = 400
file_path = f"X:\\Checkpoints\\Diffusion_MRI\\dmri-hippo-seg-debugging\\dmri-hippo-cycle-flash-1q798mvn\\best_checkpoints\\iter{iteration:08}.pt"
variables = dict(DATASET_PATH="X:/Datasets/Diffusion_MRI/", CHECKPOINTS_PATH="X:/Checkpoints/")
context = TorchContext(device, file_path=file_path, variables=variables)
context.init_components()

In [2]:
iteration = 118
file_path = f"X:\\Checkpoints\\MSSEG2\\msseg2-hooks-capacitor-2oachxos\\iter{iteration:08}.pt"
variables = dict(DATASET_PATH="X:/Datasets/MSSEG2_resampled/")

device = torch.device('cpu')
context = TorchContext(device, file_path=file_path, variables=variables)
context.init_components()

In [None]:
print(context)

In [None]:
seg_evaluator = SegmentationEvaluator

In [None]:
config = load_module("./configs/msseg2.py")

variables = dict(DATASET_PATH="X:/Datasets/MSSEG2_processed/")
context = config.get_context(device, variables)
context.init_components()

In [None]:
training_dataset = context.dataset.get_cohort_dataset('training')
validation_dataset = context.dataset.get_cohort_dataset('validation')
print(len(training_dataset))
[subject['name'] for subject in training_dataset.subjects]

In [5]:
dataset = context.dataset
from transforms import *
all_subject_names = list(dataset.all_subjects_map.keys())
import matplotlib

def vis_subject(subject):
    images = {key: val for key, val in subject.items() if isinstance(val, tio.ScalarImage)}
    label_maps = {key: val for key, val in subject.items() if isinstance(val, tio.LabelMap)}
    
    @interact(image_name=images.keys(), label_map_name=label_maps.keys(), plane=['Axial', 'Coronal', 'Saggital', 'interesting', 'random',])
    def select_images(image_name, label_map_name, plane):
        label_map_name = 'y'
        image = images[image_name]
        label_map = label_maps[label_map_name]
        W, H, D = image.spatial_shape
        if plane == 'random':
            plane = ('Axial', 'Coronal', 'Saggital')[random.randint(0, 2)]
        num_slices = {'Axial': D, 'Coronal': H, 'Saggital': W, 'interesting': 20,}[plane]
        
        @interact(save=False, show_labels=True, legend=True, ticks=False, scale=(0.05, 0.15, 0.01), line_width=(0.5, 2.5),
                 slice_id=(0, num_slices-1), interesting_slice=False)
        def select_slice(save, show_labels, legend, ticks, scale, line_width, slice_id, interesting_slice):
            
            if 'y_pred' in subject:
                prediction_label_map_name = 'y_pred'
            else:
                prediction_label_map_name = None
    
            evaluator = ContourImageEvaluator(
                plane=plane, image_name=image_name, 
                target_label_map_name=label_map_name if show_labels else None, 
                prediction_label_map_name=prediction_label_map_name if show_labels else None, 
                slice_id=slice_id, legend=legend, ncol=1, scale=scale, line_width=line_width,
                interesting_slice=interesting_slice
            )
            
            pil_image = evaluator([subject])
            fig = plt.figure(figsize=(10, 10))
            plt.imshow(pil_image)
            if not ticks:
                plt.tick_params(which='both', bottom=False, top=False, left=False, labelbottom=False, labelleft=False)
            if save:
                save_dir = f"./images/{context.name}/iter{context.iteration:08}/"
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                file_name = f"{subject['name']}_{image_name}_{plane}_{slice_id}.png"
                fig.savefig(save_dir + file_name, bbox_inches="tight", pad_inches=0.0, facecolor="black")
            plt.show()
            plt.close(fig)
            
def vis_features(x):
    N, C, W, H, D = x.shape
    
    @interact(i=(0, N-1), c=(0, C-1), d=(0, D-1))
    def plot_feature_map(i, c, d):
        fig = plt.figure(figsize=(8, 8))
        plt.imshow(x[i, c, :, :, d].cpu(), cmap="gray")
        plt.colorbar()
        

def vis_model(subject):
    X = subject['X']['data'].unsqueeze(0).to(device)
    modules = list(context.model.named_modules())
    
    @interact(module=modules[1:])
    def select_module(module):
    
        def forward_module_hook(module, x_in, x_out):
            vis_features(x_out.cpu())
            
        hook_handle = module.register_forward_hook(forward_module_hook)
        with torch.no_grad():
            context.model(X)
        hook_handle.remove()


@interact(name=all_subject_names, mode=['vis_subject', 'model_contour', 'vis_model'])
def vis(name, mode):
    subject = context.dataset[name]
    context.model.eval()
    
    if mode == 'vis_subject':
        
        augmentation = tio.RandomAffine(scales=0.2, degrees=45)
        augmentation = tio.RandomElasticDeformation()
        subject = augmentation(subject)
        
        vis_subject(subject)
    
    elif mode == 'model_contour':
        #subject = tio.CropOrPad((96, 96, 96))(subject)
        subject = tio.EnsureShapeMultiple(32)(subject)
        
        X = subject['X']['data'][None].to(device)
        with torch.no_grad():
            y_pred = context.model(X)[0]
            y_pred = y_pred > 0.5
        subject['y_pred'] = copy.deepcopy(subject['y'])
        subject['y_pred'].set_data(y_pred)
            
        vis_subject(subject)
    
    elif mode == 'vis_model':
        subject = tio.CropOrPad((96, 96, 96))(subject)
        vis_model(subject)

  and should_run_async(code)


interactive(children=(Dropdown(description='name', options=('013', '015', '016', '018', '019', '020', '021', '…

In [None]:
for subject in context.dataset:
    time01 = subject['flair_time01'].data
    time01 = subject['flair_time01'].data
    print(subject['name'], subject['flair_time01'], time01.mean(), time01.std(), time01.min(), time01.max())

In [None]:
for subject in context.dataset:
    print(subject['name'], subject['flair_time01'])

In [None]:
subject = context.dataset[0]

In [None]:
import numpy as np

mask = subject['hbt_roi'].data.bool()[0]
W, H, D = mask.shape
W_where, H_where, D_where = torch.where(mask)

slice_ids, counts = torch.unique(W_where, return_counts=True)
interesting_slice_ids_ids = torch.argsort(counts, descending=True)
interesting_slice_ids = slice_ids[interesting_slice_ids_ids]

print(interesting_slice_ids)


In [None]:
subject = FindInterestingSlice()(subject)


In [None]:
subject['y']['interesting_slice_ids']

In [None]:
x = torch.ones(1)
x.dtype == torch.float

In [None]:
from random import Random

fold_ids = [i % 5 for i in range(42)]
Random(0).shuffle(fold_ids)
fold_ids

In [None]:
dataset

In [None]:
import numpy as np

x = torch.ones(3, 5, 7)

perm = (1, 2, 0)
x = x.permute(perm)

inverse_perm = tuple(torch.argsort(torch.tensor(perm)).tolist())
print(inverse_perm)
x = x.permute(inverse_perm)

print(x.shape)

In [None]:
from itertools import permutations

list(permutations((0, 1, 2)))