In [1]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, ConcatDataset

import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
import random
import json

In [2]:
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss
from segmentation_models_pytorch.utils.metrics import IoU, Fscore, Accuracy

In [3]:
import matplotlib.pyplot as plt

# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image, 'gray')
    plt.show()

In [4]:
import random

SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

print(np.random.rand(5), torch.randn(5))

[0.5488135  0.71518937 0.60276338 0.54488318 0.4236548 ] tensor([ 1.5410, -0.2934, -2.1788,  0.5684, -1.0845])


# Settings

In [5]:
root = os.getcwd()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

BATCH = 4

GEN_IN_CHANNELS = 1
GEN_N_CLASSES = 1

DIS_IN_CHANNELS = 1
DIS_N_CLASSES = 2

E_IN_CHANNELS = 1
E_N_CLASSES = 2

EPOCH = 200

### Model Settings

#### encoder

In [6]:
ENCODER = "resnet152"
ENCODER_WEIGHT = None

#### decoder

In [7]:
DECODER_ATT = "scse"

#### head

In [8]:
GEN_ACT = "sigmoid"
DIS_ACT = "softmax"
E_ACT = "softmax"

### Optimizer Settings

In [9]:
GEN_OPTIM_NAME = "adam"
GEN_init_lr = 1e-4
GEN_momentum = 0.9

In [10]:
DIS_OPTIM_NAME = "adam"
DIS_init_lr = 1e-3
DIS_momentum = 0.9
DIS_scheduler = "cosineAnnWarm"

In [11]:
E_OPTIM_NAME = "adam"
E_init_lr = 1e-3
E_momentum = 0.9
E_scheduler = "cosineAnnWarm"

# Epochs

In [12]:
from codes.losses import SSIMLoss
from codes.losses import MAELoss
from pytorch_msssim import ssim
from codes.metrics import PSNR, SNR, ContourEval
from codes.activation import Activation
from codes.utils import hu_clip_tensor
from codes.losses import PerceptualLoss
from kornia.filters.sobel import Sobel

In [13]:
def set_requires_grad(nets, requires_grad=False):
    """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
    Parameters:
    nets (network list)   -- a list of networks
    requires_grad (bool)  -- whether the networks require gradients or not
    """
    if not isinstance(nets, list):
        nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

In [14]:
def replace_relu_to_leakyReLU(model):
    for child_name, child in model.named_children():
        if isinstance(child, nn.ReLU):
            setattr(model, child_name, nn.LeakyReLU())
        else:
            replace_relu_to_leakyReLU(child)

In [15]:
def replace_bn_to_instanceNorm(model):
    for child_name, child in model.named_children():
        if isinstance(child, nn.BatchNorm2d):
            bn = getattr(model, child_name)
            setattr(model, child_name, nn.InstanceNorm2d(bn.num_features))
        else:
            replace_bn_to_instanceNorm(child)

In [16]:
perceptual_ext = PerceptualLoss()

In [17]:
sobel_filter = Sobel().to(device)

In [18]:
def train_a2b_epoch(epoch, generator, gen_optim, 
                    tf_discriminator, tf_d_optim, tf_scheduler, 
                    edge_discriminator, edge_d_optim, edge_scheduler, 
                    dataloader, device):

    # change mode to train and move to current device
    generator = generator.eval().to(device)
    tf_discriminator = tf_discriminator.train().to(device)
    edge_discriminator = edge_discriminator.train().to(device)
    
    tf_dis_l = 0
    edge_dis_l = 0
    tf_gen_l = 0
    edge_gen_l = 0

    set_requires_grad(tf_discriminator, True)
    set_requires_grad(edge_discriminator, True)
    set_requires_grad(generator, False)
    
    iteration = len(dataloader)
    
    for index, data in tqdm(enumerate(dataloader)):
        torch.cuda.empty_cache()
        
        x, y, air_x, bone_x, *_ = data

        x = x.to(device)
        y = y.to(device)
        air_x = air_x.to(device)
        bone_x = bone_x.to(device)

        y_pr = generator(x)
        y_pr = Activation(name=GEN_ACT)(y_pr) # zipped value to [0, 1]
        
        edge_x = sobel_filter(x)
        edge_y_pr = sobel_filter(y_pr)
        edge_y = sobel_filter(y)
        
       ############################
        # (1) Update True/Fake D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        tf_d_optim.zero_grad()

        y_cls = tf_discriminator(y)
        y_pr_cls = tf_discriminator(y_pr.detach())
        
        tf_dis = 0
        for _y_cls, _y_pr_cls in zip(y_cls, y_pr_cls):
            B, C = _y_cls.size()
            tf_dis = nn.CrossEntropyLoss()(_y_cls.float(), torch.zeros(B, dtype=torch.long, device=device)) + \
                            nn.CrossEntropyLoss()(_y_pr_cls.float(), torch.ones(B, dtype=torch.long, device=device)) + \
                            tf_dis

        tf_dis = tf_dis / (len(y_cls))
        tf_dis.backward()
        tf_d_optim.step()
        tf_scheduler.step(EPOCH + index / iteration)
        
        tf_dis_l += tf_dis.item()
        
       ############################
        # (1.1) Update Edge D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################        
        edge_d_optim.zero_grad()        
 
        edge_y_cls = edge_discriminator(edge_y)
        edge_pr_cls = edge_discriminator(edge_y_pr.detach())
        
        edge_dis = 0
        for _edge_y_cls, _edge_pr_cls in zip(edge_y_cls, edge_pr_cls):
            B, C = _edge_y_cls.size()
            edge_dis = nn.CrossEntropyLoss()(_edge_y_cls.float(), torch.zeros(B, dtype=torch.long, device=device)) + \
                                    nn.CrossEntropyLoss()(_edge_pr_cls.float(), torch.ones(B, dtype=torch.long, device=device)) + \
                                    edge_dis

        edge_dis = edge_dis / (len(edge_y_cls))
        edge_dis.backward()
        edge_d_optim.step()       
        edge_scheduler.step(EPOCH + index / iteration)
        
        edge_dis_l += edge_dis.item()

        
    ############################
    # (2) Update G network: maximize log(D(G(z)))
    ###########################
    
    # change mode to train and move to current device
    generator = generator.train().to(device)
    tf_discriminator = tf_discriminator.eval().to(device)
    edge_discriminator = edge_discriminator.eval().to(device)
    
    air_l = 0
    bone_l = 0
    mae_l = 0
    
    set_requires_grad(tf_discriminator, False)
    set_requires_grad(edge_discriminator, False)
    set_requires_grad(generator, True)
        
    for index, data in tqdm(enumerate(dataloader)):
        torch.cuda.empty_cache()
        
        x, y, air_x, bone_x, *_ = data

        x = x.to(device)
        y = y.to(device)
        air_x = air_x.to(device)
        bone_x = bone_x.to(device)

        y_pr = generator(x)
        y_pr = Activation(name=GEN_ACT)(y_pr) # zipped value to [0, 1]
        
        edge_x = sobel_filter(x)
        edge_y_pr = sobel_filter(y_pr)
        edge_y = sobel_filter(y)

        gen_optim.zero_grad()

        # adversarial loss
        assert y_pr.requires_grad, "ct_pred without gradient"
        assert edge_y_pr.requires_grad, "edge without gradient"
        
        y_pr_cls = tf_discriminator(y_pr)
        tf_gen = 0
        for _y_pr_cls in y_pr_cls:
            tf_gen = nn.CrossEntropyLoss()(_y_pr_cls.float(), torch.zeros(_y_pr_cls.size()[0], dtype=torch.long, device=device)) + tf_gen
        tf_gen = tf_gen / (len(y_pr_cls))
        
        edge_pr_cls = edge_discriminator(edge_y_pr)  
        edge_gen = 0
        for _edge_pr_cls in edge_pr_cls:
            edge_gen = nn.CrossEntropyLoss()(_edge_pr_cls.float(), torch.zeros(_edge_pr_cls.size()[0], dtype=torch.long, device=device)) + edge_gen
        edge_gen = edge_gen / (len(edge_pr_cls))

        
        # auxilary loss
        # https://discuss.pytorch.org/t/unclear-about-weighted-bce-loss/21486/2
        _min = -500
        _max = 500
        air_window = (-500, -499)
        upper = ((air_window[1]) - (_min))/(_max-(_min))
        lower = ((air_window[0]) - (_min))/(_max-(_min))
        air_pr = hu_clip_tensor(y_pr.double(), (lower, upper), None, True)
        air_loss = nn.BCELoss()(air_pr.float(), air_x.float())
        
        bone_window = (255, 256)
        upper = ((bone_window[1]) - (_min))/(_max-(_min))
        lower = ((bone_window[0]) - (_min))/(_max-(_min))
        bone_pr = hu_clip_tensor(y_pr.double(), (lower, upper), None, True)
        bone_loss = nn.BCELoss()(bone_pr.float(), bone_x.float())

        _loss = tf_gen + edge_gen + air_loss + bone_loss
        _loss.backward()    
        gen_optim.step()

        edge_gen_l += edge_gen.item()
        tf_gen_l += tf_gen.item()
        air_l += air_loss.item()
        bone_l += bone_loss.item()
        
    return  air_l/(index+1), bone_l/(index+1), \
                    edge_gen_l/(index+1), tf_gen_l/(index+1), \
                    edge_dis_l/(index+1), tf_dis_l/(index+1), \
                    tf_scheduler.get_last_lr()[0], edge_scheduler.get_last_lr()[0]

