In [None]:
import os
import sys
import multiprocessing

import torch
import numpy as np
import torchvision
import pandas as pd
import SimpleITK as sitk
import pytorch_lightning as pl
import matplotlib.pyplot as plt

# Move to top of ArtifactNet directory so relative imports can be used here
sys.path.insert(0, "/cluster/home/carrowsm/ArtifactNet")

# Exploring and Testing ArtifactNet
Use this notebook to load different models and pass sample images to them.

## Define the dataloaders and CSVs containing the labels
ArtifactNet requires a CSV for each dataset which labels the DA class of the images. In the examples below the class labels are `"2"`, `"1"`, and `"0"` for 'strong', 'weak', and 'none' respectively. These can be customized, as the actual labels are passed to the `load_image_data_frame` function which in turn selects the appropriate rows from the CSV. The data loading tools used in ArtifactNet are described in more detail below.

### The load_image_data_frame() function
This is a function which takes a dataframe containing the names of the image files (without file extensions) and the corresponding DA status label. The CSV is loaded into a pandas dataframe and the rows are grouped according to their DA status to form an unpaired set of images in domain $X$ and $Y$. These groups are then randomly split to form a training and validation set, with equal proportions of images from domain $X$ and $Y$ in each set.

The function must be given a path to a CSV and the labels (as they appear in the CSV) to use for each domain. For example, if we want to create a data set containing 'strong' or 'weak' DA images in domain $X$ and only no-DA images in domain $Y$, we would write
```
x_trg, x_val, y_trg, y_val = load_image_data_frame(csv_path, ["2", "1"], ["0"])
```
To get pandas dataframes contaning the file names for images in the two domains of the training and validation set. To get the actual images, we have to pass these to a DataLoader object.

### The Data Loaders: PairedDataset, UnpairedDataset
These take the dataframes created by `load_image_data_frame()` and generate the actual pytorch datasets. `UnpairedDataset` is instantiated by passing it the data frames for domain $X$ and $Y$. This means each train, val, and test set requires defining its own dataset instance.

Once instantiated, both `UnpairedDataset` and `PairedDataset` will create a cache of images preprocessed and cropped to the model input size and voxel spacing. If such a cache already exists, they will skip this step. When called, `UnpairedDataset` will return an image from domain $X$ and a randomly selected image from domain $Y$. Conversely, `PairedDataset` will return an image from $X$ and the paired image from $Y$, according to the ordering of the CSVs.

In [None]:
# Import dataloaders
from data.data_loader import PairedDataset, UnpairedDataset, load_image_data_frame
from data.transforms import AffineTransform, ToTensor, Normalize, HorizontalFlip

# Locations of CSVs containing relevant info about the scans (e.g. DA status)
phantom_csv = "/cluster/home/carrowsm/ArtifactNet/datasets/phantoms.csv"
csv_path = "/cluster/home/carrowsm/ArtifactNet/datasets/train_labels.csv"

# Load data frames with image labels
x_df, _, y_df, _ = load_image_data_frame(csv_path, ["2", "1"], ["0"], val_split=0)
val_x_df, _, val_y_df, _ = load_image_data_frame(phantom_csv, ["2", "1"], ["0"], val_split=0)

# Where to read the original full scans from (img_dir) and 
# where to save/read a cache of preprocessed images (cache_dir)
img_dir = "/cluster/projects/radiomics/RADCURE-images/"
cache_dir = "/cluster/projects/radiomics/Temp/colin/cyclegan_data/2-1-1mm_nrrd/"
phantom_dir = "/cluster/projects/radiomics/Temp/colin/cyclegan_data/phantom_img/"

In [None]:
# Print some stats for the datasets
print("DATA STATS\n----------")
print( f"DA+ images in training set: {len(x_df)}")
print(f"DA- images in training set: {len(y_df)}")
print(f"Number of image pairs in phantom set: {len(val_x_df)}")

