In [28]:
import os
import argparse
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.tensorboard import SummaryWriter

import fastmri
from fastmri.data import transforms
from fastmri.models.unet import Unet
from fastmri.models.varnet import *

import sigpy as sp
from sigpy import from_pytorch
import sigpy.plot as pl
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (10,10)
import skimage.metrics

from dloader import genDataLoader

In [2]:
# # command line argument parser
# parser = argparse.ArgumentParser(
#     description = 'define parameters and roots for STL training'
# )

# # hyperparameters
# parser.add_argument(
#     '--epochs', default=2, type=int,
#     help='number of epochs to run'
# )
# parser.add_argument(
#     '--lr', default=0.0002, type=float,
#     help='learning rate'
# )


# # model training
# parser.add_argument(
#     '--numblocks', default=12, type=int,
#     help='number of unrolled blocks in total'
# )
# parser.add_argument(
#     '--network', default='varnet',
#     help='type of network ie unet or varnet'
# )
# parser.add_argument(
#     '--device', default='cuda:2',
#     help='cuda:2 device default'
# )


# # dataset properties
# parser.add_argument(
#     '--datasets', nargs='+',
#     help='names of one or two sets of data files i.e. div_coronal_pd',
#     required = True
# )
# parser.add_argument(
#     '--scarcities', default=[3, 4], type=int, nargs='+',
#     help='number of samples in second contrast will be decreased by 1/2^N'
#     )
# parser.add_argument(
#     '--undersamples', default=[6], type=int, nargs='+',
#     help='undersampling factor of k-space'
#     )
# parser.add_argument(
#     '--centerfracs', default=[0.06], type=int, nargs='+',
#     help='center fractions sampled of k-space'
#     )


# # paths
# parser.add_argument(
#     '--modelpath', default='models',
#     help='path to save best model'
# )


# # save / display data
# parser.add_argument(
#     '--experimentname', default='unnamed_experiment',
#     help='experiment name i.e. STL_unet or MTAN_pareto_varnet etc.'
# )
# parser.add_argument(
#     '--verbose', default=True, type=bool,
#     help='''if true, prints to console and creatues full TensorBoard
#     (if tensorboard is also True)'''
# )
# parser.add_argument(
#     '--tensorboard', default=True, type=bool,
#     help='if true, creates TensorBoard'
# )
# parser.add_argument(
#     '--savefreq', default=15, type=int,
#     help='how many epochs per saved recon image'
# )

# opt = parser.parse_args()

In [3]:
# We can make one iteration block like this
class VarNetBlock(nn.Module):
    """
    This model applies a combination of soft data consistency with the input
    model as a regularizer. A series of these blocks can be stacked to form
    the full variational network.
    """

    def __init__(self, model: nn.Module):
        """
        Args:
            model: Module for "regularization" component of variational
                network.
        """
        super().__init__()

        self.model = model
        self.eta = nn.Parameter(torch.ones(1))

    def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
        return fastmri.fft2c(fastmri.complex_mul(x, sens_maps)) # F*S operator

    def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
        x = fastmri.ifft2c(x)
        return fastmri.complex_mul(x, fastmri.complex_conj(sens_maps)).sum(
            dim=1, keepdim=True
        ) # S^H * F^H operator

    def forward(
        self,
        current_kspace: torch.Tensor,
        ref_kspace: torch.Tensor,
        mask: torch.Tensor,
        sens_maps: torch.Tensor,
    ) -> torch.Tensor:
        mask = mask.bool()
        zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace)
        soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.eta
        model_term = self.sens_expand(
            self.model(self.sens_reduce(current_kspace, sens_maps)), sens_maps
        )

        return current_kspace - soft_dc - model_term

In [4]:
# now we can stack VarNetBlocks to make a unrolled VarNet (with 10 blocks)


class VarNet(nn.Module):
    """
    A full variational network model.

    This model applies a combination of soft data consistency with a U-Net
    regularizer. To use non-U-Net regularizers, use VarNetBock.
    """

    def __init__(
        self,
        num_cascades: int = 12,
        chans: int = 18,
        pools: int = 4,
    ):
        super().__init__()

        self.cascades = nn.ModuleList(
            [VarNetBlock(NormUnet(chans, pools)) for _ in range(num_cascades)]
        )
        
    def forward(
        self,
        masked_kspace: torch.Tensor, 
        mask: torch.Tensor,
        sens_maps: torch.Tensor
    ) -> torch.Tensor:
        
        kspace_pred = masked_kspace.clone()

        for cascade in self.cascades:
            kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps)
        
        im_coil = fastmri.ifft2c(kspace_pred)
        im_comb = fastmri.complex_mul(im_coil, fastmri.complex_conj(sens_maps)).sum(
            dim=1, keepdim=True
        )
        
        return kspace_pred, im_comb