In [19]:
@torch.no_grad()
def eval_epoch(generator, tf_discriminator, edge_discriminator, dataloader, device):
  
    # change mode to train and move to current device
    generator = generator.eval().to(device)
    tf_discriminator = tf_discriminator.eval().to(device)
    edge_discriminator = edge_discriminator.eval().to(device)
    
    ssim_ = []
    psnr_ = []
    mae_ = []
    snr_ = []
    
    air_ = []
    bone_ = []
    cont_ = []
 
    tf_acc = []
    edge_acc = []
    
    for index, data in tqdm(enumerate(dataloader)):

        x, y, air_x, bone_x, *_ = data

        x = x.to(device)
        y = y.to(device)
        air_x = air_x.to(device)
        bone_x = bone_x.to(device)

        y_pr = generator(x)
        y_pr = Activation(name=GEN_ACT)(y_pr) # zipped value to [0, 1]
        
        edge_x = sobel_filter(x)
        edge_y_pr = sobel_filter(y_pr)
        edge_y = sobel_filter(y)
        
        total = x.size()[0]
       ############################
        # (1) Eval True/Fake D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################        
        y_cls = tf_discriminator(y)
        for _y_cls in y_cls:
            _y_cls = Activation(name=DIS_ACT)(_y_cls)
            _, _y_cls = torch.max(_y_cls.data, 1)
            correct = (_y_cls == torch.zeros(total, dtype=torch.long, device=device)).sum().item()
            tf_acc += [correct / total]        
        
        y_pr_cls = tf_discriminator(y_pr)
        for _y_pr_cls in y_pr_cls:
            _y_pr_cls = Activation(name=DIS_ACT)(_y_pr_cls)
            _, _y_pr_cls = torch.max(_y_pr_cls.data, 1)
            correct = (_y_pr_cls == torch.ones(total, dtype=torch.long, device=device)).sum().item()
            tf_acc += [correct / total]
        
       ############################
        # (1.1) Eval Edge D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################  
        edge_y_cls = edge_discriminator(edge_y)
        for _edge_y_cls in edge_y_cls:
            _edge_y_cls = Activation(name=DIS_ACT)(_edge_y_cls)
            _, _edge_y_cls = torch.max(_edge_y_cls.data, 1)
            correct = (_edge_y_cls == torch.zeros(total, dtype=torch.long, device=device)).sum().item()
            edge_acc += [correct / total]          
        
        edge_pr_cls = edge_discriminator(edge_y_pr)
        for _edge_pr_cls in edge_pr_cls:
            _edge_pr_cls = Activation(name=DIS_ACT)(_edge_pr_cls)
            _, _edge_pr_cls = torch.max(_edge_pr_cls.data, 1)
            correct = (_edge_pr_cls == torch.ones(total, dtype=torch.long, device=device)).sum().item()
            edge_acc += [correct / total]        
        
        ############################
        # (2) Eval G network: maximize log(D(G(z)))
        ###########################
        
        # main
        _ssim = ssim(y, y_pr, data_range=1.0, size_average=True) # return (N,)
        _psnr = PSNR()(y_pr, y, 1.0)
        _snr = SNR()(y_pr, y)
        _mae = MAELoss()(y_pr.float(), y.float())

        ssim_ += [_ssim.item()]
        psnr_ += [_psnr.item()]
        mae_ += [_mae.item()]
        snr_ += [_snr.item()]
        
        # auxilary
        _min = -500
        _max = 500
        air_window = (-500, -499)
        upper = ((air_window[1]) - (_min))/(_max-(_min))
        lower = ((air_window[0]) - (_min))/(_max-(_min))
        air_pr = hu_clip_tensor(y_pr.double(), (lower, upper), None, True)
        dice = Fscore()(air_pr, air_x)
        air_ += [dice.item()]
        
        bone_window = (255, 256)
        upper = ((bone_window[1]) - (_min))/(_max-(_min))
        lower = ((bone_window[0]) - (_min))/(_max-(_min))
        bone_pr = hu_clip_tensor(y_pr.double(), (lower, upper), None, True)
        dice = Fscore()(bone_pr, bone_x)
        bone_ += [dice.item()]

        for b in range(x.size()[0]):
            cont = ContourEval()(y_pr[b, :, :, :], x[b, :, :, :])
            cont_ += [cont.item()]

    return  sum(ssim_)/len(ssim_), sum(psnr_)/len(psnr_), sum(snr_)/len(snr_), sum(mae_)/len(mae_), \
                    sum(air_)/len(air_), sum(bone_)/len(bone_), sum(cont_)/len(cont_), sum(tf_acc)/len(tf_acc), sum(edge_acc)/len(edge_acc)

