<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#DICOM-Images" data-toc-modified-id="DICOM-Images-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>DICOM Images</a></span></li><li><span><a href="#Nifti-Maker" data-toc-modified-id="Nifti-Maker-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Nifti Maker</a></span></li><li><span><a href="#xarray-generation" data-toc-modified-id="xarray-generation-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>xarray generation</a></span></li><li><span><a href="#xarray-viewer" data-toc-modified-id="xarray-viewer-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>xarray viewer</a></span></li><li><span><a href="#Model-Training" data-toc-modified-id="Model-Training-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Model Training</a></span></li><li><span><a href="#Slurm-Analysis" data-toc-modified-id="Slurm-Analysis-6"><span class="toc-item-num">6&nbsp;&nbsp;</span>Slurm Analysis</a></span></li></ul></div>

# New Data Exploration

## DICOM Images

In [None]:
%matplotlib inline

In [None]:
from pathlib import Path
from datetime import datetime
from collections import OrderedDict
import numpy as np
import pickle as pkl
import pandas as pd
import xarray as xr
import SimpleITK as sitk

import holoviews as hv
from holoviews import opts
import panel as pn
import hvplot.pandas
hv.extension('bokeh')
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (10,8)

from mre.plotting import patient_series_viewer, chaos_viewer, xr_viewer, hv_dl_vis_chaos
from mre.preprocessing import make_nifti_atlas_v2, make_xr_dataset_for_chaos
from mre.segmentation import ChaosDataset
from mre.train_seg_model import train_seg_model 
from mre import pytorch_arch

from torch.utils.data import Dataset, DataLoader
import torchvision.utils
from torchsummary import summary
import torch
import torch.nn as nn
from collections import defaultdict
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import models

%load_ext autoreload
%autoreload 2

In [None]:
torch.__version__

In [None]:
data_dir = Path('/pghbio/dbmi/batmanlab/bpollack/predictElasticity/data/CHAOS/Train_Sets/MR/')

In [None]:
# patient_series_viewer(data_dir, 'DICOMA/PA1/ST0')
# patient_series_viewer(data_dir, '1', img_type='DICOM_CHAOS')

## Nifti Maker

In [None]:
#make_nifti_atlas_v2()

In [None]:
#patient_series_viewer(data_dir, 'NIFTI/01', img_type='NIFTI')

In [None]:
#chaos_viewer(data_dir, 'NIFTI/03')


## xarray generation

In [None]:
ls ../data/CHAOS/Train_Sets/MR/NIFTI/

In [None]:
# patients = ["01",  "03",  "08",  "13",  "19",  "21",  "31",  "33",  "36",  "38",
# "02",  "05",  "10",  "15",  "20",  "22",  "32",  "34",  "37",  "39"] 
# ds = make_xr_dataset_for_chaos(patients, 256, 256, 32, 'chaos')


## xarray viewer

In [None]:
ds_path = Path(data_dir, 'xarray_chaos.nc')
ds = xr.open_dataset(ds_path)

In [None]:
#xr_viewer(ds, overlay_data='mask')
#ds

## Model Training

In [None]:
# Setup paths
out_dir = '/pghbio/dbmi/batmanlab/bpollack/predictElasticity/data/CHAOS/'

In [None]:
torch.cuda.empty_cache()

subj = '01'
version = None
n_layers = 5
model_cap = 16
channel_growth = True
seq_mode = 't1_out'
now = datetime.today().strftime('%Y-%m-%d_%H-%M-%S')
if version is None: version = now
#model_version=f'chaos_notebook_test_{version}'
model_version=version
print(now)
inputs, targets, names, model = train_seg_model(data_dir, 'xarray_chaos.nc', out_dir, model_version=model_version, subj=subj, loss='dice', dry_run=False,
                                                transform=True, def_seq_mode=seq_mode, coord_conv=False, step_size=80, num_epochs=200, lr=1e-2, 
                                                model_arch='2D', resize=False, n_layers=n_layers, channel_growth=channel_growth,
                                                model_cap=model_cap, batch_size=200, test_seq_mode='all', test_aug=False, train_aug=True, val_aug=True,
                                                bc_weight=0.1)

# model_path = Path(out_dir, 'trained_models', subj, f'model_{model_version}.pkl')
# model = pytorch_arch.GeneralUNet3D(n_layers, 1, model_cap, 1, channel_growth, False, False)
# model_dict = torch.load(model_path, map_location='cpu')
# model_dict = OrderedDict([(key[7:], val) for key, val in model_dict.items()])
# model.load_state_dict(model_dict, strict=True)
# model.eval()

In [16]:
model_pred = None
if model:
    # inputs.to('cuda:0')
    model_pred = torch.zeros_like(inputs)
    for i in range(inputs.shape[0]):
        for j in range(inputs.shape[1]):
            model_pred[i, j, :] = model(inputs[i:i+1, j:j+1, :])
            model_pred[i, j, :] = F.sigmoid(model_pred[i, j, :])
            ones = torch.ones_like(model_pred[i, j, :])
            zeros = torch.zeros_like(model_pred[i, j, :])
        # model_pred[:, i, :] = torch.where(model_pred[:, i, :]>3e-3, ones, zeros)
    inputs.to('cpu')
