<a href="https://colab.research.google.com/github/liamchalcroft/RectAngle/blob/main/demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! git clone https://github.com/liamchalcroft/RectAngle.git

In [None]:
! pip uninstall rectangle -y; cd RectAngle; git pull; pip install . -q

In [None]:
import rectangle as rect
import h5py
import matplotlib.pyplot as plt
import torch
from scipy.stats import linregress
import numpy as np
from datetime import date

In [None]:
# f = h5py.File('dataset70-200.h5', 'r')
# f = <----- Your data here

In [None]:
if torch.cuda.is_available():
  device = torch.device('cuda')
  torch.backends.cudnn.benchmark = True
else:
  device = torch.device('cpu')

# model_ = rect.model.networks.UNet(n_layers=5, device=device, gate='attention')
# model_ = rect.model.networks.UNet(n_layers=5, device=device, gate=None)

In [None]:
train_ix, test_ix, _ = rect.utils.io.train_val_test(f, ratio=(0.8,0.2,0))

train_keys, test_keys = rect.utils.io.key_gen(f, train_ix), \
                        rect.utils.io.key_gen(f, test_ix)

train_data,  test_data = rect.utils.io.H5DataLoader(f, train_keys,'random'), \
                          rect.utils.io.H5DataLoader(f, test_keys, label='vote')

test_plot_data = rect.utils.io.TestPlotLoader(f, test_keys, label='vote')

In [None]:
model_rand = rect.model.networks.UNet(n_layers=5, device=device, gate='attention')
train_data_rand = rect.utils.io.H5DataLoader(f, train_keys, label='random')
trainer_rand = rect.utils.train.Trainer(model_rand, ensemble=5, outdir='./random')
trainer_rand.train(train_data, train_pre=[rect.utils.transforms.z_score(), rect.utils.transforms.Flip(), rect.utils.transforms.Affine(), rect.utils.transforms.SpeckleNoise()], 
              val_pre=[rect.utils.transforms.z_score()])
trainer_rand.test(test_data, test_pre=[rect.utils.transforms.z_score()], 
             test_post=[rect.utils.transforms.Binary(), rect.utils.transforms.KeepLargestComponent()])
# trainer_rand.test(test_plot_data, test_pre=[rect.utils.transforms.z_score()], 
#              test_post=[rect.utils.transforms.Binary(), rect.utils.transforms.KeepLargestComponent()])

In [None]:
model_vote = rect.model.networks.UNet(n_layers=5, device=device, gate='attention')
train_data_vote = rect.utils.io.H5DataLoader(f, train_keys, label='vote')
trainer_vote = rect.utils.train.Trainer(model_vote, ensemble=5, outdir='./vote')
trainer_vote.train(train_data, train_pre=[rect.utils.transforms.z_score(), rect.utils.transforms.Flip(), rect.utils.transforms.Affine(), rect.utils.transforms.SpeckleNoise()], 
              val_pre=[rect.utils.transforms.z_score()])
trainer_vote.test(test_data, test_pre=[rect.utils.transforms.z_score()], 
             test_post=[rect.utils.transforms.Binary(), rect.utils.transforms.KeepLargestComponent()])
# trainer_vote.test(test_plot_data, test_pre=[rect.utils.transforms.z_score()], 
#              test_post=[rect.utils.transforms.Binary(), rect.utils.transforms.KeepLargestComponent()])

In [None]:
model_mean = rect.model.networks.UNet(n_layers=5, device=device, gate='attention')
train_data_mean = rect.utils.io.H5DataLoader(f, train_keys, label='mean')
trainer_mean = rect.utils.train.Trainer(model_mean, ensemble=5, outdir='./mean')
trainer_mean.train(train_data, train_pre=[rect.utils.transforms.z_score(), rect.utils.transforms.Flip(), rect.utils.transforms.Affine(), rect.utils.transforms.SpeckleNoise()], 
              val_pre=[rect.utils.transforms.z_score()])
trainer_mean.test(test_data, test_pre=[rect.utils.transforms.z_score()], 
             test_post=[rect.utils.transforms.Binary(), rect.utils.transforms.KeepLargestComponent()])
# trainer_mean.test(test_plot_data, test_pre=[rect.utils.transforms.z_score()], 
#              test_post=[rect.utils.transforms.Binary(), rect.utils.transforms.KeepLargestComponent()], oname='test')