In [20]:
@torch.no_grad()
def test_epoch(iid, model, dataloader, device, save=False, path=None):
  
    # change mode to train and move to current device
    model = model.eval().to(device)
    
    for index, data in tqdm(enumerate(dataloader)):

        x, y, air_x, bone_x, *_ = data

        x = x.to(device)
        y = y.to(device)
        air_x = air_x.to(device)
        bone_x = bone_x.to(device)

        y_pr = model(x)
        y_pr = Activation(name=GEN_ACT)(y_pr) # zipped value to [0, 1]
        
        edge_x = sobel_filter(x)
        edge_y_pr = sobel_filter(y_pr)
        edge_y = sobel_filter(y)
        
        _min = -500
        _max = 500
        air_window = (-500, -499)
        upper = ((air_window[1]) - (_min))/(_max-(_min))
        lower = ((air_window[0]) - (_min))/(_max-(_min))
        air_pr = hu_clip_tensor(y_pr.double(), (lower, upper), None, True)
        
        bone_window = (255, 256)
        upper = ((bone_window[1]) - (_min))/(_max-(_min))
        lower = ((bone_window[0]) - (_min))/(_max-(_min))
        bone_pr = hu_clip_tensor(y_pr.double(), (lower, upper), None, True)

        for b in range(x.shape[0]):           
            tmp_y = y[b, :, :, :].unsqueeze(0)
            tmp_y_pr = y_pr[b, :, :, :].unsqueeze(0)
            tmp_air = air_x[b, :, :, :].unsqueeze(0)
            tmp_air_pr = air_pr[b, :, :, :].unsqueeze(0)
            tmp_bone = bone_x[b, :, :, :].unsqueeze(0)
            tmp_bone_pr = bone_pr[b, :, :, :].unsqueeze(0)
            
            _ssim = ssim(tmp_y, tmp_y_pr, data_range=1.0, size_average=True) # return (N,)
            _psnr = PSNR()(tmp_y_pr, tmp_y, 1.0)
            _snr = SNR()(tmp_y_pr, tmp_y)
            _mae = MAELoss()(tmp_y_pr.float(), tmp_y.float())
            _air = Fscore()(tmp_air_pr, tmp_air)
            _bone = Fscore()(tmp_bone_pr, tmp_bone)
            _cont = ContourEval()(tmp_y_pr, x[b, :, :, :])
            
            scores = {
                "ssim score": _ssim.item(),
                "psnr score": _psnr.item(),
                "snr score": _snr.item(),
                "mae error": _mae.item(),
                "air dice score": _air.item(),
                "bone dice score": _bone.item(),
                "contour dice score": _cont.item()
            }
            
            __cbct = (x.squeeze().cpu().numpy() * 255).astype(np.uint8)
            __ct = (tmp_y.squeeze().cpu().numpy() * 255).astype(np.uint8)
            __ct_pred = (tmp_y_pr.squeeze().cpu().numpy() * 255).astype(np.uint8)
            __ct_masked = __ct_pred * tmp_air.squeeze().cpu().numpy()
            __air = (tmp_air.squeeze().cpu().numpy() * 255).astype(np.uint8)
            __air_pr = (tmp_air_pr.squeeze().cpu().numpy() * 255).astype(np.uint8)
            __bone = (tmp_bone.squeeze().cpu().numpy() * 255).astype(np.uint8)
            __bone_pr = (tmp_bone_pr.squeeze().cpu().numpy() * 255).astype(np.uint8)
            __edge_pr = (edge_y_pr[b, :].squeeze().cpu().numpy() * 255).astype(np.uint8)
            __edge_x = (edge_x.squeeze().cpu().numpy() * 255).astype(np.uint8)
            __edge_y = (edge_y.squeeze().cpu().numpy() * 255).astype(np.uint8)
            
#             visualize(
#                 cbct = __cbct,
#                 ct_y = __ct,
#                 ct_pred = __ct_pred,
#                 edge_pred = __edge_pr,
#                 edge_x = __edge_x,
#                 edge_y = __edge_y,
#                 air = __air,
#                 air_pr = __air_pr,
#                 bone = __bone,
#                 bone_pr = __bone_pr
#             )
            
            if save:
                path_dir = os.path.join(path, "file_{}".format(iid))
                try:
                    os.mkdir(path_dir)
                except FileExistsError:
                    pass
                
                cv2.imwrite(os.path.join(path_dir, "cbct.jpg"), __cbct)
                cv2.imwrite(os.path.join(path_dir, "ct.jpg"), __ct)
                cv2.imwrite(os.path.join(path_dir, "ct_pred.jpg"), __ct_pred)
                cv2.imwrite(os.path.join(path_dir, "air.jpg"), __air)
                cv2.imwrite(os.path.join(path_dir, "air_pred.jpg"), __air_pr)
                cv2.imwrite(os.path.join(path_dir, "bone.jpg"), __bone)
                cv2.imwrite(os.path.join(path_dir, "bone_pred.jpg"), __bone_pr)
                cv2.imwrite(os.path.join(path_dir, "edge_pred.jpg"), __edge_pr)
                cv2.imwrite(os.path.join(path_dir, "edge_cbct.jpg"), __edge_x)
                cv2.imwrite(os.path.join(path_dir, "edge_ct.jpg"), __edge_y)
                with open(os.path.join(path_dir, "scores.txt"), "w") as file:
                    file.write(json.dumps(scores))                
            iid += 1

