In [1]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, ConcatDataset

import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
import random
import json

In [2]:
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss
from segmentation_models_pytorch.utils.metrics import IoU, Fscore, Accuracy

In [3]:
import matplotlib.pyplot as plt

# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image, 'gray')
    plt.show()

In [4]:
import random

SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

print(np.random.rand(5), torch.randn(5))

[0.5488135  0.71518937 0.60276338 0.54488318 0.4236548 ] tensor([ 1.5410, -0.2934, -2.1788,  0.5684, -1.0845])


# Settings

In [5]:
root = os.getcwd()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

BATCH = 4

GEN_IN_CHANNELS = 1
GEN_N_CLASSES = 1

DIS_IN_CHANNELS = 1
DIS_N_CLASSES = 2

E_IN_CHANNELS = 1
E_N_CLASSES = 2

EPOCH = 200

### Model Settings

#### encoder

In [6]:
ENCODER = "resnet152"
ENCODER_WEIGHT = None

#### decoder

In [7]:
DECODER_ATT = "scse"

#### head

In [8]:
GEN_ACT = "sigmoid"
DIS_ACT = "softmax"
E_ACT = "softmax"

# Epochs

In [9]:
from codes.losses import SSIMLoss
from codes.losses import MAELoss
from pytorch_msssim import ssim
from codes.metrics import PSNR, SNR, ContourEval
from codes.activation import Activation
from codes.utils import hu_clip_tensor
from codes.losses import PerceptualLoss
from kornia.filters.sobel import Sobel

In [10]:
def set_requires_grad(nets, requires_grad=False):
    """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
    Parameters:
    nets (network list)   -- a list of networks
    requires_grad (bool)  -- whether the networks require gradients or not
    """
    if not isinstance(nets, list):
        nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

In [11]:
def replace_relu_to_leakyReLU(model):
    for child_name, child in model.named_children():
        if isinstance(child, nn.ReLU):
            setattr(model, child_name, nn.LeakyReLU())
        else:
            replace_relu_to_leakyReLU(child)

In [12]:
def replace_bn_to_instanceNorm(model):
    for child_name, child in model.named_children():
        if isinstance(child, nn.BatchNorm2d):
            bn = getattr(model, child_name)
            setattr(model, child_name, nn.InstanceNorm2d(bn.num_features))
        else:
            replace_bn_to_instanceNorm(child)

In [13]:
@torch.no_grad()
def eval_epoch(netG_A2B, dataloader, device):
  
    # change mode to train and move to current device
    netG_A2B = netG_A2B.eval().to(device)
    
    ssim_ = []
    psnr_ = []
    mae_ = []
    snr_ = []
    
    air_ = []
    bone_ = []
    cont_ = []

    for index, data in tqdm(enumerate(dataloader)):

        x, y, air_x, bone_x, *_ = data
        
        B, C, H, W = x.size()
        x = x.to(device)
        y = y.to(device)
        air_x = air_x.to(device)
        bone_x = bone_x.to(device)

        fake_B = netG_A2B(x.float()).float()
    
        ############################
        # (2) Eval G network: maximize log(D(G(z)))
        ###########################

        # main
        _ssim = ssim(y.float(), fake_B, data_range=1.0, size_average=True) # return (N,)
        _psnr = PSNR()(fake_B, y.float(), 1.0)
        _snr = SNR()(fake_B, y.float())
        _mae = MAELoss()(fake_B.float(), y.float())

        ssim_ += [_ssim.item()]
        psnr_ += [_psnr.item()]
        mae_ += [_mae.item()]
        snr_ += [_snr.item()]
        
        # auxilary
        _min = VIEW_BOUND[0]
        _max = VIEW_BOUND[1]
        air_window = AIR_BOUND
        upper = ((air_window[1]) - (_min))/(_max-(_min))
        lower = ((air_window[0]) - (_min))/(_max-(_min))
        air_pr = hu_clip_tensor(fake_B.double(), (lower, upper), None, True)
        dice = Fscore()(air_pr, air_x)
        air_ += [dice.item()]
        
        bone_window = BONE_BOUND
        upper = ((bone_window[1]) - (_min))/(_max-(_min))
        lower = ((bone_window[0]) - (_min))/(_max-(_min))
        bone_pr = hu_clip_tensor(fake_B.double(), (lower, upper), None, True)
        dice = Fscore()(bone_pr, bone_x)
        bone_ += [dice.item()]
        
        for b in range(x.size()[0]):
            cont = ContourEval()(fake_B[b, :, :, :], x[b, :, :, :])
            cont_ += [cont.item()]

    return  sum(ssim_)/len(ssim_), sum(psnr_)/len(psnr_), sum(snr_)/len(snr_), sum(mae_)/len(mae_), \
                    sum(air_)/len(air_), sum(bone_)/len(bone_), sum(cont_)/len(cont_)