In [None]:
oname = date.today()
oname = oname.strftime("%b-%d-%Y")

rand_dice = np.genfromtxt('./random/testing/table/dice_Apr-22-2021.csv', delimiter=',')
rand_prec = np.genfromtxt('./random/testing/table/precision_Apr-22-2021.csv', delimiter=',')
rand_rec = np.genfromtxt('./random/testing/table/recall_Apr-22-2021.csv', delimiter=',')

vote_dice = np.genfromtxt('./vote/testing/table/dice_Apr-22-2021.csv', delimiter=',')
vote_prec = np.genfromtxt('./vote/testing/table/precision_Apr-22-2021.csv', delimiter=',')
vote_rec = np.genfromtxt('./vote/testing/table/recall_Apr-22-2021.csv', delimiter=',')

mean_dice = np.genfromtxt('./mean/testing/table/dice_Apr-22-2021.csv', delimiter=',')
mean_prec = np.genfromtxt('./mean/testing/table/precision_Apr-22-2021.csv', delimiter=',')
mean_rec = np.genfromtxt('./mean/testing/table/recall_Apr-22-2021.csv', delimiter=',')

In [None]:
# create violin plots of dice scores

plt.figure(figsize=(16,8))

plt.subplot(131)
plt.violinplot([rand_dice, vote_dice, mean_dice], showmeans=True)
plt.ylabel('Dice Coefficient')
plt.xticks([1, 2, 3], ['Random Sampling', 'Vote Sampling', 'Mean (soft) Sampling'])

plt.subplot(132)
plt.violinplot([rand_prec, vote_prec, mean_prec], showmeans=True)
plt.ylabel('Precision')
plt.xticks([1, 2, 3], ['Random Sampling', 'Vote Sampling', 'Mean (soft) Sampling'])

plt.subplot(133)
plt.violinplot([rand_rec, vote_rec, mean_rec], showmeans=True)
plt.ylabel('Recall')
plt.xticks([1, 2, 3], ['Random Sampling', 'Vote Sampling', 'Mean (soft) Sampling'])

plt.show()

In [None]:
def bland_altman_plot(data1, data2, *args, **kwargs):
    mean      = np.mean([data1, data2], axis=0)
    diff      = data1 - data2                   # Difference between data1 and data2
    md        = np.mean(diff)                   # Mean of the difference
    sd        = np.std(diff, axis=0)            # Standard deviation of the difference
    fit = linregress(mean, diff)
    rsq = fit.rvalue**2
    x = np.linspace(mean.min(), mean.max())
    y = fit.slope * x + fit.intercept

    plt.scatter(mean, diff, *args, **kwargs)
    plt.plot(x, y, '--', c='r')
    plt.text(x[-1], y[-1], '$R^2$ = {:.3f}'.format(rsq), c='r')
    plt.axhline(md,           color='gray', linestyle='--')
    plt.axhline(md + 1.96*sd, color='gray', linestyle='--')
    plt.axhline(md - 1.96*sd, color='gray', linestyle='--')

In [None]:
plt.figure(figsize=(16,8))
plt.subplot(131)
bland_altman_plot(rand_dice, vote_dice)
plt.ylim([-1.1, 1.1])
plt.xlim([-0.1, 1.1])
plt.xlabel('Mean Dice')
plt.ylabel('Difference in Dice')
plt.title('Random vs. Vote')
plt.subplot(132)
bland_altman_plot(mean_dice, vote_dice)
plt.ylim([-1.1, 1.1])
plt.xlim([-0.1, 1.1])
plt.title('Mean vs. Vote')
plt.subplot(133)
bland_altman_plot(rand_dice, mean_dice)
plt.ylim([-1.1, 1.1])
plt.xlim([-0.1, 1.1])
plt.title('Random vs. Mean')
plt.show()

Classifier screening

In [None]:
class_train_ix, class_val_ix, _ = rect.utils.io.train_val_test(f, ratio=(0.6,0.2,0.2))

class_train_keys, class_val_keys = rect.utils.io.key_gen(f, class_train_ix), rect.utils.io.key_gen(f, class_val_ix)