# Discriminator

In [21]:
from codes.activation import Activation
import torchvision.models as models
from codes.losses import MultiScaleHeads

In [22]:
class Discriminator(nn.Module):
    def __init__(self, in_channel=1, n_classes=2, activation=None):
        
        super(Discriminator, self).__init__()
        
        encoder = smp.Unet(encoder_name="resnet18", in_channels=in_channel, classes=n_classes).encoder
        self.encoder = encoder
        self.fc = MultiScaleHeads(n_classes=n_classes, channels=(512, ), activation=None)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.fc([x[-1]])
        return x

# Generator

In [23]:
import segmentation_models_pytorch as smp
from codes.decoder import UnetDecoder
from segmentation_models_pytorch.base.heads import SegmentationHead

In [24]:
class Generator(nn.Module):
    def __init__(self, encoder_name, encoder_weights, in_channels, classes, attention_type):
        
        super(Generator, self).__init__()
        
        self.encoder = smp.Unet(encoder_name=encoder_name, encoder_weights=encoder_weights, in_channels=in_channels, classes=classes).encoder
        self.decoder = UnetDecoder(
                            encoder_channels=[in_channels, 0, 0, 0, 0, 2048],
                            decoder_channels=[512, 256, 128, 64, 16],
                            n_blocks=5,
                            use_batchnorm=True,
                            attention_type=attention_type)
        self.head =  SegmentationHead(
                        in_channels=16,
                        out_channels=classes,
                        activation=None)
        
        
    def forward(self, x):
        x = self.encoder(x)[-1]
        x = self.decoder(*(None, None, None, None, x))
        x = self.head(x)
        
        return x

# Read Data

In [25]:
import glob
from codes.dataset import Dataset, DicomDataset, DicomsDataset
import codes.augmentation as aug

In [26]:
# run_name = wandb.run.name
run_name = "dashing-valley-161"

In [27]:
checkpoint = torch.load(os.path.join("weight-gan", "{}.pth".format(run_name)))

In [28]:
model = smp.Unet(encoder_name=ENCODER, 
                  encoder_weights=ENCODER_WEIGHT, 
                  in_channels=GEN_IN_CHANNELS, 
                  classes=GEN_N_CLASSES,
                 decoder_attention_type=DECODER_ATT)
replace_relu_to_leakyReLU(model)
replace_bn_to_instanceNorm(model)
model.load_state_dict(checkpoint["model"])

<All keys matched successfully>

In [29]:
tf_discriminator = Discriminator(in_channel=DIS_IN_CHANNELS, n_classes=DIS_N_CLASSES, activation=DIS_ACT)
replace_relu_to_leakyReLU(tf_discriminator)
replace_bn_to_instanceNorm(tf_discriminator)
tf_discriminator.load_state_dict(checkpoint["tf_discriminator"])

<All keys matched successfully>

In [30]:
edge_discriminator = Discriminator(in_channel=E_IN_CHANNELS, n_classes=E_N_CLASSES, activation=E_ACT)
replace_relu_to_leakyReLU(edge_discriminator)
replace_bn_to_instanceNorm(edge_discriminator)
edge_discriminator.load_state_dict(checkpoint["edge_discriminator"])

<All keys matched successfully>

# Pelvic

## Pelvic Test CBCT

In [31]:
test_case_path = 'raw/test/*_*'
paths = sorted(glob.glob(test_case_path))

In [32]:
try:
    os.mkdir(os.path.join("eval-gan", run_name))
except FileExistsError:
    pass

In [33]:
try:
    os.mkdir(os.path.join("eval-gan", run_name, "test_pelvic_cbct"))
except FileExistsError:
    pass

In [34]:
# read cbct and ct
iid = 0
for i in range(0, len(paths), 2):
    scans = DicomDataset(cbct_path=paths[i+1], ct_path=paths[i], ditch=3, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=False)
    patient_id = scans.patientID()
    testloader = torch.utils.data.DataLoader(scans, batch_size=1, shuffle=False, num_workers=4)
    try:
        os.mkdir(os.path.join("eval-gan", run_name, "test_pelvic_cbct", patient_id))
    except FileExistsError:
        pass
    
    test_epoch(iid, model, testloader, device, True, os.path.join("eval-gan", run_name, "test_pelvic_cbct", patient_id)) 
    iid += len(scans)

28it [00:02, 12.76it/s]
27it [00:01, 20.43it/s]
26it [00:01, 19.67it/s]
28it [00:01, 19.85it/s]
27it [00:01, 19.66it/s]
28it [00:01, 20.35it/s]
27it [00:01, 19.87it/s]


In [35]:
testset = DicomsDataset(test_case_path, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)
# score_ssim_test, score_psnr_test, score_snr_test, score_mae_test, \
#             score_air_test, score_bone_test, score_cont_test, score_tf_acc_test, score_edge_acc_test
eval_epoch(model, tf_discriminator, edge_discriminator, testloader, device)

191it [00:09, 19.26it/s]


(0.8335763627946065,
 23.3061956435598,
 14.310467700059501,
 0.02841728902299991,
 0.996848789770518,
 0.8615822573440921,
 0.4753542253796343,
 0.5,
 0.4973821989528796)

## Pelvic Test CT

In [36]:
test_case_path = 'raw/test/*_*'
paths = sorted(glob.glob(test_case_path))

In [37]:
try:
    os.mkdir(os.path.join("eval-gan", run_name))
except FileExistsError:
    pass

In [38]:
try:
    os.mkdir(os.path.join("eval-gan", run_name, "test_pelvic_ct"))
except FileExistsError:
    pass

In [39]:
# read cbct and ct
iid = 0
for i in range(0, len(paths), 2):
    scans = DicomDataset(cbct_path=paths[i+1], ct_path=paths[i], ditch=3, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=True)
    patient_id = scans.patientID()
    testloader = torch.utils.data.DataLoader(scans, batch_size=1, shuffle=False, num_workers=4)
    try:
        os.mkdir(os.path.join("eval-gan", run_name, "test_pelvic_ct", patient_id))
    except FileExistsError:
        pass
    
    test_epoch(iid, model, testloader, device, True, os.path.join("eval-gan", run_name, "test_pelvic_ct", patient_id)) 
    iid += len(scans)