In [5]:
def count_parameters(model):
    return sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )

def _test_result(im_fs: torch.Tensor, im_us: torch.Tensor) -> np.ndarray:

    with torch.no_grad():
        im_us = from_pytorch(im_us.cpu().detach(),iscomplex = True)
        im_fs = from_pytorch(im_fs.cpu().detach(), iscomplex = True)
        im_us = np.abs(im_us).squeeze()
        im_fs = np.abs(im_fs).squeeze()
        
        im_us = sp.resize(im_us, [360, 320])
        im_fs = sp.resize(im_fs, [360, 320])
        
        out_cat = np.concatenate((im_fs, im_us), 1)
        error_cat = np.concatenate((im_fs, im_fs), 1)
        error_cat = np.abs(error_cat - out_cat) * 5
        
        out_cat = np.concatenate((error_cat, out_cat,), axis=0)
        out_cat = out_cat * 1.5  
        
    return np.flip(out_cat)


def plot_quadrant(im_fs, im_us):
    fig = plt.figure()
    plt.imshow(_test_result(im_fs, im_us), cmap = 'gray', vmax = 2.5) # or normalize between 0-1
    plt.close(fig)
    return fig


def _param_dict(lr, epochs, undersampling, scarcities, center_fractions):
    params = {}
    params['lr'] = lr
    params['epochs'] = epochs
    
    for i in range(len(undersampling)):
        params[f'accerlation_{i}'] = undersampling[i]
    
    for i in range(len(scarcities)):
        params[f'scarcity_{i}'] = scarcities[i]
        
    for i in range(len(center_fractions)):
        params[f'center_fraction_{i}'] = center_fractions[i]
    return params


def write_tensorboard(writer, avg_cost, model, total_epochs, ratio, opt):
    if len(avg_cost.keys()) == 2:
        write_tensorboard_one_contrasts(
            writer, avg_cost, model, total_epochs, ratio, opt
        )
    else:
        write_tensorboard_two_contrasts(
            writer, avg_cost, model, total_epochs, ratio, opt
        )


def write_tensorboard_two_contrasts(writer, avg_cost, model, total_epochs, ratio, opt):
    #write to tensorboard ###opt###
    contrast_1, contrast_2, _ = avg_cost.keys()
    
    for epoch in range(0, total_epochs): ###opt###
        writer.add_scalars(
            f'{ratio}/l1', {
                f'train/{contrast_1}' : avg_cost[contrast_1][epoch, 0],
                f'val/{contrast_1}' : avg_cost[contrast_1][epoch, 4],
                f'train/{contrast_2}' : avg_cost[contrast_2][epoch, 0],
                f'val/{contrast_2}' : avg_cost[contrast_2][epoch, 4],
            }, 
            epoch
        )

        writer.add_scalars(
            f'{ratio}/ssim', {
                f'train/{contrast_1}' : avg_cost[contrast_1][epoch, 1],
                f'val/{contrast_1}' : avg_cost[contrast_1][epoch, 5],
                f'train/{contrast_2}' : avg_cost[contrast_2][epoch, 1],
                f'val/{contrast_2}' : avg_cost[contrast_2][epoch, 5],
            }, 
            epoch
        )

        writer.add_scalars(
            f'{ratio}/psnr', {
                f'train/{contrast_1}' : avg_cost[contrast_1][epoch, 2],
                f'val/{contrast_1}' : avg_cost[contrast_1][epoch, 6],
                f'train/{contrast_2}' : avg_cost[contrast_2][epoch, 2],
                f'val/{contrast_2}' : avg_cost[contrast_2][epoch, 6],
            }, 
            epoch
        )

        writer.add_scalars(
            f'{ratio}/nrmse', {
                f'train/{contrast_1}' : avg_cost[contrast_1][epoch, 3],
                f'val/{contrast_1}' : avg_cost[contrast_1][epoch, 7],
                f'train/{contrast_2}' : avg_cost[contrast_2][epoch, 3],
                f'val/{contrast_2}' : avg_cost[contrast_2][epoch, 7],
            }, 
            epoch
        )
        
        writer.add_scalars(
            'overall/l1', {
                f'val/{ratio}/{contrast_1}' : avg_cost[contrast_1][epoch, 4],
                f'val/{ratio}/{contrast_2}' : avg_cost[contrast_2][epoch, 4],
            }, 
            epoch
        )

        writer.add_scalars(
            'overall/ssim', {
                f'val/{ratio}/{contrast_1}' : avg_cost[contrast_1][epoch, 5],
                f'val/{ratio}/{contrast_2}' : avg_cost[contrast_2][epoch, 5],
            }, 
            epoch
        )

        writer.add_scalars(
            'overall/psnr', {
                f'val/{ratio}/{contrast_1}' : avg_cost[contrast_1][epoch, 6],
                f'val/{ratio}/{contrast_2}' : avg_cost[contrast_2][epoch, 6],
            }, 
            epoch
        )

        writer.add_scalars(
            'overall/nrmse', {
                f'val/{ratio}/{contrast_1}' : avg_cost[contrast_1][epoch, 7],
                f'val/{ratio}/{contrast_2}' : avg_cost[contrast_2][epoch, 7],
            }, 
            epoch
        )

    writer.add_text(
        'parameters', 
        f'{count_parameters(model)} parameters'
    )
    
    ###opts###
    writer.add_hparams(
        _param_dict(0.001, 2, [6], [0], [0.06]), 
        {'overall/l1':0}
    )
    