hv_dl_vis_chaos(inputs, targets, names, ['t1_in', 't1_out', 't2'], model_pred)
# hv_dl_vis_chaos(inputs, targets, names, ['seq'], model_pred)



## Slurm Analysis

In [None]:
config_path = Path(out_dir, 'config')

In [None]:
df = pd.DataFrame()
for f in list(config_path.glob('*2019-09-16_14-09-07*.pkl')):
    s_tmp = pd.Series(pd.read_pickle(str(f)), name=f.stem)
    df = df.append(s_tmp, ignore_index=False)

In [None]:
pd.set_option('display.max_columns', None)
df.head()

In [None]:
df['test_dice_mean'] = (df.test_dice_t1_in+df.test_dice_t1_out+df.test_dice_t2)/3.0

In [None]:
df1 = df.query('channel_growth==1').sort_values('test_dice_t1_in').reset_index().rename(columns={'index':'job_name'})
df2 = df.query('channel_growth==0').sort_values('test_dice_t1_in').reset_index().rename(columns={'index':'job_name'})
(df1.hvplot.line(x='index', y='test_dice_t1_in', hover_cols=['model_cap', 'def_seq_mode', 'bce_weight'], label='C Growth')*
df2.hvplot.line(x='index', y='test_dice_t1_in', hover_cols=['model_cap', 'def_seq_mode', 'bce_weight'], label='C Static')).opts(legend_position='top_left', show_legend=True)

In [None]:
df1 = df.query('def_seq_mode=="t1_in"').sort_values('test_dice_mean').reset_index().rename(columns={'index':'job_name'})
df2 = df.query('def_seq_mode=="t1_out"').sort_values('test_dice_mean').reset_index().rename(columns={'index':'job_name'})
df3 = df.query('def_seq_mode=="t2"').sort_values('test_dice_mean').reset_index().rename(columns={'index':'job_name'})
df4 = df.query('def_seq_mode=="random"').sort_values('test_dice_mean').reset_index().rename(columns={'index':'job_name'})
(df1.hvplot.line(x='index', y='test_dice_mean', hover_cols=['model_cap', 'def_seq_mode', 'bce_weight', 'job_name'], label='t1_in')*
df2.hvplot.line(x='index', y='test_dice_mean', hover_cols=['model_cap', 'def_seq_mode', 'bce_weight', 'job_name'], label='t1_out')*
df3.hvplot.line(x='index', y='test_dice_mean', hover_cols=['model_cap', 'def_seq_mode', 'bce_weight', 'job_name'], label='t2')*
df4.hvplot.line(x='index', y='test_dice_mean', hover_cols=['model_cap', 'def_seq_mode', 'bce_weight', 'job_name'], label='random')
).opts(legend_position='top_left', show_legend=True)

In [None]:
df1 = df.query('bce_weight==0.1').sort_values('best_loss').reset_index().rename(columns={'index':'job_name'})
df2 = df.query('bce_weight==0.3').sort_values('best_loss').reset_index().rename(columns={'index':'job_name'})
df3 = df.query('bce_weight==0.5').sort_values('best_loss').reset_index().rename(columns={'index':'job_name'})
df4 = df.query('bce_weight==0.7').sort_values('best_loss').reset_index().rename(columns={'index':'job_name'})
(df1.hvplot.line(x='index', y='best_loss', hover_cols=['model_cap', 'def_seq_mode', 'bce_weight'], label='0.1')*
df2.hvplot.line(x='index', y='best_loss', hover_cols=['model_cap', 'def_seq_mode', 'bce_weight'], label='0.3')*
df3.hvplot.line(x='index', y='best_loss', hover_cols=['model_cap', 'def_seq_mode', 'bce_weight'], label='0.5')*
df4.hvplot.line(x='index', y='best_loss', hover_cols=['model_cap', 'def_seq_mode', 'bce_weight'], label='0.7')
).opts(legend_position='top_left', show_legend=True)

In [None]:

df1 = df.query('model_cap==8').sort_values('test_dice_mean').reset_index().rename(columns={'index':'job_name'})
df2 = df.query('model_cap==12').sort_values('test_dice_mean').reset_index().rename(columns={'index':'job_name'})
df3 = df.query('model_cap==16').sort_values('test_dice_mean').reset_index().rename(columns={'index':'job_name'})
(df1.hvplot.line(x='index', y='test_dice_mean', hover_cols=['model_cap', 'def_seq_mode', 'bce_weight'], label='8')*
df2.hvplot.line(x='index', y='test_dice_mean', hover_cols=['model_cap', 'def_seq_mode', 'bce_weight'], label='12')*
df3.hvplot.line(x='index', y='test_dice_mean', hover_cols=['model_cap', 'def_seq_mode', 'bce_weight'], label='16')
).opts(legend_position='top_left', show_legend=True)

In [None]:
df.groupby(['model_cap', 'def_seq_mode'])['test_dice_t1_out'].mean()

Notes: t1_out seems to outperform all other combos (including random).  Best current overall: t1_out, model_cap=8.  Why would adding additional images decrease performance?

In [None]:
df.query('model_cap==8 and def_seq_mode=="t1_out" and subj=="01"')