# Runs inference on the holdout set

In [14]:
from data.CTDataSet import CTDicomSlices, DatasetManager
from data.CustomTransforms import Window, Imagify, Normalize

from models.UNet_L import UNet

import albumentations as A
from torchvision import transforms

from torch.utils.data import DataLoader

import torch

from torchmetrics.classification import BinaryConfusionMatrix
from pytorch_lightning import Trainer, loggers as pl_loggers

## Get data

In [15]:
holdout_set = "holdout_set.txt"  # may need full path to this file
dataset = '/home/hussam/imager/organized_dataset_2'

### Image preprocessing

In [16]:
# important constants

WL = 50
WW = 200

img_size = 256

mean = 61.0249
std = 78.3195

batch_size = 32
num_workers = 2

In [17]:
prep = transforms.Compose([Window(WL, WW), Imagify(WL, WW), Normalize(mean, std)])

resize_tsfm = A.Compose([A.Resize(img_size, img_size)],
            additional_targets={"image1": 'image', "mask1": 'mask'})

In [18]:
dsm = DatasetManager.load_train_val_test(dataset, holdout_set, holdout_set, holdout_set) # train, val and test will be identical. we will use test

_, _, test_dicoms = dsm.get_dicoms() # DICOM glob is preset in the class file

test_ds = CTDicomSlices(test_dicoms, preprocessing = prep, resize_transform = resize_tsfm, n_surrounding=1, mask_is_255=False)

test_dl = DataLoader(test_ds, batch_size=batch_size, num_workers = num_workers, shuffle=False)

## Get models

In [None]:
class UNetTester(UNet):
    """ Same as UNet but with a tester method that tracks more metrics. This allows us to avoid editing the original code. """
    def __init__(self, datasets, backbone :str = 'resnet34', encoder_weights :str = 'imagenet',
                 classes :int = 2, activation :str = 'softmax', batch_size :int = 32,
                 lr = 0.0001, dl_workers = 8, optimizer_params = None, in_channels=3,
                 loss = 'dice'):
        super().__init__(datasets, backbone = backbone, encoder_weights=encoder_weights, classes=classes,
                activation=activation, batch_size=batch_size, lr=lr, dl_workers=dl_workers,
                optimizer_params=optimizer_params, in_channels=in_channels, loss=loss)
        self.bcm = BinaryConfusionMatrix()

        
    def test_step(self, batch, batch_nb):
        images, masks, _, _ = batch

        y_hat = self(images)

        # loss dim is [batch, 1, img_x, img_y]
        # need to get rid of the second dimension so
        # size matches with mask
        loss = self.loss(y_hat, masks)

        # binary_classification_metrics
        ground_truth = torch.amax(masks, (1, 2)) / 256  # mask pixel values are 0 or 256. This changes then to 0 or 1
        ground_truth = torch.round(ground_truth)
        
        preds = torch.amax(y_hat[:, 0, :, :], (1, 2))   # y_hat dim is (batch, 2, img_sz, img_sz). 2 for 2 classes. We only need the first. To make dimension match ground truth masks

        conf_matrix = self.bcm(preds, ground_truth)

        # Logs
        #tensorboard_logs = {'val_loss': loss}
        return {'test_loss': loss, "tp": conf_matrix[0, 0], "fn": conf_matrix[0, 1], "fp": conf_matrix[1, 0], "tn": conf_matrix[1, 1]} #, 'log': tensorboard_logs}

    def test_epoch_end(self, outputs):
        test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean()
        self.log('test_loss_mean', test_loss_mean, logger=True)

        tp = torch.stack([x['tp'] for x in outputs]).sum()
        fn = torch.stack([x['fn'] for x in outputs]).sum()
        fp = torch.stack([x['fp'] for x in outputs]).sum()
        tn = torch.stack([x['tn'] for x in outputs]).sum()

        return {'test_loss': test_loss_mean, "tp": tp, "fn": fn, "fp": fp, "tn": tn}