def write_tensorboard_one_contrasts(writer, avg_cost, model, total_epochs, ratio, opt):
    #write to tensorboard ###opt###
    contrast_1, _ = avg_cost.keys()
    
    for epoch in range(total_epochs):
        
        writer.add_scalars(
            f'{ratio}/l1', {
                f'train/{contrast_1}' : avg_cost[contrast_1][epoch, 0],
                f'val/{contrast_1}' : avg_cost[contrast_1][epoch, 4],
            }, 
            epoch
        )

        writer.add_scalars(
            f'{ratio}/ssim', {
                f'train/{contrast_1}' : avg_cost[contrast_1][epoch, 1],
                f'val/{contrast_1}' : avg_cost[contrast_1][epoch, 5],
            }, 
            epoch
        )

        writer.add_scalars(
            f'{ratio}/psnr', {
                f'train/{contrast_1}' : avg_cost[contrast_1][epoch, 2],
                f'val/{contrast_1}' : avg_cost[contrast_1][epoch, 6],
            }, 
            epoch
        )

        writer.add_scalars(
            f'{ratio}/nrmse', {
                f'train/{contrast_1}' : avg_cost[contrast_1][epoch, 3],
                f'val/{contrast_1}' : avg_cost[contrast_1][epoch, 7],
            }, 
            epoch
        )

    writer.add_text(
        'parameters', 
        f'{count_parameters(model)} parameters'
    )
    
    ###opts###
    writer.add_hparams(
        _param_dict(0.001, 2, [6], [0], [0.06]), 
        {f'{ratio}/l1':0}
    )

In [6]:
def criterion(im_fs: torch.Tensor, im_us: torch.Tensor):
    '''
    @parameter im_us: undersampled image (2D)
    @parameter im_fs: fully sampled image (2D)
    should be on GPU device for fast computation
    '''
    
    # use l1 loss between two images
    criterion = nn.L1Loss()
    
    # can add more fancy loss functions here later
    
    return criterion(im_us, im_fs)

def metrics(im_fs: torch.Tensor, im_us: torch.Tensor):
    '''
    @parameter im_us: undersampled image (2D)
    @parameter im_fs: fully sampled image (2D)
    should be on GPU device for fast computation
    '''

    # change to ndarray
    im_us = transforms.tensor_to_complex_np(im_us.cpu().detach())
    im_fs = transforms.tensor_to_complex_np(im_fs.cpu().detach())
    
    # convert complex nums to magnitude
    im_us = np.absolute(im_us)
    im_fs = np.absolute(im_fs)
    
    im_us = im_us.reshape(
        (im_us.shape[2], im_us.shape[3])
    )
    
    im_fs = im_fs.reshape(
        (im_fs.shape[2], im_fs.shape[3])
    )
    
    # psnr
    psnr = skimage.metrics.peak_signal_noise_ratio(
        im_fs, 
        im_us, 
        data_range = np.max(im_fs) - np.min(im_fs)
    )
    
    #nrmse
    nrmse = skimage.metrics.normalized_root_mse(im_fs, im_us)
    
    # ssim
    # normalize 0 to 1
    im_fs -= np.min(im_fs)
    im_fs /= np.max(im_fs)
    im_us -= np.min(im_us)
    im_us /= np.max(im_us)
    
    ssim = skimage.metrics.structural_similarity(im_fs, im_us, data_range = 1)
    
    return ssim, psnr, nrmse

