# Runs inference on the holdout set

In [6]:
from data.CTDataSet import CTDicomSlices, DatasetManager
from data.CustomTransforms import Window, Imagify, Normalize

from models.UNet_L import UNet

import albumentations as A
from torchvision import transforms

from torch.utils.data import DataLoader

import torch

from torchmetrics.classification import BinaryConfusionMatrix
from pytorch_lightning import Trainer, loggers as pl_loggers

## Get data

In [None]:
holdout_set = "/root/checkpoints/holdout_set.txt"  # may need full path to this file
dataset = '/root/imager/organized_dataset_2'

### Image preprocessing

In [8]:
# important constants

WL = 50
WW = 200

img_size = 256

mean = 61.0249
std = 78.3195

batch_size = 32
num_workers = 2

In [9]:
prep = transforms.Compose([Window(WL, WW), Imagify(WL, WW), Normalize(mean, std)])

resize_tsfm = A.Compose([A.Resize(img_size, img_size)],
            additional_targets={"image1": 'image', "mask1": 'mask'})

In [10]:
dsm = DatasetManager.load_train_val_test(dataset, holdout_set, holdout_set, holdout_set) # train, val and test will be identical. we will use test

_, _, test_dicoms = dsm.get_dicoms() # DICOM glob is preset in the class file

test_ds = CTDicomSlices(test_dicoms, preprocessing = prep, resize_transform = resize_tsfm, n_surrounding=1)

datasets = {}
datasets['train'] = None
datasets['val'] = None
datasets['test'] = test_ds

## Get models

In [11]:
class UNetTester(UNet):
    """ Same as UNet but with a tester method that tracks more metrics. This allows us to avoid editing the original code. """
    def __init__(self, datasets, backbone :str = 'resnet34', encoder_weights :str = 'imagenet',
                 classes :int = 2, activation :str = 'softmax', batch_size :int = 32,
                 lr = 0.0001, dl_workers = 8, optimizer_params = None, in_channels=3,
                 loss = 'dice'):
        super().__init__(datasets, backbone = backbone, encoder_weights=encoder_weights, classes=classes,
                activation=activation, batch_size=batch_size, lr=lr, dl_workers=dl_workers,
                optimizer_params=optimizer_params, in_channels=in_channels, loss=loss)
        self.bcm = BinaryConfusionMatrix()
        
    def test_step(self, batch, batch_nb):
        images, masks, _, _ = batch

        y_hat = self(images)

        # loss dim is [batch, 1, img_x, img_y]
        # need to get rid of the second dimension so
        # size matches with mask
        loss = self.loss(y_hat, masks)

        # binary_classification_metrics
        ground_truth = torch.amax(masks, (1, 2)) 
        
        preds = torch.amax(y_hat[:, 0, :, :], (1, 2))   # y_hat dim is (batch, 2, img_sz, img_sz). 2 for 2 classes. We only need the first. To make dimension match ground truth masks

        conf_matrix = self.bcm(preds, ground_truth)

        # Logs
        #tensorboard_logs = {'val_loss': loss}
        return {'test_loss': loss, "tp": conf_matrix[0, 0], "fn": conf_matrix[0, 1], "fp": conf_matrix[1, 0], "tn": conf_matrix[1, 1]} #, 'log': tensorboard_logs}

    def test_epoch_end(self, outputs):
        test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean()
        self.log('test_loss_mean', test_loss_mean, logger=True)

        tp = torch.stack([x['tp'] for x in outputs]).sum()
        fn = torch.stack([x['fn'] for x in outputs]).sum()
        fp = torch.stack([x['fp'] for x in outputs]).sum()
        tn = torch.stack([x['tn'] for x in outputs]).sum()

        self.log('tp', tp, logger=True)
        self.log('fn', fn, logger=True)
        self.log('fp', fp, logger=True)
        self.log('tn', tn, logger=True)

        return {'test_loss': test_loss_mean, "tp": tp, "fn": fn, "fp": fp, "tn": tn}

In [None]:
ckpt_imagenet = "/root/checkpoints/imagenet.ckpt"
ckpt_random = "/root/checkpoints/random_nopretrain.ckpt"
ckpt_jigsaw = "/root/checkpoints/ckpt_jigsaw_classic.ckpt"
ckpt_felz = "/root/checkpoints/felz.ckpt"
ckpt_jigsaw_sr = "/root/checkpoints/jigsaw_sr.ckpt"

In [13]:
def get_model(path):
    return UNetTester.load_from_checkpoint(path, datasets=datasets, in_channels=3, classes=2)

def get_model_dir(ckpt):
    return ckpt[0:ckpt.find('/logs/default')]

# imagenet, random, felz, jigsaw, jigsawSR

#model_imagenet = get_model(ckpt_imagenet)
#model_random = get_model(ckpt_random)
#model_jigsaw = get_model(ckpt_jigsaw)
#model_felz = get_model(ckpt_felz)
#model_jigsaw_sr = get_model(ckpt_jigsaw_sr)

In [None]:
# felz_seg was trained using a different dataset split without a holdout. make sure dataset is set to
# felz before running felz_seg
model_ckpts = [ckpt_imagenet, ckpt_random, ckpt_jigsaw, ckpt_felz, ckpt_jigsaw_sr]
ckpt_labels = ["imagenet", "random", "jigsaw_classic", "felz_seg", "jigsaw_sr"]

for m_ckpt, m_label in zip(model_ckpts, ckpt_labels):
    model = get_model(m_ckpt)
    model_dir = '/root/test_outputs/'
    
    tb_logger = pl_loggers.TensorBoardLogger('{}-test-{}'.format(model_dir, m_label))

    trainer = Trainer(gpus=1, accelerator='gpu', precision=16, logger = tb_logger, default_root_dir=model_dir)

    trainer.test(model = model)


  rank_zero_deprecation(
Using 16bit native Automatic Mixed Precision (AMP)
  scaler = torch.cuda.amp.GradScaler()
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

  return self._call_impl(*args, **kwargs)


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Runningstage.testing metric      DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
           fn                      296.0
           fp                      432.0
     test_loss_mean         0.22868618369102478
           tn                     1718.0
           tp                     8710.0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