In [14]:
@torch.no_grad()
def test_epoch(netG_A2B, dataloader, device, save=False, path=None):
  
    # change mode to train and move to current device
    netG_A2B = netG_A2B.eval().to(device)
    
    iid = 0
    for index, data in tqdm(enumerate(dataloader)):

        x, y, air_x, bone_x, *_ = data
        
        B, C, H, W = x.size()
        x = x.to(device)
        y = y.to(device)
        air_x = air_x.to(device)
        bone_x = bone_x.to(device)

        y_pr = netG_A2B(x.float())

        _min = VIEW_BOUND[0]
        _max = VIEW_BOUND[1]
        air_window = AIR_BOUND
        upper = ((air_window[1]) - (_min))/(_max-(_min))
        lower = ((air_window[0]) - (_min))/(_max-(_min))
        air_pr = hu_clip_tensor(y_pr.double(), (lower, upper), None, True)
        
        bone_window = BONE_BOUND
        upper = ((bone_window[1]) - (_min))/(_max-(_min))
        lower = ((bone_window[0]) - (_min))/(_max-(_min))
        bone_pr = hu_clip_tensor(y_pr.double(), (lower, upper), None, True)
        
        for b in range(x.shape[0]):           
            tmp_y = y[b, :, :, :].unsqueeze(0).float()
            tmp_y_pr = y_pr[b, :, :, :].unsqueeze(0)
            tmp_air = air_x[b, :, :, :].unsqueeze(0)
            tmp_air_pr = air_pr[b, :, :, :].unsqueeze(0)
            tmp_bone = bone_x[b, :, :, :].unsqueeze(0)
            tmp_bone_pr = bone_pr[b, :, :, :].unsqueeze(0)
            
            __cbct = (x.squeeze().cpu().numpy() * 255).astype(np.uint8)
            __ct = (tmp_y.squeeze().cpu().numpy() * 255).astype(np.uint8)
            __ct_pred = (tmp_y_pr.squeeze().cpu().numpy() * 255).astype(np.uint8)
            __air = (tmp_air.squeeze().cpu().numpy() * 255).astype(np.uint8)
            __air_pr = (tmp_air_pr.squeeze().cpu().numpy() * 255).astype(np.uint8)
            __bone = (tmp_bone.squeeze().cpu().numpy() * 255).astype(np.uint8)
            __bone_pr = (tmp_bone_pr.squeeze().cpu().numpy() * 255).astype(np.uint8)
 
            if save:
                path_dir = os.path.join(path, "file_{}".format(iid))
                try:
                    os.mkdir(path_dir)
                except FileExistsError:
                    pass
                
                cv2.imwrite(os.path.join(path_dir, "cbct.jpg"), __cbct)
                cv2.imwrite(os.path.join(path_dir, "ct.jpg"), __ct)
                cv2.imwrite(os.path.join(path_dir, "ct_pred.jpg"), __ct_pred)
                cv2.imwrite(os.path.join(path_dir, "air.jpg"), __air)
                cv2.imwrite(os.path.join(path_dir, "air_pred.jpg"), __air_pr)
                cv2.imwrite(os.path.join(path_dir, "bone.jpg"), __bone)
                cv2.imwrite(os.path.join(path_dir, "bone_pred.jpg"), __bone_pr)
                
                iid += 1