28it [00:01, 18.82it/s]
27it [00:01, 19.33it/s]
26it [00:01, 19.21it/s]
28it [00:01, 19.74it/s]
27it [00:01, 19.16it/s]
28it [00:01, 19.47it/s]
27it [00:01, 19.13it/s]


In [40]:
testset = DicomsDataset(test_case_path, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)
# score_ssim_test, score_psnr_test, score_snr_test, score_mae_test, \
#             score_air_test, score_bone_test, score_cont_test, score_tf_acc_test, score_edge_acc_test
eval_epoch(model, tf_discriminator, edge_discriminator, testloader, device)

191it [00:10, 19.08it/s]


(0.9155489952152313,
 27.91344012515083,
 18.91771220661583,
 0.01751434677207345,
 0.9967022513157912,
 0.8472905809715715,
 0.48818213964632046,
 0.5,
 0.4973821989528796)

## Pelvic L1 CBCT

In [41]:
test_case_path = 'L1_pelvic_processed/reg_pelvic_l1/*_*'
paths = sorted(glob.glob(test_case_path))

In [42]:
try:
    os.mkdir(os.path.join("eval-gan", run_name))
except FileExistsError:
    pass

In [43]:
try:
    os.mkdir(os.path.join("eval-gan", run_name, "L1_pelvic_cbct"))
except FileExistsError:
    pass

In [44]:
# read cbct and ct
iid = 0
for i in range(0, len(paths), 2):
    scans = DicomDataset(cbct_path=paths[i+1], ct_path=paths[i], ditch=3, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=False)
    patient_id = scans.patientID()
    testloader = torch.utils.data.DataLoader(scans, batch_size=1, shuffle=False, num_workers=4)
    try:
        os.mkdir(os.path.join("eval-gan", run_name, "L1_pelvic_cbct", patient_id))
    except FileExistsError:
        pass
    
    test_epoch(iid, model, testloader, device, True, os.path.join("eval-gan", run_name, "L1_pelvic_cbct", patient_id)) 
    iid += len(scans)