class_train_data, class_val_data = rect.utils.io.ClassifyDataLoader(f, class_train_keys), rect.utils.io.ClassifyDataLoader(f, class_val_keys)

In [None]:
class_train_data = rect.utils.io.ClassifyDataLoader(f, class_train_keys)

In [None]:
class_model = rect.model.networks.MakeDenseNet(freeze_weights=False).to(device)

In [None]:
class_trainer = rect.utils.train.ClassTrainer(class_model, outdir='./classlogs',
                                         ensemble=None, early_stop=1000)

In [None]:
class_trainer.train(class_train_data, class_val_data)

In [None]:
threshRange = np.linspace(0, 0.6, 20)

for i, thresh in enumerate(threshRange):
    print('Threshold = {}'.format(thresh))
    test_screen_data = rect.utils.io.PreScreenLoader(class_model.eval(), f, test_keys, label='vote', threshold = thresh)
    trainer_vote.test(test_screen_data, test_pre=[rect.utils.transforms.z_score()], 
                test_post=[rect.utils.transforms.Binary(), rect.utils.transforms.KeepLargestComponent()], oname='class_thresh_{}'.format(i))

In [None]:
diceMu = []
diceSig = []
precMu = []
precSig = []
recMu = []
recSig = []

for i in range(len(threshRange)):
  dice_ = np.genfromtxt('./vote/testing/table/dice_class_thresh_{}.csv'.format(i), delimiter=',')
  prec_ = np.genfromtxt('./vote/testing/table/precision_class_thresh_{}.csv'.format(i), delimiter=',')
  rec_ = np.genfromtxt('./vote/testing/table/recall_class_thresh_{}.csv'.format(i), delimiter=',')

  diceMu.append(np.mean(dice_))
  diceSig.append(np.std(dice_))
  precMu.append(np.mean(prec_))
  precSig.append(np.std(prec_))
  recMu.append(np.mean(rec_))
  recSig.append(np.std(rec_))

diceMu = np.array(diceMu)
diceSig = np.array(diceSig)
precMu = np.array(precMu)
precSig = np.array(precSig)
recMu = np.array(recMu)
recSig = np.array(recSig)

baseDiceMu = np.mean(vote_dice)
baseDiceSig = np.std(vote_dice)
basePrecMu = np.mean(vote_prec)
basePrecSig = np.std(vote_prec)
baseRecMu = np.mean(vote_rec)
baseRecSig = np.std(vote_rec)

In [None]:
plt.figure(figsize=(16,8))

plt.subplot(131)
plt.plot(threshRange, diceMu)
plt.fill_between(threshRange, diceMu-diceSig, diceMu+diceSig, alpha=0.2)
plt.plot(threshRange, len(threshRange)*[baseDiceMu])
plt.fill_between(threshRange, len(threshRange)*[baseDiceMu-baseDiceSig], len(threshRange)*[baseDiceMu+baseDiceSig], alpha=0.1)
plt.xlabel('Classifier threshold')
plt.ylabel('DSC')

plt.subplot(132)
plt.plot(threshRange, precMu)
plt.fill_between(threshRange, precMu-precSig, precMu+precSig, alpha=0.2)
plt.plot(threshRange, len(threshRange)*[basePrecMu])
plt.fill_between(threshRange, len(threshRange)*[basePrecMu-basePrecSig], len(threshRange)*[basePrecMu+basePrecSig], alpha=0.1)
plt.ylabel('Precision')

plt.subplot(133)
plt.plot(threshRange, recMu)
plt.fill_between(threshRange, recMu-recSig, recMu+recSig, alpha=0.2)
plt.plot(threshRange, len(threshRange)*[baseRecMu])
plt.fill_between(threshRange, len(threshRange)*[baseRecMu-baseRecSig], len(threshRange)*[baseRecMu+baseRecSig], alpha=0.1)
plt.ylabel('Recall')

plt.show()

In [None]:
screen_dice = np.genfromtxt('./vote/testing/table/dice_class_thresh_{}.csv'.format(len(threshRange)-1), delimiter=',')

plt.figure(figsize=(8,6))
bland_altman_plot(screen_dice, vote_dice)
plt.ylim([-1.1, 1.1])
plt.xlim([-0.1, 1.1])
plt.xlabel('Mean Dice')
plt.ylabel('Difference in Dice')
plt.show()