# Read Data

In [15]:
import glob
from codes.dataset import DicomDataset, DicomsDataset
import codes.augmentation as aug
from codes.RegGAN.Cyclegan import *

In [16]:
# run_name = wandb.run.name
run_name = "sleek-tree-25"
ELECTRON = False
G_COORD = False
L_COORD= False

In [17]:
VIEW_BOUND = (-500, 500)
AIR_BOUND = (-500, -499)
BONE_BOUND = (255, 256)
if ELECTRON:
    VIEW_BOUND = (0.5, 1.5)
    AIR_BOUND = (0.5, 0.5009)
    BONE_BOUND = (1.2, 1.2009)  

In [18]:
config = {
    # lamda weight
    "Adv_lamda": 1,
    "Cyc_lamda": 10,
    "Corr_lamda": 20,
    "Smooth_lamda": 10,

    "n_epoch": 80,        # starting epoch
    "batchSize": 4,               # size of the batches
    "lr": 0.0001,                   # initial learning rate
    "decay_epoch": 20,            # epoch to start linearly decaying the learning rate to 0
    "input_nc": 1,         
    "output_nc": 1,                         
    "n_cpu": 1,
    "size": 256,
    "cuda": True,
}

In [19]:
checkpoint = torch.load(os.path.join("weight-gan", "{}.pth".format(run_name)))

In [20]:
netG_A2B = Generator(config['input_nc'], config['output_nc'])
netG_A2B.load_state_dict(checkpoint["netG_A2B"])

<All keys matched successfully>

# Pelvic

## Pelvic test

In [21]:
test_case_path = 'raw/test/*_*'
paths = sorted(glob.glob(test_case_path))

In [22]:
try:
    os.mkdir(os.path.join("eval-gan", run_name))
except FileExistsError:
    pass

In [23]:
try:
    os.mkdir(os.path.join("eval-gan", run_name, "c_test"))
except FileExistsError:
    pass

In [24]:
# read cbct and ct
for i in range(0, len(paths), 2):
    scans = DicomDataset(cbct_path=paths[i+1], ct_path=paths[i],
                         geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, 
                         identity=False, electron=ELECTRON, position="pelvic", g_coord=G_COORD, l_coord=L_COORD)
    patient_id = scans.patientID()
    testloader = torch.utils.data.DataLoader(scans, batch_size=1, shuffle=False, num_workers=4)
    try:
        os.mkdir(os.path.join("eval-gan", run_name, "c_test", patient_id))
    except FileExistsError:
        pass
    test_epoch(netG_A2B, testloader, device, True, os.path.join("eval-gan", run_name, "c_test", patient_id)) 

28it [00:03,  9.11it/s]
27it [00:01, 14.79it/s]
26it [00:01, 14.62it/s]
28it [00:01, 14.83it/s]
27it [00:01, 14.79it/s]
28it [00:01, 14.58it/s]
27it [00:01, 14.70it/s]


In [25]:
testset = DicomsDataset(test_case_path, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, 
                        identity=False, electron=ELECTRON, position="pelvic", g_coord=G_COORD, l_coord=L_COORD)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)
eval_epoch(netG_A2B, testloader, device)

191it [00:13, 14.11it/s]


(0.7201153152276084,
 20.503642436721563,
 11.507914543151855,
 0.04118932826512771,
 0.9956907481380713,
 0.7974198721891713,
 0.0801106410806544)

## Pelvic L1

In [26]:
test_case_path = 'L1_pelvic_processed/test/*_*'
paths = sorted(glob.glob(test_case_path))

In [27]:
try:
    os.mkdir(os.path.join("eval-gan", run_name))
except FileExistsError:
    pass

In [28]:
try:
    os.mkdir(os.path.join("eval-gan", run_name, "L1_test"))
except FileExistsError:
    pass