In [7]:
"""
=========== Universal Single-task Trainer =========== 
code modified from https://github.com/lorenmt/mtan/blob/master/im2im_pred/utils.py
"""


def single_task_trainer(
    train_loader, val_loader,
    train_ratios, val_ratios,
    single_task_model, device, writer, 
    optimizer, scheduler,
    opt = 0, total_epochs=2 ###opt###
):
    
    contrast_count = len(list(train_ratios.values()))
    ratio = f"N={'_N='.join(str(key) for key in train_ratios.values())}"
    
    best_val_loss = np.infty
    
    # contains info for all epochs and contrasts
    avg_cost = {
        contrast : np.zeros([total_epochs, 8])
        for contrast in train_ratios.keys()
    }
    avg_cost['overall'] = np.zeros([total_epochs, 8])
    
    for epoch in range(total_epochs):
        # contains info for single batch of a single epoch
        cost = np.zeros(8, dtype = np.float32)

        # train the data
        single_task_model.train()
        train_batch = len(train_loader)
        train_dataset = iter(train_loader)
        
        for kspace, mask, sens, im_fs, contrast in train_dataset:
            contrast = contrast[0] # torch dataset loader returns as tuple
            kspace, mask = kspace.to(device), mask.to(device)
            sens, im_fs = sens.to(device), im_fs.to(device)

            optimizer.zero_grad()
            _, im_us = single_task_model(kspace, mask, sens) # forward pass
            loss = criterion(im_fs, im_us)
            loss.backward()
            optimizer.step()
            
            # losses and metrics are averaged over epoch
            # L1 loss for now
            cost[0] = loss.item() 
            # ssim, psnr, nrmse
            cost[1], cost[2], cost[3] = metrics(im_fs, im_us)

            # update overall
            avg_cost[contrast][epoch, :4] += cost[:4] / train_ratios[contrast]
            avg_cost['overall'][epoch, :4] += cost[:4] / train_batch

        
        # get losses and metrics for each epoch
        single_task_model.eval()
        with torch.no_grad():
            val_batch = len(val_loader)
            
#             # training data (calculate during training) start calculating loss / metrics after a few epochs
#             train_dataset = iter(train_loader)
#             for kspace, mask, sens, im_fs, contrast in train_dataset:
#                 contrast = contrast[0]
#                 kspace, mask = kspace.to(device), mask.to(device)
#                 sens, im_fs = sens.to(device), im_fs.to(device)

#                 _, im_us = single_task_model(kspace, mask, sens) # forward pass
#                 loss = criterion(im_fs, im_us)
                
#                 # L1 loss for now
#                 cost[0] = loss.item() 
#                 # ssim, psnr, nrmse
#                 cost[1], cost[2], cost[3] = metrics(im_fs, im_us)