In [None]:
# Data features
n_cpus = 4                            # Number of cores to use for data loading
img_shape = [16, 256, 256]             # Size of image for the model (in pixels)
voxel_spacing = [2.0, 1.0, 1.0]        # Physical spacing between 3D pixels (mm)

# Define sequence of transforms
trg_transform = torchvision.transforms.Compose([
#                             HorizontalFlip(),
#                             AffineTransform(max_angle=30.0, max_pixels=[20, 20]),
                            Normalize(-1000.0, 1000.0),
                            ToTensor()])
val_transform = torchvision.transforms.Compose([
                            Normalize(-1000.0, 1000.0),
                            ToTensor()])

# Initialize the two dataloaders
trg_dataset = UnpairedDataset(x_df, y_df,
                              image_dir=img_dir,
                              cache_dir=os.path.join(cache_dir, "unpaired"),
                              file_type="DICOM",
                              image_size=img_shape,
                              image_spacing=voxel_spacing,
                              dim=3,
                              transform=trg_transform,
                              num_workers=n_cpus)
val_dataset = PairedDataset(val_x_df, val_y_df,
                            image_dir=phantom_dir,
                            cache_dir=os.path.join(phantom_dir, "paired"),
                            file_type="nrrd",
                            image_size=img_shape,
                            image_spacing=voxel_spacing,
                            dim=3,
                            transform=val_transform,
                            num_workers=n_cpus)

## Plot some sample training and validation images
We can load an image simply by calling the dataloader object. The returned image will be a pytorch tensor from each domain $X$ and $Y$. The 3D images will have a size defined by `img_shape` and should be centered on the slice with the strongest artifact.

In [None]:
trg1X, trg1Y = trg_dataset[16]
val1X, val1Y = val_dataset[0]
val2X, val2Y = val_dataset[1]
val3X, val3Y = val_dataset[2]

cm = "Greys_r"
z_index = img_shape[0] // 2

fig, ax = plt.subplots(nrows=2, ncols=2, figsize=[8, 8], facecolor='white')
img1 = ax[0,0].imshow(trg1X[0, z_index, :, :], cmap=cm)
ax[0,0].set_title("Original Strong DA image")
fig.colorbar(img1, ax=ax[0,0], shrink=0.6)

img2 = ax[0,1].imshow(trg1Y[0, z_index, :, :], cmap=cm)
ax[0,1].set_title("Original Weak DA image")
fig.colorbar(img2, ax=ax[0,1], shrink=0.6)

img3 = ax[1,0].imshow(val2X[0, z_index, :, :], cmap=cm)
ax[1,0].set_title("Phantom DA+ image")
fig.colorbar(img3, ax=ax[1,0], shrink=0.6)

img4 = ax[1,1].imshow(val2Y[0, z_index, :, :], cmap=cm)
ax[1,1].set_title("Phantom DA- image")
fig.colorbar(img4, ax=ax[1,1], shrink=0.6)

plt.show()

## Test the model
Use a pretrained ArtifactNet model to clean some images. We can use the paired phantom images to test how well the cleaning worked.

In [None]:
def load_model(module: pl.LightningModule, checkpoint_path: str) :
    """This function loads a model checkpoint (pretrained model) and 
    returns the frozen model parameters as a PyTorch-Lightning module.
    """
    model = module.load_from_checkpoint(checkpoint_path)
    model.eval()
    return model

In [None]:
# Import model
from cycleGAN import GAN
checkpoint_path="/cluster/home/carrowsm/logs/cycleGAN/16_256_256px\
/2u1-0/version_0/checkpoints/epoch=62.ckpt"

if torch.cuda.is_available() :
    n_gpus = torch.cuda.device_count()
    print(f"{n_gpus} GPUs are available")
    device = torch.device('cuda')
else :
    device = torch.device('cpu')

model = load_model(GAN, checkpoint_path)    
model.g_y.to(device)
generator = model.g_y          # We only want to use the X -> Y generator
del model, GAN                 # Free up memory