In [55]:
ckpt_imagenet = "/mnt/d/pretrainer/model_runs/best_model/logs/default/version_0/checkpoints/epoch=99-step=226899.ckpt"
ckpt_random = "/mnt/d/pretrainer/model_runs/random_nopretrain_nooptim/logs/default/version_0/checkpoints/epoch=99-step=226899.ckpt"
ckpt_jigsaw = "/mnt/d/pretrainer-2/model_runs/2021-05-31-17_34_04/logs/default/version_0/checkpoints/epoch=99-step=226899.ckpt"
ckpt_felz = "/mnt/d/pretrainer-2/model_runs/2021-05-19-15_14_53/logs/default/version_0/checkpoints/epoch=99-step=250299.ckpt"
ckpt_jigsaw_sr = "/mnt/d/pretrainer-2/model_runs/after jigsaw SR with HNSCC_test/logs/default/version_0/checkpoints/last.ckpt"

In [56]:
def get_model(path):
    return UNetTester.load_from_checkpoint(path, datasets=None, in_channels=3, classes=2)

def get_model_dir(ckpt):
    return ckpt[0:ckpt.find('/logs/default')]

# imagenet, random, felz, jigsaw, jigsawSR

#model_imagenet = get_model(ckpt_imagenet)
#model_random = get_model(ckpt_random)
#model_jigsaw = get_model(ckpt_jigsaw)
#model_felz = get_model(ckpt_felz)
#model_jigsaw_sr = get_model(ckpt_jigsaw_sr)

In [57]:
model_ckpts = [ckpt_imagenet]
ckpt_labels = ["imagenet"]

for m_ckpt, m_label in zip(model_ckpts, ckpt_labels):
    model = get_model(m_ckpt)
    model_dir = get_model_dir(m_ckpt)
    
    tb_logger = pl_loggers.TensorBoardLogger('{}-test-{}/logs/'.format(model_dir, m_label))

    trainer = Trainer(gpus=1, accelerator='cpu', precision=16, logger = tb_logger, default_root_dir=model_dir)

    trainer.test(model = model, dataloaders=test_dl)


# do trainer
# test model
# follow up on logs

  rank_zero_deprecation(
  rank_zero_warn(
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /mnt/d/pretrainer/model_runs/best_model-test-imagenet/logs/lightning_logs
  rank_zero_warn(


Testing DataLoader 0:   0%|          | 0/311 [00:00<?, ?it/s]

  return self._call_impl(*args, **kwargs)



y_hat: torch.Size([32, 2, 256, 256]) - masks: torch.Size([32, 256, 256])

mean of y_hat: 0.5

Y-HAT!!!!!!: 
 tensor([[[7.1886e-09, 2.2737e-12, 8.9813e-12,  ..., 2.1316e-12,
          6.1675e-12, 1.4319e-08],
         [3.7517e-12, 7.5460e-17, 3.3827e-16,  ..., 2.0470e-16,
          7.1471e-16, 4.8431e-11],
         [3.2969e-12, 2.0470e-16, 7.6605e-15,  ..., 8.0838e-16,
          3.4139e-15, 1.4916e-10],
         ...,
         [3.9790e-12, 4.9266e-16, 1.1158e-14,  ..., 4.9682e-15,
          1.5127e-15, 1.3188e-10],
         [2.7626e-11, 1.7139e-15, 4.1078e-15,  ..., 4.6629e-15,
          1.8180e-15, 7.0486e-11],
         [2.0396e-07, 8.0036e-11, 1.8986e-11,  ..., 1.2369e-10,
          6.6393e-11, 8.7544e-08]],

        [[7.1886e-09, 2.2737e-12, 8.9813e-12,  ..., 2.1316e-12,
          6.1675e-12, 1.3446e-08],
         [3.7517e-12, 7.5460e-17, 3.3827e-16,  ..., 2.0470e-16,
          7.1471e-16, 4.5475e-11],
         [3.2969e-12, 2.0470e-16, 7.6605e-15,  ..., 7.1471e-16,
          3.0115e-

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