In [29]:
# read cbct and ct
for i in range(0, len(paths), 2):
    scans = DicomDataset(cbct_path=paths[i+1], ct_path=paths[i],
                         geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, 
                         identity=False, electron=ELECTRON, position="pelvic", g_coord=G_COORD, l_coord=L_COORD)
    patient_id = scans.patientID()
    testloader = torch.utils.data.DataLoader(scans, batch_size=1, shuffle=False, num_workers=4)
    try:
        os.mkdir(os.path.join("eval-gan", run_name, "L1_test", patient_id))
    except FileExistsError:
        pass
    test_epoch(netG_A2B, testloader, device, True, os.path.join("eval-gan", run_name, "L1_test", patient_id)) 

26it [00:01, 14.49it/s]
30it [00:02, 14.79it/s]
26it [00:01, 14.46it/s]
32it [00:02, 14.96it/s]
36it [00:02, 14.62it/s]
28it [00:01, 14.55it/s]
27it [00:01, 14.37it/s]
29it [00:01, 14.75it/s]
26it [00:01, 14.46it/s]
27it [00:01, 14.40it/s]
27it [00:01, 14.70it/s]


In [30]:
testset = DicomsDataset(test_case_path, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, 
                        identity=False, electron=ELECTRON, position="pelvic", g_coord=G_COORD, l_coord=L_COORD)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)
eval_epoch(netG_A2B, testloader, device)

314it [00:22, 14.04it/s]


(0.7324433286858213,
 20.928894674702054,
 11.370021000409581,
 0.04111557270572254,
 0.9921685395604201,
 0.7624190412749761,
 0.06701108271418264)

## Pelvic L2

In [31]:
test_case_path = 'L2_pelvic_processed/test/*_*'
paths = sorted(glob.glob(test_case_path))

In [32]:
try:
    os.mkdir(os.path.join("eval-gan", run_name))
except FileExistsError:
    pass

In [33]:
try:
    os.mkdir(os.path.join("eval-gan", run_name, "L2_test"))
except FileExistsError:
    pass

In [34]:
# read cbct and ct
for i in range(0, len(paths), 2):
    scans = DicomDataset(cbct_path=paths[i+1], ct_path=paths[i],
                         geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, 
                         identity=False, electron=ELECTRON, position="pelvic", g_coord=G_COORD, l_coord=L_COORD)
    patient_id = scans.patientID()
    testloader = torch.utils.data.DataLoader(scans, batch_size=1, shuffle=False, num_workers=4)
    try:
        os.mkdir(os.path.join("eval-gan", run_name, "L2_test", patient_id))
    except FileExistsError:
        pass
    test_epoch(netG_A2B, testloader, device, True, os.path.join("eval-gan", run_name, "L2_test", patient_id)) 

31it [00:02, 14.89it/s]
27it [00:01, 14.63it/s]
28it [00:01, 14.82it/s]
27it [00:01, 14.83it/s]
27it [00:01, 14.62it/s]
26it [00:01, 14.65it/s]
30it [00:02, 14.79it/s]
27it [00:01, 14.50it/s]
29it [00:01, 14.82it/s]
26it [00:01, 14.64it/s]
26it [00:01, 14.64it/s]
27it [00:01, 13.94it/s]
26it [00:01, 14.56it/s]
27it [00:01, 14.38it/s]
26it [00:01, 14.55it/s]
30it [00:02, 14.85it/s]
26it [00:01, 14.56it/s]
26it [00:01, 14.60it/s]
26it [00:01, 14.49it/s]
27it [00:01, 14.55it/s]
27it [00:01, 14.64it/s]
26it [00:01, 14.41it/s]
29it [00:01, 14.52it/s]
29it [00:01, 14.51it/s]


In [35]:
testset = DicomsDataset(test_case_path, geometry_aug=aug.get_validation_augmentation(), intensity_aug=None, 
                        identity=False, electron=ELECTRON, position="pelvic", g_coord=G_COORD, l_coord=L_COORD)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)
eval_epoch(netG_A2B, testloader, device)

656it [00:47, 13.93it/s]


(0.7182728342893647,
 19.89013300581676,
 10.344775846455155,
 0.0447093009375172,
 0.9951992599975208,
 0.7572215146713059,
 0.054645513602516425)