In [None]:
# Create some "clean" test images
### WARNING: THIS IS MEMORY INTENSIVE ###
with torch.no_grad() :
    # Forward pass through the generator
    val1_gen_y = generator(val1X.unsqueeze(0).to(device)).to(torch.device('cpu'))
    val2_gen_y = generator(val2X.unsqueeze(0).to(device)).to(torch.device('cpu'))
    val3_gen_y = generator(val3X.unsqueeze(0).to(device)).to(torch.device('cpu'))
    trg1_gen_y = generator(trg1X.unsqueeze(0).to(device)).to(torch.device('cpu'))
    
    # Move the outputs back to the CPU
    val1_gen_y = val1_gen_y.to(torch.device('cpu'))
    val2_gen_y = val2_gen_y.to(torch.device('cpu'))
    val3_gen_y = val3_gen_y.to(torch.device('cpu'))
    trg1_gen_y = trg1_gen_y.to(torch.device('cpu'))

In [None]:
# Plot comparison between clean and original
## Phantom Images
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=[8*2, 8], facecolor='white')
img = ax[0,0].imshow(val1Y[0, z_index, :, :], cmap=cm)
ax[0,0].set_title("Phantom 1 - Original DA-")
fig.colorbar(img, ax=ax[0,0], shrink=1)

img = ax[0,1].imshow(val1X[0, z_index, :, :], cmap=cm)
ax[0,1].set_title("Phantom 1 - Original DA+")
fig.colorbar(img, ax=ax[0,1], shrink=1)

img = ax[0,2].imshow(val1_gen_y[0, 0, z_index, :, :].numpy(), cmap=cm)
ax[0,2].set_title("Phantom 1 - Generated DA-")
fig.colorbar(img, ax=ax[0,2], shrink=1)

img = ax[1,0].imshow(val2Y[0, z_index-1, :, :], cmap=cm)
ax[1,0].set_title("Phantom 2 - Original DA-")
fig.colorbar(img, ax=ax[1,0], shrink=1)

img = ax[1,1].imshow(val2X[0, z_index-1, :, :], cmap=cm)
ax[1,1].set_title("Phantom 2 - Original DA+")
fig.colorbar(img, ax=ax[1,1], shrink=1)

img = ax[1,2].imshow(val2_gen_y[0, 0, z_index-1, :, :].numpy(), cmap=cm)
ax[1,2].set_title("Phantom 2 - Generated DA-")
fig.colorbar(img, ax=ax[1,2], shrink=1)
plt.show()


# Real Patient Image
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=[8*2, 8], facecolor='white')
img = ax[0].imshow(trg1X[0, z_index, :, :], cmap=cm)
ax[0].set_title("Sample Patient - Original DA+")
fig.colorbar(img, ax=ax[0], shrink=0.78)

img = ax[1].imshow(trg1_gen_y[0, 0, z_index, :, :].numpy(), cmap=cm)
ax[1].set_title("Sample Patient - Generated DA-")
fig.colorbar(img, ax=ax[1], shrink=0.78)
plt.show()

In [None]:
val1Y = val1Y.reshape(1, 1, 16, 256, 256)
val2Y = val2Y.reshape(1, 1, 16, 256, 256)
val3Y = val3Y.reshape(1, 1, 16, 256, 256)
val1X = val1X.reshape(1, 1, 16, 256, 256)
val2X = val2X.reshape(1, 1, 16, 256, 256)
val3X = val3X.reshape(1, 1, 16, 256, 256)
mse = torch.nn.MSELoss(reduction='mean')

# Test how well the paired images were cleaned
print( mse(val1X, val1Y), mse(val1_gen_y, val1Y) )
print(val1Y.mean(), val1X.mean(), val1_gen_y.mean())
print(val1Y.std(), val1X.std(), val1_gen_y.std())