29it [00:01, 19.35it/s]
26it [00:01, 18.72it/s]
28it [00:01, 18.14it/s]
28it [00:01, 19.53it/s]
27it [00:01, 19.38it/s]
26it [00:01, 18.66it/s]
27it [00:01, 18.55it/s]
30it [00:01, 19.61it/s]
28it [00:01, 19.90it/s]
29it [00:01, 19.60it/s]
29it [00:01, 19.78it/s]
27it [00:01, 19.92it/s]
28it [00:01, 19.91it/s]
26it [00:01, 19.61it/s]
28it [00:01, 20.14it/s]
28it [00:01, 19.20it/s]
27it [00:01, 19.71it/s]
26it [00:01, 19.90it/s]
27it [00:01, 19.58it/s]
29it [00:01, 19.63it/s]
30it [00:01, 20.26it/s]
32it [00:01, 20.96it/s]
28it [00:01, 18.56it/s]
28it [00:01, 19.79it/s]
15it [00:00, 18.00it/s]
28it [00:01, 19.79it/s]
26it [00:01, 19.04it/s]
28it [00:01, 20.08it/s]
26it [00:01, 19.80it/s]
27it [00:01, 20.28it/s]
36it [00:01, 20.06it/s]
28it [00:01, 19.61it/s]
27it [00:01, 19.90it/s]
28it [00:01, 20.00it/s]
30it [00:01, 19.86it/s]
27it [00:01, 19.94it/s]
26it [00:01, 19.77it/s]
29it [00:01, 19.62it/s]
27it [00:01, 18.96it/s]
27it [00:01, 19.29it/s]
28it [00:01, 20.01it/s]
26it [00:01, 19.

In [45]:
testset = DicomsDataset(test_case_path, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)
# score_ssim_test, score_psnr_test, score_snr_test, score_mae_test, \
#             score_air_test, score_bone_test, score_cont_test, score_tf_acc_test, score_edge_acc_test
eval_epoch(model, tf_discriminator, edge_discriminator, testloader, device)

1457it [01:12, 20.21it/s]


(0.8182522566987785,
 22.719802633244942,
 13.51878494102004,
 0.030733568246771027,
 0.9961217021741574,
 0.8366674953045895,
 0.4995913408293026,
 0.5013726835964311,
 0.47632120796156485)

## Pelvic L1 CT

In [46]:
test_case_path = 'L1_pelvic_processed/reg_pelvic_l1/*_*'
paths = sorted(glob.glob(test_case_path))

In [47]:
try:
    os.mkdir(os.path.join("eval-gan", run_name))
except FileExistsError:
    pass

In [48]:
try:
    os.mkdir(os.path.join("eval-gan", run_name, "L1_pelvic_ct"))
except FileExistsError:
    pass

In [49]:
# read cbct and ct
iid = 0
for i in range(0, len(paths), 2):
    scans = DicomDataset(cbct_path=paths[i+1], ct_path=paths[i], ditch=3, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=True)
    patient_id = scans.patientID()
    testloader = torch.utils.data.DataLoader(scans, batch_size=1, shuffle=False, num_workers=4)
    try:
        os.mkdir(os.path.join("eval-gan", run_name, "L1_pelvic_ct", patient_id))
    except FileExistsError:
        pass
    
    test_epoch(iid, model, testloader, device, True, os.path.join("eval-gan", run_name, "L1_pelvic_ct", patient_id)) 
    iid += len(scans)

29it [00:01, 19.52it/s]
26it [00:01, 18.80it/s]
28it [00:01, 19.33it/s]
28it [00:01, 19.09it/s]
27it [00:01, 19.63it/s]
26it [00:01, 19.37it/s]
27it [00:01, 19.64it/s]
30it [00:01, 20.01it/s]
28it [00:01, 19.45it/s]
29it [00:01, 19.36it/s]
29it [00:01, 19.26it/s]
27it [00:01, 19.43it/s]
28it [00:01, 18.97it/s]
26it [00:01, 18.84it/s]
28it [00:01, 18.85it/s]
28it [00:01, 19.55it/s]
27it [00:01, 19.31it/s]
26it [00:01, 19.06it/s]
27it [00:01, 18.89it/s]
29it [00:01, 19.17it/s]
30it [00:01, 19.55it/s]
32it [00:01, 19.93it/s]
28it [00:01, 19.21it/s]
28it [00:01, 19.01it/s]
15it [00:00, 16.91it/s]
28it [00:01, 19.86it/s]
26it [00:01, 18.34it/s]
28it [00:01, 18.40it/s]
26it [00:01, 19.18it/s]
27it [00:01, 19.07it/s]
36it [00:01, 19.67it/s]
28it [00:01, 19.57it/s]
27it [00:01, 18.31it/s]
28it [00:01, 19.71it/s]
30it [00:01, 19.96it/s]
27it [00:01, 19.04it/s]
26it [00:01, 19.04it/s]
29it [00:01, 19.67it/s]
27it [00:01, 19.45it/s]
27it [00:01, 18.73it/s]
28it [00:01, 19.19it/s]
26it [00:01, 18.

In [50]:
testset = DicomsDataset(test_case_path, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)
# score_ssim_test, score_psnr_test, score_snr_test, score_mae_test, \
#             score_air_test, score_bone_test, score_cont_test, score_tf_acc_test, score_edge_acc_test
eval_epoch(model, tf_discriminator, edge_discriminator, testloader, device)

1457it [01:12, 20.05it/s]


(0.9122636873344075,
 27.98485933040705,
 18.783841616827573,
 0.01768491023418915,
 0.9969199726447608,
 0.8401054924853236,
 0.5039504731077811,
 0.5,
 0.4766643788606726)

## Pelvic L2 CBCT

In [51]:
test_case_path = 'L2_pelvic_processed/reg_pelvic_l2/*_*'
paths = sorted(glob.glob(test_case_path))

In [52]:
try:
    os.mkdir(os.path.join("eval-gan", run_name))
except FileExistsError:
    pass

In [53]:
try:
    os.mkdir(os.path.join("eval-gan", run_name, "L2_pelvic_cbct"))
except FileExistsError:
    pass

In [54]:
# read cbct and ct
iid = 0
for i in range(0, len(paths), 2):
    scans = DicomDataset(cbct_path=paths[i+1], ct_path=paths[i], ditch=3, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=False)
    patient_id = scans.patientID()
    testloader = torch.utils.data.DataLoader(scans, batch_size=1, shuffle=False, num_workers=4)
    try:
        os.mkdir(os.path.join("eval-gan", run_name, "L2_pelvic_cbct", patient_id))
    except FileExistsError:
        pass
    
    test_epoch(iid, model, testloader, device, True, os.path.join("eval-gan", run_name, "L2_pelvic_cbct", patient_id)) 
    iid += len(scans)

26it [00:01, 18.44it/s]
26it [00:01, 16.73it/s]
26it [00:01, 19.27it/s]
27it [00:01, 18.70it/s]
31it [00:01, 19.62it/s]
27it [00:01, 18.07it/s]
27it [00:01, 18.84it/s]
26it [00:01, 18.42it/s]
28it [00:01, 19.04it/s]
28it [00:01, 19.41it/s]
27it [00:01, 18.52it/s]
27it [00:01, 18.74it/s]
26it [00:01, 19.03it/s]
26it [00:01, 18.45it/s]
26it [00:01, 19.01it/s]
26it [00:01, 18.66it/s]
26it [00:01, 18.41it/s]
26it [00:01, 18.59it/s]
27it [00:01, 18.96it/s]
26it [00:01, 18.29it/s]
26it [00:01, 18.23it/s]
28it [00:01, 18.44it/s]
28it [00:01, 19.16it/s]
29it [00:01, 18.59it/s]
28it [00:01, 19.31it/s]
26it [00:01, 19.10it/s]
27it [00:01, 18.64it/s]
26it [00:01, 19.03it/s]
28it [00:01, 18.85it/s]
30it [00:01, 19.00it/s]
27it [00:01, 18.52it/s]
29it [00:01, 17.25it/s]
27it [00:01, 19.08it/s]
27it [00:01, 18.82it/s]
26it [00:01, 18.90it/s]
26it [00:01, 18.89it/s]
26it [00:01, 19.05it/s]
26it [00:01, 18.20it/s]
27it [00:01, 18.80it/s]
26it [00:01, 18.00it/s]
26it [00:01, 18.61it/s]
26it [00:01, 18.

In [55]:
testset = DicomsDataset(test_case_path, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)
# score_ssim_test, score_psnr_test, score_snr_test, score_mae_test, \
#             score_air_test, score_bone_test, score_cont_test, score_tf_acc_test, score_edge_acc_test
eval_epoch(model, tf_discriminator, edge_discriminator, testloader, device)

3064it [02:32, 20.03it/s]


(0.8070026322537863,
 22.09904223353682,
 12.514194093936418,
 0.032503613919158306,
 0.9961601279546994,
 0.8330500069193886,
 0.5798903052493268,
 0.502121409921671,
 0.4968994778067885)

## Pelvic L2 CT

In [56]:
test_case_path = 'L2_pelvic_processed/reg_pelvic_l2/*_*'
paths = sorted(glob.glob(test_case_path))

In [57]:
try:
    os.mkdir(os.path.join("eval-gan", run_name))
except FileExistsError:
    pass

In [58]:
try:
    os.mkdir(os.path.join("eval-gan", run_name, "L2_pelvic_ct"))
except FileExistsError:
    pass

In [59]:
# read cbct and ct
iid = 0
for i in range(0, len(paths), 2):
    scans = DicomDataset(cbct_path=paths[i+1], ct_path=paths[i], ditch=3, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=True)
    patient_id = scans.patientID()
    testloader = torch.utils.data.DataLoader(scans, batch_size=1, shuffle=False, num_workers=4)
    try:
        os.mkdir(os.path.join("eval-gan", run_name, "L2_pelvic_ct", patient_id))
    except FileExistsError:
        pass
    
    test_epoch(iid, model, testloader, device, True, os.path.join("eval-gan", run_name, "L2_pelvic_ct", patient_id)) 
    iid += len(scans)

26it [00:01, 18.47it/s]
26it [00:01, 17.13it/s]
26it [00:01, 18.45it/s]
27it [00:01, 18.41it/s]
31it [00:01, 18.79it/s]
27it [00:01, 17.78it/s]
27it [00:01, 18.25it/s]
26it [00:01, 17.91it/s]
28it [00:01, 18.11it/s]
28it [00:01, 18.72it/s]
27it [00:01, 18.14it/s]
27it [00:01, 18.87it/s]
26it [00:01, 18.50it/s]
26it [00:01, 18.47it/s]
26it [00:01, 18.37it/s]
26it [00:01, 18.48it/s]
26it [00:01, 18.44it/s]
26it [00:01, 18.21it/s]
27it [00:01, 18.06it/s]
26it [00:01, 17.95it/s]
26it [00:01, 18.29it/s]
28it [00:01, 18.57it/s]
28it [00:01, 18.50it/s]
29it [00:01, 17.59it/s]
28it [00:01, 18.35it/s]
26it [00:01, 18.04it/s]
27it [00:01, 18.45it/s]
26it [00:01, 18.21it/s]
28it [00:01, 18.56it/s]
30it [00:01, 18.31it/s]
27it [00:01, 18.47it/s]
29it [00:01, 18.53it/s]
27it [00:01, 19.07it/s]
27it [00:01, 17.78it/s]
26it [00:01, 18.42it/s]
26it [00:01, 18.43it/s]
26it [00:01, 18.79it/s]
26it [00:01, 17.84it/s]
27it [00:01, 18.29it/s]
26it [00:01, 17.87it/s]
26it [00:01, 18.59it/s]
26it [00:01, 18.

In [60]:
testset = DicomsDataset(test_case_path, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)
# score_ssim_test, score_psnr_test, score_snr_test, score_mae_test, \
#             score_air_test, score_bone_test, score_cont_test, score_tf_acc_test, score_edge_acc_test
eval_epoch(model, tf_discriminator, edge_discriminator, testloader, device)

3064it [02:33, 20.01it/s]


(0.919069783708601,
 28.746581126130902,
 19.161732978048896,
 0.0149197413907666,
 0.9957276478780008,
 0.8510225283182197,
 0.51591270302689,
 0.5001631853785901,
 0.49510443864229764)

# Abdomen

In [61]:
test_case_path = 'abdomen/test/*_*'
paths = sorted(glob.glob(test_case_path))

## Abdomen on CBCT

In [62]:
try:
    os.mkdir(os.path.join("eval-gan", run_name))
except FileExistsError:
    pass

In [63]:
try:
    os.mkdir(os.path.join("eval-gan", run_name, "abdomen_test_cbct"))
except FileExistsError:
    pass

In [64]:
# read cbct and ct
iid = 0
for i in range(0, len(paths), 2):
    scans = DicomDataset(cbct_path=paths[i+1], ct_path=paths[i], ditch=3, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=False)
    patient_id = scans.patientID()
    testloader = torch.utils.data.DataLoader(scans, batch_size=1, shuffle=False, num_workers=4)
    try:
        os.mkdir(os.path.join("eval-gan", run_name, "abdomen_test_cbct", patient_id))
    except FileExistsError:
        pass
    
    test_epoch(iid, model, testloader, device, True, os.path.join("eval-gan", run_name, "abdomen_test_cbct", patient_id)) 
    iid += len(scans)

28it [00:01, 17.30it/s]
78it [00:04, 19.22it/s]
50it [00:02, 19.15it/s]
29it [00:01, 18.45it/s]
50it [00:02, 18.24it/s]
71it [00:03, 18.35it/s]
39it [00:02, 18.29it/s]


In [65]:
testset = DicomsDataset(test_case_path, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)
# score_ssim_test, score_psnr_test, score_snr_test, score_mae_test, \
#             score_air_test, score_bone_test, score_cont_test, score_tf_acc_test, score_edge_acc_test
eval_epoch(model, tf_discriminator, edge_discriminator, testloader, device)

345it [00:17, 19.22it/s]


(0.7360915315323981,
 20.447586551610975,
 10.978227572510209,
 0.04378857581835726,
 0.9910460487732943,
 0.7480698697937649,
 0.4359599198984063,
 0.4782608695652174,
 0.6)

## Abdomen on CT

In [66]:
try:
    os.mkdir(os.path.join("eval-gan", run_name))
except FileExistsError:
    pass

In [67]:
try:
    os.mkdir(os.path.join("eval-gan", run_name, "abdomen_test_ct"))
except FileExistsError:
    pass

In [68]:
# read cbct and ct
iid = 0
for i in range(0, len(paths), 2):
    scans = DicomDataset(cbct_path=paths[i+1], ct_path=paths[i], ditch=3, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=True)
    patient_id = scans.patientID()
    testloader = torch.utils.data.DataLoader(scans, batch_size=1, shuffle=False, num_workers=4)
    try:
        os.mkdir(os.path.join("eval-gan", run_name, "abdomen_test_ct", patient_id))
    except FileExistsError:
        pass
    
    test_epoch(iid, model, testloader, device, True, os.path.join("eval-gan", run_name, "abdomen_test_ct", patient_id)) 
    iid += len(scans)

28it [00:01, 16.91it/s]
78it [00:04, 19.41it/s]
50it [00:02, 19.09it/s]
29it [00:01, 18.45it/s]
50it [00:02, 18.22it/s]
71it [00:03, 19.06it/s]
39it [00:02, 17.09it/s]


In [69]:
testset = DicomsDataset(test_case_path, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)
# score_ssim_test, score_psnr_test, score_snr_test, score_mae_test, \
#             score_air_test, score_bone_test, score_cont_test, score_tf_acc_test, score_edge_acc_test
eval_epoch(model, tf_discriminator, edge_discriminator, testloader, device)

345it [00:18, 19.15it/s]


(0.8804858138595802,
 26.812738435164743,
 17.343379399396373,
 0.019925909394911234,
 0.9912097250665444,
 0.7521050641476357,
 0.43140424265377764,
 0.46956521739130436,
 0.4971014492753623)

# Chest

In [70]:
test_case_path = 'chest/test/*_*'
paths = sorted(glob.glob(test_case_path))

## Chest on CBCT

In [71]:
try:
    os.mkdir(os.path.join("eval-gan", run_name))
except FileExistsError:
    pass

In [72]:
try:
    os.mkdir(os.path.join("eval-gan", run_name, "chest_test_cbct"))
except FileExistsError:
    pass

In [73]:
# read cbct and ct
iid = 0
for i in range(0, len(paths), 2):
    scans = DicomDataset(cbct_path=paths[i+1], ct_path=paths[i], ditch=3, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=False)
    patient_id = scans.patientID()
    testloader = torch.utils.data.DataLoader(scans, batch_size=1, shuffle=False, num_workers=4)
    try:
        os.mkdir(os.path.join("eval-gan", run_name, "chest_test_cbct", patient_id))
    except FileExistsError:
        pass
    
    test_epoch(iid, model, testloader, device, True, os.path.join("eval-gan", run_name, "chest_test_cbct", patient_id)) 
    iid += len(scans)

49it [00:02, 18.79it/s]
39it [00:02, 18.38it/s]
40it [00:02, 17.97it/s]
29it [00:01, 16.14it/s]
35it [00:01, 17.67it/s]
37it [00:02, 17.51it/s]
34it [00:02, 16.85it/s]


In [74]:
testset = DicomsDataset(test_case_path, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)
# score_ssim_test, score_psnr_test, score_snr_test, score_mae_test, \
#             score_air_test, score_bone_test, score_cont_test, score_tf_acc_test, score_edge_acc_test
eval_epoch(model, tf_discriminator, edge_discriminator, testloader, device)

263it [00:13, 19.05it/s]


(0.7609564967935076,
 20.659521534415706,
 10.211650753202548,
 0.03748260943410288,
 0.9602242858203407,
 0.7783437568264306,
 0.6013020702527956,
 0.42775665399239543,
 0.5684410646387833)

## Chest on CT

In [75]:
try:
    os.mkdir(os.path.join("eval-gan", run_name))
except FileExistsError:
    pass

In [76]:
try:
    os.mkdir(os.path.join("eval-gan", run_name, "chest_test_ct"))
except FileExistsError:
    pass

In [77]:
# read cbct and ct
iid = 0
for i in range(0, len(paths), 2):
    scans = DicomDataset(cbct_path=paths[i+1], ct_path=paths[i], ditch=3, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=True)
    patient_id = scans.patientID()
    testloader = torch.utils.data.DataLoader(scans, batch_size=1, shuffle=False, num_workers=4)
    try:
        os.mkdir(os.path.join("eval-gan", run_name, "chest_test_ct", patient_id))
    except FileExistsError:
        pass
    
    test_epoch(iid, model, testloader, device, True, os.path.join("eval-gan", run_name, "chest_test_ct", patient_id)) 
    iid += len(scans)

49it [00:02, 18.26it/s]
39it [00:02, 18.42it/s]
40it [00:02, 18.06it/s]
29it [00:01, 16.33it/s]
35it [00:02, 17.30it/s]
37it [00:02, 17.66it/s]
34it [00:02, 16.93it/s]


In [78]:
testset = DicomsDataset(test_case_path, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)
# score_ssim_test, score_psnr_test, score_snr_test, score_mae_test, \
#             score_air_test, score_bone_test, score_cont_test, score_tf_acc_test, score_edge_acc_test
eval_epoch(model, tf_discriminator, edge_discriminator, testloader, device)

263it [00:13, 19.07it/s]


(0.9101409864516312,
 28.163540749495475,
 17.715669976441127,
 0.014904872634706615,
 0.9615348073937652,
 0.7688699186743759,
 0.584569799186159,
 0.4144486692015209,
 0.435361216730038)

# Headneck

In [31]:
test_case_path = 'headneck/test/*_*'
paths = sorted(glob.glob(test_case_path))

## Headneck on CBCT

In [32]:
try:
    os.mkdir(os.path.join("eval-gan", run_name))
except FileExistsError:
    pass

In [33]:
try:
    os.mkdir(os.path.join("eval-gan", run_name, "hn_test_cbct"))
except FileExistsError:
    pass

In [34]:
# read cbct and ct
iid = 0
for i in range(0, len(paths), 2):
    scans = DicomDataset(cbct_path=paths[i+1], ct_path=paths[i], ditch=3, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=False)
    patient_id = scans.patientID()
    testloader = torch.utils.data.DataLoader(scans, batch_size=1, shuffle=False, num_workers=4)
    try:
        os.mkdir(os.path.join("eval-gan", run_name, "hn_test_cbct", patient_id))
    except FileExistsError:
        pass
    
    test_epoch(iid, model, testloader, device, True, os.path.join("eval-gan", run_name, "hn_test_cbct", patient_id)) 
    iid += len(scans)

58it [00:03, 16.03it/s]
55it [00:02, 20.88it/s]
55it [00:02, 21.30it/s]
56it [00:02, 20.97it/s]
56it [00:02, 20.84it/s]
55it [00:02, 21.16it/s]
56it [00:02, 21.74it/s]


In [35]:
testset = DicomsDataset(test_case_path, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)
# score_ssim_test, score_psnr_test, score_snr_test, score_mae_test, \
#             score_air_test, score_bone_test, score_cont_test, score_tf_acc_test, score_edge_acc_test
eval_epoch(model, tf_discriminator, edge_discriminator, testloader, device)

391it [00:19, 19.69it/s]


(0.7980937566370001,
 20.324446195226802,
 6.37189347329347,
 0.03392029772310153,
 0.5509570363082971,
 0.6837179195462165,
 0.1401999675400658,
 0.40664961636828645,
 0.5179028132992327)

## Headneck on CT

In [36]:
try:
    os.mkdir(os.path.join("eval-gan", run_name))
except FileExistsError:
    pass

In [37]:
try:
    os.mkdir(os.path.join("eval-gan", run_name, "hn_test_ct"))
except FileExistsError:
    pass

In [38]:
# read cbct and ct
iid = 0
for i in range(0, len(paths), 2):
    scans = DicomDataset(cbct_path=paths[i+1], ct_path=paths[i], ditch=3, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=True)
    patient_id = scans.patientID()
    testloader = torch.utils.data.DataLoader(scans, batch_size=1, shuffle=False, num_workers=4)
    try:
        os.mkdir(os.path.join("eval-gan", run_name, "hn_test_ct", patient_id))
    except FileExistsError:
        pass
    
    test_epoch(iid, model, testloader, device, True, os.path.join("eval-gan", run_name, "hn_test_ct", patient_id)) 
    iid += len(scans)

58it [00:02, 21.29it/s]
55it [00:02, 20.60it/s]
55it [00:02, 20.98it/s]
56it [00:02, 21.21it/s]
56it [00:02, 20.56it/s]
55it [00:02, 20.93it/s]
56it [00:02, 20.88it/s]


In [39]:
testset = DicomsDataset(test_case_path, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, identity=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)
# score_ssim_test, score_psnr_test, score_snr_test, score_mae_test, \
#             score_air_test, score_bone_test, score_cont_test, score_tf_acc_test, score_edge_acc_test
eval_epoch(model, tf_discriminator, edge_discriminator, testloader, device)

391it [00:19, 20.09it/s]


(0.8420657558971659,
 23.8565428482602,
 9.9039901095583,
 0.03822702361875788,
 0.586710221456378,
 0.7214187184565952,
 0.16228404077024855,
 0.5140664961636828,
 0.49744245524296676)