#                 # update overall
#                 avg_cost[contrast][epoch, :4] += cost[:4] / train_ratios[contrast]
#                 avg_cost['overall'][epoch, :4] += cost[:4] / train_batch

        
            # validation data
            val_dataset = iter(val_loader)
            for val_idx, val_data in enumerate(val_dataset):
                kspace, mask, sens, im_fs, contrast = val_data
                contrast = contrast[0]
                kspace, mask = kspace.to(device), mask.to(device)
                sens, im_fs = sens.to(device), im_fs.to(device)

                _, im_us = single_task_model(kspace, mask, sens) # forward pass
                loss = criterion(im_fs, im_us)
                
                # L1 loss for now
                cost[4] = loss.item()
                # ssim, psnr, nrmse
                cost[5], cost[6], cost[7] = metrics(im_fs, im_us)
                
                # update overall
                avg_cost[contrast][epoch, 4:] += cost[4:] / val_ratios[contrast]
                avg_cost['overall'][epoch, 4:] += cost[4:] / val_batch
                
               # visualize reconstruction every few epochs
                if epoch % 1 == 0 and True: ###opt###
                    # if single contrast, only visualize 17th slice
                    if (
                        val_idx == 17 or 
                        val_idx == val_batch - 17 and contrast_count > 1
                    ):
                        writer.add_figure(
                            f'{ratio}/{contrast}', 
                            plot_quadrant(im_fs, im_us),
                            epoch, close = True,
                        )
                               
                
        # early stopping        
        if avg_cost['overall'][epoch, 4] < best_val_loss:
            best_val_loss = avg_cost['overall'][epoch, 4]
            torch.save(
                single_task_model.state_dict(), 
                f'models/experiment_name/{ratio}_l1.pt'
            ) ###opt###
            
        scheduler.step()
        
        print(f'''
        >Epoch: {epoch + 1:04d}
        TRAIN: loss {avg_cost['overall'][epoch, 0]:.4f} | ssim {avg_cost['overall'][epoch, 1]:.4f} | psnr {avg_cost['overall'][epoch, 2]:.4f} | nrmse {avg_cost['overall'][epoch, 3]:.4f} 
        VAL: loss {avg_cost['overall'][epoch, 4]:.4f} | ssim {avg_cost['overall'][epoch, 5]:.4f} | psnr {avg_cost['overall'][epoch, 6]:.4f} | nrmse {avg_cost['overall'][epoch, 7]:.4f}
        
        ''')
    
    # write to tensorboard
    ###opt###
    if True:
        write_tensorboard(writer, avg_cost, single_task_model, ratio, opt)   

In [8]:
def main(scarcities, dataset_names):
    basedirs = [
        f'/mnt/dense/vliu/summer_dset/{dataset_name}'
        for dataset_name in dataset_names
    ]
    
    for scarcity in scarcities:
        print(f'experiment w scarcity {scarcity}')
        train_dloader = genDataLoader(
            [f'{basedir}/Train' for basedir in basedirs], # choose randomly
            [4, scarcity] # downsample
        )

        val_dloader = genDataLoader(
            [f'{basedir}/Val' for basedir in basedirs], # choose randomly
            [4, scarcity], # no downsampling
            shuffle = False,
        )
        print('generated dataloaders')

        # other inputs to STL wrapper
        device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
        varnet = VarNet().to(device)

        optimizer = torch.optim.Adam(varnet.parameters(),lr=0.001)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)
        print('start training')
        single_task_trainer(
            train_dloader[0], val_dloader[0], 
            train_dloader[1], val_dloader[1], # ratios dicts
            varnet, device, writer_tensorboard,
            optimizer, scheduler,
            opt = 0, total_epochs=2
        )

In [9]:
if not os.path.isdir('modelsa/hi'):
    os.makedirs('modelsa/hi')

In [10]:
# two contrasts
# datasets
scarcities = [2, 4]

dataset_names = [
    'div_coronal_pd',
    'div_coronal_pd_fs',
]

# # single contrast
# # datasets
# scarcities = [4]

# dataset_names = [
#     'div_coronal_pd',
# ]

run_name = f"demo/STL_{'_'.join(dataset_names)}"
writer_tensorboard = SummaryWriter(log_dir = run_name)

In [13]:
main(scarcities, dataset_names)
writer_tensorboard.flush()
writer_tensorboard.close()

experiment w scarcity 2
generated dataloaders
start training

        >Epoch: 0001
        TRAIN: loss 0.0522 | ssim 0.7002 | psnr 29.3284 | nrmse 0.2586 
        VAL: loss 0.0418 | ssim 0.7517 | psnr 30.4592 | nrmse 0.2188
        
        

        >Epoch: 0002
        TRAIN: loss 0.0450 | ssim 0.7466 | psnr 30.9216 | nrmse 0.2141 
        VAL: loss 0.0405 | ssim 0.7717 | psnr 30.9871 | nrmse 0.2072
        
        
experiment w scarcity 4
generated dataloaders
start training

        >Epoch: 0001
        TRAIN: loss 0.0520 | ssim 0.6841 | psnr 28.3595 | nrmse 0.2706 
        VAL: loss 0.0437 | ssim 0.7498 | psnr 30.1756 | nrmse 0.2248
        
        

        >Epoch: 0002
        TRAIN: loss 0.0413 | ssim 0.7566 | psnr 30.0946 | nrmse 0.2175 
        VAL: loss 0.0423 | ssim 0.7435 | psnr 30.8858 | nrmse 0.2104
        
        