print("\n")
print( mse(val2X, val2Y), mse(val2_gen_y, val2Y) )
print(val2Y.mean(), val2X.mean(), val2_gen_y.mean())
print(val2Y.std(), val2X.std(), val2_gen_y.std())

print("\n")
print( mse(val3X, val3Y), mse(val3_gen_y, val3Y) )
print(val3Y.mean(), val3X.mean(), val3_gen_y.mean())
print(val3Y.std(), val3X.std(), val3_gen_y.std())

plt.figure()
plt.imshow((val2X - val2Y)[0, 0, z_index, :, :])
plt.show()

plt.figure()
plt.imshow((val2_gen_y - val2Y)[0, 0, z_index, :, :])
plt.show()

## Postprocessing
ArtifactNet has a PostProcessor object which is used to reinsert the subvolume cleaned by the network back into the full original SITK image.

In [None]:
## Test post processor
from data.postprocessing import PostProcessor

In [None]:
csv_path = "datasets/radcure_challenge_test.csv"
orig_img_dir = "/cluster/projects/radiomics/RADCURE-images/"
out_img_dir = "."

postprocess = PostProcessor(orig_img_dir, out_img_dir,
                            input_spacing=voxel_spacing,
                            output_spacing="orig",
                            input_file_type="dicom",
                            output_file_type="nrrd")

In [None]:
patient_id = trg_dataset.x_ids[16]
img_center_x = float(x_df.at[patient_id, "img_center_x"])
img_center_y = float(x_df.at[patient_id, "img_center_y"])
img_center_z = float(x_df.at[patient_id, "img_center_z"])
postprocess(trg1_gen_y.reshape(16, 256, 256), patient_id, 
            [img_center_x, img_center_y, img_center_z])

In [None]:
### Load recreated image
img = sitk.ReadImage(f"{patient_id}.nrrd")
ximg = sitk.GetArrayFromImage(img)

In [None]:
plt.figure()
plt.imshow(trg1_gen_y.reshape(16, 256, 256).numpy()[8, :, :], aspect='equal', cmap=cm)
plt.show()

plt.figure()
plt.imshow(ximg[::-1, 190, :], aspect='auto', cmap=cm)
plt.colorbar()
plt.show()

plt.figure()
plt.imshow(ximg[::-1, :, 256], aspect='auto', cmap=cm)
plt.colorbar()
plt.show()

plt.figure()
plt.imshow(ximg[int(x_df.at[patient_id, "a_slice"]), :, :], aspect='equal', cmap=cm)
plt.colorbar()
plt.show()

### Check available datasets

In [None]:
# Check how many images can be used to test OAR segmentation
oar_df = pd.read_csv("../datasets/test_fold_oar_1.csv",usecols=["index", "0"],index_col="index")
gtv_df = pd.read_csv("../datasets/test_fold_gtv_1.csv",usecols=["index", "0"],index_col="index")
da_test_df = pd.read_csv("../datasets/test_labels.csv", dtype=str).set_index("patient_id")
da_train_df = pd.read_csv("../datasets/train_labels.csv", dtype=str).set_index("patient_id")

# Make column with MRN as index in each df
oar_df["mrn"] = oar_df.loc[:, "0"].str[-12:-5]
oar_df.set_index("mrn", inplace=True)
gtv_df["mrn"] = gtv_df.loc[:, "0"].str[-12:-5]
gtv_df.set_index("mrn", inplace=True)

In [None]:
# common_gtv = da_test_df[da_test_df.index.isin(gtv_df.index)]
# common_oar = da_test_df[da_test_df.index.isin(oar_df.index)]
# full_da = pd.concat([da_train_df, da_test_df])

# full_oar = full_da[full_da.index.isin(oar_df.index)]
# full_oar.to_csv("../datasets/oar_segment_imgs.csv")

# full_gtv = full_da[full_da.index.isin(gtv_df.index)]
# full_gtv.to_csv("../datasets/gtv_segment_imgs.csv")