In [8]:
filedir = f"models/{opt.experimentname}_{opt.network}_{'_'.join(opt.datasets)}"

In [7]:
model_filedir = f"models/STL_varnet_div_coronal_pd_div_coronal_pd_fs"
modelnames = os.listdir(model_filedir)

In [8]:
modelnames

['N=38_N=40_l1.pt', 'N=494_N=511_l1.pt', 'N=494_N=259_l1.pt']

In [9]:
dataset_names = [
    'div_coronal_pd',
    'div_coronal_pd_fs',
]

basedirs = [
    f'/mnt/dense/vliu/summer_dset/{dataset_name}'
    for dataset_name in dataset_names
]

test_dloader = genDataLoader(
    [f'{basedir}/Test' for basedir in basedirs], # choose randomly
    [4, 4],
    shuffle = False,
)

In [10]:
import pandas as pd

In [25]:
device = 'cuda:3'
with torch.no_grad():
    the_model = VarNet().to(device)
df_row = np.zeros([len(modelnames), 6])
for idx, model in enumerate(modelnames):

    
    model_filepath = os.path.join(model_filedir, model)

    # load model
    the_model.load_state_dict(torch.load(model_filepath))
    
    # iterate thru test set
    ####################### separate two contrasts in test set ############################
    test_batch = len(test_dloader)
    test_dataset = iter(test_dloader[0])
    for test_idx, test_data in enumerate(test_dataset):
        with torch.no_grad():
            
            kspace, mask, sens, im_fs, contrast = test_data
            contrast = contrast[0]
            kspace, mask = kspace.to(device), mask.to(device)
            sens, im_fs = sens.to(device), im_fs.to(device)

            _, im_us = the_model(kspace, mask, sens) # forward pass
            
            # L1 loss for now
            loss = criterion(im_fs, im_us)
            df_row[idx][0] += loss.item() / test_batch
            
            # ssim, psnr, nrmse
            ssim, psnr, nrmse = metrics(im_fs, im_us)
            for j in range(3):
                df_row[idx][j + 1] += metrics(im_fs, im_us)[j] / test_batch
    
    # define x axis
    ratio_1 = int(model.split('_')[0].split('=')[1])
    ratio_2 = int(model.split('_')[1].split('=')[1])
    df_row[idx][4] = ratio_1
    df_row[idx][5] = ratio_2

In [26]:
df_row

array([[   1.45900832,   25.10502428, 1025.94458722,    7.53519702,
          38.        ,   40.        ],
       [   1.18314077,   27.24668827, 1147.36300301,    5.34349928,
         494.        ,  511.        ],
       [   1.20969392,   26.73536128, 1130.10879287,    5.61200281,
         494.        ,  259.        ]])

In [36]:
contrast1 = 'div coronal proton density'
contrast2 = 'div coronal proton density fat suppression'

In [40]:
df2 = pd.DataFrame(
    df_row,
    columns=['loss', 'ssim', 'psnr', 'nrmse', contrast1, contrast2]
)
df2 = df2.drop(0)

In [45]:
import bokeh.plotting

In [41]:
df2

Unnamed: 0,loss,ssim,psnr,nrmse,div coronal proton density,div coronal proton density fat suppression
1,1.183141,27.246688,1147.363003,5.343499,494.0,511.0
2,1.209694,26.735361,1130.108793,5.612003,494.0,259.0


In [50]:
bokeh.io.output_notebook()
# For convenience
x = contrast2
y = "loss"

# Make figure
p = bokeh.plotting.figure(
    width=600,
    height=400,
    x_axis_label=x,
    y_axis_label=y,
    tooltips=[
        ("loss", "@{loss}"),
        ("contrast", f"@{contrast2}"),
    ],
)

# Add glyphs
p.circle(
    source=df2,
    x=x,
    y=y,
    legend_label="first contrast",
)

p.line(
    source=df2,
    x=x,
    y=y,
    legend_label="first contrast",
)

p.legend.location = "top_right"
p.legend.click_policy = "hide"
p.title = "loss"

bokeh.io.show(p)