In [2]:
import logging
import argparse
import os
import pandas as pd

from lightning.pytorch.tuner import Tuner
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch
from monai.inferers import sliding_window_inference
from skimage import measure

import experiments_items.nets
from monai.metrics.meandice import compute_dice
from config.config import Config
from preprocessing.covid_dataset import CovidDataset
import monai.data
from monai.data import DataLoader, decollate_batch
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from preprocessing.transforms import get_hrct_transforms, get_cbct_transforms, \
    get_val_hrct_transforms, get_val_cbct_transforms
from utils.custom_callbacks import CustomTimingCallback
from utils.helpers import load_images_from_path, check_dataset
from config.constants import (ZENODO_COVID_CASES_PATH, ZENODO_INFECTION_MASKS_PATH, SEED, VALIDATION_INFERENCE_ROI_SIZE,
                              SPATIAL_SIZE,
                              ZENODO_LUNG_MASKS_PATH, EXPERIMENTS_PATH, COVID_PREPROCESSED_CASES_PATH,
                              INFECTION_PREPROCESSED_MASKS_PATH, SPACING)
import torch
import numpy as np
from monai.metrics import DiceMetric
import lightning as L
import numpy as np
import matplotlib.pyplot as plt
from skimage import measure
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import matplotlib.pyplot as plt
import pyvista as pv
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import scipy.ndimage as ndi
import plotly.graph_objects as go


2024-07-23 13:02:08.408334: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
def plot_covid_volumes(ground_truth, predictions, save_path="test.png"):
    """
    Plots contours of ground truth and fills areas of predictions.

    Parameters:
    - ground_truth: A 3D numpy array (or metatensor) with binary values (0s and 1s) representing the ground truth.
    - predictions: A 3D numpy array (or metatensor) with binary values (0s and 1s) representing the model's predictions.
    - save_path: A string path to save the resulting plotly figure.
    """

    # Convert tensors to numpy arrays if they are not already
    if torch.is_tensor(ground_truth):
        ground_truth = ground_truth.cpu().numpy()
    if torch.is_tensor(predictions):
        predictions = predictions.cpu().numpy()

    # Ensure input arrays are binary (0s and 1s)
    assert np.all(np.isin(ground_truth, [0, 1])), "Ground truth should be binary (0s and 1s)"
    assert np.all(np.isin(predictions, [0, 1])), "Predictions should be binary (0s and 1s)"

    # Create the figure
    fig = go.Figure()
        
    # paint the prediction volume, such that we paint the whole volume with all slices, where the values are 1
    fig.update(data=[
                go.Volume(
                    x=np.repeat(np.arange(predictions.shape[0]), predictions.shape[1] * predictions.shape[2]),
                    y=np.tile(np.repeat(np.arange(predictions.shape[1]), predictions.shape[2]), predictions.shape[0]),
                    z=np.tile(np.arange(predictions.shape[2]), predictions.shape[0] * predictions.shape[1]),
                    value=predictions.flatten(),
                    opacity=0.1, # adjust for visualization
                    surface_count=17, # adjust for visualization
                    colorscale='Viridis' # or any other color scale
                )
            ]
    )

    
    fig.update_layout(
        title='3D Metatensor Visualization',
        scene=dict(
            xaxis_title='X Axis',
            yaxis_title='Y Axis',
            zaxis_title='Z Axis'
        )
    )
    
    plt.savefig(save_path)   
    

def plot_3d_volumes(ground_truth, prediction, save_path):
    """
    Plot 3D volumes of the ground truth and prediction tensors and save the plot.
    
    Parameters:
    ground_truth (torch.Tensor): 4D tensor (CHWD) representing the ground truth.
    prediction (torch.Tensor): 4D tensor (CHWD) representing the prediction.
    save_path (str): Path to save the 3D visualization.
    """
    # Extract the volume data from the tensors and convert to numpy arrays
    gt_volume = ground_truth[0].cpu().numpy()
    pred_volume = prediction[0].cpu().numpy()
    
    # Find the indices of the 'ones' in the ground truth and prediction volumes
    gt_indices = np.argwhere(gt_volume == 1)
    pred_indices = np.argwhere(pred_volume == 1)
    
    # Set up the figure and 3D axis
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    # Plot the ground truth volume
    if gt_indices.size > 0:
        gt_x, gt_y, gt_z = gt_indices[:, 0], gt_indices[:, 1], gt_indices[:, 2]
        ax.scatter(gt_x, gt_y, gt_z, color='blue', alpha=0.01, label='Ground Truth')

    # Plot the prediction volume
    if pred_indices.size > 0:
        pred_x, pred_y, pred_z = pred_indices[:, 0], pred_indices[:, 1], pred_indices[:, 2]
        ax.scatter(pred_x, pred_y, pred_z, color='red', alpha=0.2, label='Prediction')

    # Set labels
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title('3D Volumes of Ground Truth and Prediction')
    ax.legend()
    
    # Save the plot to the specified path
    plt.savefig(save_path, dpi=300)
    plt.close()
    

def plot_3d_volumes2(ground_truth, prediction, save_path):
    """
    Plot 3D volumes of the ground truth and prediction tensors and save the plot with volume rendering.
    Ground truth is visualized as contours, and prediction is visualized as points.
    
    Parameters:
    ground_truth (torch.Tensor): 4D tensor (CHWD) representing the ground truth.
    prediction (torch.Tensor): 4D tensor (CHWD) representing the prediction.
    save_path (str): Path to save the 3D visualization.
    """
    # Extract the volume data from the tensors and convert to numpy arrays
    gt_volume = ground_truth[0].cpu().numpy()
    pred_volume = prediction[0].cpu().numpy()

    # Downsample the volumes to speed up processing
    def downsample_volume(volume, factor):
        zoom_factors = [1/factor, 1/factor, 1]  # Downsample spatial dimensions
        return ndi.zoom(volume, zoom_factors, order=1)  # Bilinear interpolation

    gt_volume_ds = downsample_volume(gt_volume, 4)
    pred_volume_ds = downsample_volume(pred_volume, 4)

    # Set up the figure and 3D axis
    fig = plt.figure(figsize=(12, 10))
    ax = fig.add_subplot(111, projection='3d')

  # Function to add contours to the plot
    def add_volume(ax, volume, color, alpha=0.3):
        # Compute contours
        verts, faces, _, _ = measure.marching_cubes(volume, level=0.5)
        mesh = Poly3DCollection(verts[faces], alpha=alpha, facecolors=color, linewidths=0.1)
        ax.add_collection3d(mesh)

    # Plot the ground truth contours
    if np.any(gt_volume_ds):
        add_volume(ax, gt_volume_ds, 'blue', alpha=0.3)
    
    # Plot the prediction points
    gt_indices = np.argwhere(gt_volume_ds == 1)
    if gt_indices.size > 0:
        gt_x, gt_y, gt_z = gt_indices[:, 0], gt_indices[:, 1], gt_indices[:, 2]
        ax.scatter(gt_x, gt_y, gt_z, color='blue', alpha=0.01, label='Ground Truth')
    
    # Set labels and aspect
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title('3D Volumes of Ground Truth and Prediction')

    # Set aspect ratio
    ax.set_box_aspect([gt_volume_ds.shape[0], gt_volume_ds.shape[1], gt_volume_ds.shape[2]])

    # Add a legend
    ax.legend()

    # Save the plot to the specified path
    plt.savefig(save_path)
    plt.close()
    
    
def plot_3d_volumes3(ground_truth, prediction, save_path):
    # using plotply
    # Extract the volume data from the tensors and convert to numpy arrays
    gt_volume = ground_truth[0].cpu().numpy()
    pred_volume = prediction[0].cpu().numpy()
    
    # Find the indices of the 'ones' in the ground truth and prediction volumes
    gt_indices = np.argwhere(gt_volume == 1)
    pred_indices = np.argwhere(pred_volume == 1)
    
    # Set up the figure and 3D axis
    fig = go.Figure()
    
    # Plot the prediction volume
    if pred_indices.size > 0:
        pred_x, pred_y, pred_z = pred_indices[:, 0], pred_indices[:, 1], pred_indices[:, 2]
        fig.add_trace(go.Scatter3d(x=pred_x, y=pred_y, z=pred_z, mode='markers', opacity=0.1, marker=dict(size=2, color='red'), name='Prediction'))

    if gt_indices.size > 0:
        print(f"shape of gt_indices: {gt_indices.shape}")
        pred_x, pred_y, pred_z = gt_indices[:, 0], gt_indices[:, 1], gt_indices[:, 2]
        fig.add_trace(go.Scatter3d(x=pred_x, y=pred_y, z=pred_z, mode='markers', opacity=0.1, marker=dict(size=2, color='blue'), name='Ground Truth'))

    # Set layout
    fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'), title='3D Volumes of Ground Truth and Prediction')
        
    # plot
    fig.show()
    
def burn_masks_in_ct(ct,mask,predictions, path_to_save=None):
    # Definir los valores de las clases
    BACKGROUND = 0
    CLASS1 = 1
 
    # Cambiar los valores de la matriz "predictions_without_black" a los valores de las clases
    predictions_without_black = np.where(predictions == BACKGROUND, np.nan, CLASS1)
 
    # Crear el mapa de colores personalizado
    cmap = ListedColormap(['yellow', 'black'])
 
    # Crear la figura y los ejes
    dpi = 80
    figsize = 512 / dpi
 
    # Crear la figura y los ejes
    fig, ax = plt.subplots(figsize=(figsize + 1, figsize + 1.2), dpi=dpi)
 
    # Mostrar la imagen con el mapa de colores personalizado
    ax.imshow(ct, cmap="gray")
    ax.imshow(predictions_without_black, alpha=0.6, cmap=cmap)
 
    # Encontrar los contornos de la máscara
    contours = measure.find_contours(mask)
 
    # Dibujar los contornos
    for contour in contours:
        ax.plot(contour[:, 1], contour[:, 0], linewidth=1, color="blue")
 
    # Añadir el título
    #ax.set_title("Dice score: {:.3%}".format(dice_score), fontsize=16)
 
    # Crear la leyenda con texto y colores
    class_patches = [Patch(color='yellow', label='Prediction')]
    line_patches = [Patch(color='blue', label='Mask contour')]
 
    legend = ax.legend(handles=class_patches + line_patches, loc='upper left')
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_linewidth(0.0)
    legend.get_frame().set_alpha(0.8)  # Agregar transparencia al marco
    legend.get_texts()[0].set_fontsize(8)  # Ajustar el tamaño del texto
    legend.get_texts()[1].set_fontsize(8)
 
    # Quitar los ejes xticks e yticks
    ax.set_xticks([])
    ax.set_yticks([])
 
    # Guardar la figura
    plt.savefig(path_to_save, bbox_inches='tight')
 
    # Cerrar la figura
    plt.close()


In [29]:
class Net(L.pytorch.LightningModule):
    def __init__(self, learning_rate: float, model: torch.nn.Module, loss_function: torch.nn, volumes_path: str,
                 masks_path: str, experiment_name: str):
        super(Net, self).__init__()

        # volumes paths
        self.volumes_path = volumes_path
        self.masks_path = masks_path

        # Model, loss function and learning rate
        self.model = model
        print(f"Using model: {type(self.model)}")
        self.loss_function = loss_function
        print(f"Using loss: {type(self.loss_function)}")
        self.learning_rate = learning_rate
        print(f"Using lr: {learning_rate}")

        self.experiment_name = experiment_name

        self.save_hyperparameters(ignore=["model", "loss_function", "volumes_path", "masks_path", "experiment_name"])

        # Define the post-processing transforms
        self.post_pred = monai.transforms.Compose(
            [monai.transforms.EnsureType(data_type='tensor'), monai.transforms.Activations(sigmoid=True), monai.transforms.AsDiscrete(threshold=0.5)])
        self.post_label = monai.transforms.Compose([monai.transforms.AsDiscrete(threshold=0.5)])

        # Dice metric
        self.dice_metric = DiceMetric(include_background=True, reduction="mean")
        self.train_dice_metric = DiceMetric(include_background=True, reduction="mean")

        # Surface dice metric
        self.surface_dice_metric = monai.metrics.SurfaceDiceMetric(include_background=True, distance_metric="euclidean", class_thresholds=[1.0])
        self.train_surface_dice_metric = monai.metrics.SurfaceDiceMetric(include_background=True, distance_metric="euclidean", class_thresholds=[1.0])

        # Haussdorf metric
        self.haussdorf_metric = monai.metrics.HausdorffDistanceMetric(include_background=True, distance_metric="euclidean", percentile=95)
        self.train_haussdorf_metric = monai.metrics.HausdorffDistanceMetric(include_background=True, distance_metric="euclidean", percentile=95)

        # IoU metric
        self.iou_metric = monai.metrics.MeanIoU(include_background=True)
        self.train_iou_metric = monai.metrics.MeanIoU(include_background=True)

        # Best validation dice and epoch
        self.best_val_dice = 0
        self.best_val_epoch = 0

        # Losses lists
        self.validation_step_outputs = []
        self.train_step_outputs = []
        self.test_step_outputs = []

        # CSV data frame for metrics
        self.train_val_dump_data_frame = []
        self.test_dump_data_frame = []

        # Paths lists for datasets
        self.test_paths = None
        self.val_paths = None
        self.train_paths = None

        # Datasets
        self.training_ds = None
        self.validation_ds = None
        self.test_ds = None

    def forward(self, x):
        return self.model(x)

    def prepare_data(self) -> None:
        # Load images and masks
        logging.info(f"Loading images from {self.volumes_path} and masks from {self.masks_path}")
        images = load_images_from_path(self.volumes_path)
        labels = load_images_from_path(self.masks_path)

        # Take only the images that are from Mosmed
        train_images = [image for image in images if "radiopaedia" not in image and "coronacases" not in image]
        train_labels = [label for label in labels if "radiopaedia" not in label and "coronacases" not in label]

        # Convert images and masks to a list of dictionaries with keys "img" and "mask"
        data_train_dicts = np.array([{"img": img, "mask": mask} for img, mask in zip(train_images, train_labels)])
        logging.debug(data_train_dicts)

        # Shuffle the data
        shuffler = np.random.RandomState(SEED)
        shuffler.shuffle(data_train_dicts)
        data_train_dicts = list(data_train_dicts)

        # Split the training data into training and validation
        val_split = int(len(data_train_dicts) * 0.2)

        self.train_paths = data_train_dicts[val_split:]
        self.val_paths = data_train_dicts[:val_split]

        # Take coronacases and radiopeadia images for testing
        test_images = [image for image in images if "radiopaedia" in image or "coronacases" in image]
        test_labels = [label for label in labels if "radiopaedia" in label or "coronacases" in label]
        self.test_paths = np.array([{"img": img, "mask": mask} for img, mask in zip(test_images, test_labels)])

    def setup(self, stage: str) -> None:
        if stage == "fit" or stage is None:
            # Define the CovidDataset instances for training, validation, and test
            self.training_ds = CovidDataset(volumes=self.train_paths, hrct_transform=get_hrct_transforms(),
                                            cbct_transform=get_cbct_transforms())
            self.validation_ds = CovidDataset(volumes=self.val_paths, hrct_transform=get_val_hrct_transforms(),
                                              cbct_transform=get_val_cbct_transforms())
            # Check the dataset
            print("Checking the validation dataset")
            # check_dataset(self.validation_ds)

        if stage == "validate" or stage is None:
            self.validation_ds = CovidDataset(volumes=self.val_paths, hrct_transform=get_val_hrct_transforms(),
                                              cbct_transform=get_val_cbct_transforms())
            # Check the dataset
            print("Checking the validation dataset")
            # check_dataset(self.validation_ds)

        if stage == "test" or stage is None:
            self.test_ds = CovidDataset(volumes=self.test_paths, hrct_transform=get_val_hrct_transforms(),
                                        cbct_transform=get_val_cbct_transforms())
            # Check the dataset
            print("Checking the test dataset")
            # check_dataset(self.test_ds)

    def train_dataloader(self):
        train_dataloader = DataLoader(self.training_ds, batch_size=1, shuffle=True, num_workers=4)
        return train_dataloader

    def val_dataloader(self):
        val_dataloader = DataLoader(self.validation_ds, batch_size=1, num_workers=4)
        return val_dataloader

    def test_dataloader(self):
        test_dataloader = DataLoader(self.test_ds, batch_size=1, num_workers=4)
        return test_dataloader

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=1e-5)
        return optimizer

    def training_step(self, batch, batch_idx):
        inputs, labels = batch["img"], batch["mask"]

        # Forward pass
        raw_outputs = self.forward(inputs)
        loss = self.loss_function(raw_outputs, labels)
        self.log("ts_loss", loss, on_step=True, on_epoch=False, prog_bar=True, logger=False)
        outputs = [self.post_pred(i) for i in decollate_batch(raw_outputs)]
        labels = [self.post_label(i) for i in decollate_batch(labels)]

        # Dice metric
        self.train_dice_metric(y_pred=outputs, y=labels)
        self.log("ts_dice", self.train_dice_metric.aggregate().item(), on_step=True, on_epoch=False, prog_bar=True, logger=False)

        # NSDice Metric
        self.train_surface_dice_metric(y_pred=outputs, y=labels, spacing=SPACING)
        self.log("ts_surface_dice", self.train_surface_dice_metric.aggregate().item(), on_step=True, on_epoch=False, prog_bar=True, logger=False)

        # Haussdorf Metric
        self.train_haussdorf_metric(y_pred=outputs, y=labels, spacing=SPACING)
        self.log("ts_hd95", self.train_haussdorf_metric.aggregate().item(), on_step=True, on_epoch=False, prog_bar=True, logger=False)

        # IoU Metric
        self.train_iou_metric(y_pred=outputs, y=labels)
        self.log("ts_iou", self.train_iou_metric.aggregate().item(), on_step=True, on_epoch=False, prog_bar=True, logger=False)

        # Store the loss
        train_loss_dictionary = {"loss": loss}
        self.train_step_outputs.append(train_loss_dictionary)

        return loss

    def on_train_epoch_end(self) -> None:
        # Loss
        avg_loss = torch.stack([i["loss"] for i in self.train_step_outputs]).mean()
        self.train_step_outputs.clear()

        # Dice
        mean_train_dice = self.train_dice_metric.aggregate().item()
        self.train_dice_metric.reset()

        # NSDice
        mean_train_surface_dice = self.train_surface_dice_metric.aggregate().item()
        self.train_surface_dice_metric.reset()

        # Haussdorf
        mean_train_haussdorf = self.train_haussdorf_metric.aggregate().item()
        self.train_haussdorf_metric.reset()

        # IoU
        mean_train_iou = self.train_iou_metric.aggregate().item()
        self.train_iou_metric.reset()

        # Log the metrics
        self.log_dict({
            "train_loss": avg_loss,
            "train_dice": mean_train_dice,
            "train_surface_dice": mean_train_surface_dice,
            "train_haussdorf": mean_train_haussdorf,
            "train_iou": mean_train_iou
        }, on_epoch=True, on_step=False, prog_bar=True)

        # Save the metrics to a pandas dataframe
        self.train_val_dump_data_frame[-1].update({
            "train_loss": avg_loss.item(),
            "train_dice": mean_train_dice,
            "train_surface_dice": mean_train_surface_dice,
            "train_haussdorf": mean_train_haussdorf,
            "train_iou": mean_train_iou
        })

        # Log the metrics to tensorboard
        self.logger.experiment.add_scalars("losses", {"train": avg_loss}, self.current_epoch)
        self.logger.experiment.add_scalars("dice", {"train": mean_train_dice}, self.current_epoch)
        self.logger.experiment.add_scalars("surface_dice", {"train": mean_train_surface_dice}, self.current_epoch)
        self.logger.experiment.add_scalars("haussdorf", {"train": mean_train_haussdorf}, self.current_epoch)
        self.logger.experiment.add_scalars("iou", {"train": mean_train_iou}, self.current_epoch)

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch["img"], batch["mask"]

        # Inference
        roi_size = VALIDATION_INFERENCE_ROI_SIZE
        sw_batch_size = 4
        outputs = sliding_window_inference(inputs, roi_size, sw_batch_size, self.forward, overlap=0.6)
        loss = self.loss_function(outputs, labels)
        outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
        labels = [self.post_label(i) for i in decollate_batch(labels)]
        self.validation_step_outputs.append({"val_loss": loss})

        # Validation metrics
        self.dice_metric(y_pred=outputs, y=labels)
        self.surface_dice_metric(y_pred=outputs, y=labels, spacing=SPACING)
        self.haussdorf_metric(y_pred=outputs, y=labels, spacing=SPACING)
        self.iou_metric(y_pred=outputs, y=labels)

    def on_validation_epoch_end(self) -> None:
        avg_loss = torch.stack([x["val_loss"] for x in self.validation_step_outputs]).mean()
        self.validation_step_outputs.clear()

        # Dice
        mean_val_dice = self.dice_metric.aggregate().item()
        self.dice_metric.reset()

        # NSDice
        mean_val_surface_dice = self.surface_dice_metric.aggregate().item()
        self.surface_dice_metric.reset()

        # Haussdorf
        mean_val_haussdorf = self.haussdorf_metric.aggregate().item()
        self.haussdorf_metric.reset()

        # IoU
        mean_val_iou = self.iou_metric.aggregate().item()
        self.iou_metric.reset()

        # Log the metrics
        self.log_dict({
            "val_loss": avg_loss.item(),
            "val_dice": mean_val_dice,
            "val_surface_dice": mean_val_surface_dice,
            "val_haussdorf": mean_val_haussdorf,
            "val_iou": mean_val_iou
        }, prog_bar=True, on_epoch=True, on_step=False)

        # Save the metrics to a pandas dataframe
        self.train_val_dump_data_frame.append({
            "epoch": self.current_epoch,
            "val_loss": avg_loss.item(),
            "val_dice": mean_val_dice,
            "val_surface_dice": mean_val_surface_dice,
            "val_haussdorf": mean_val_haussdorf,
            "val_iou": mean_val_iou
        })

        # Log the metrics to tensorboard
        self.logger.experiment.add_scalars("losses", {"val_loss": avg_loss}, self.current_epoch)
        self.logger.experiment.add_scalars("dice", {"val_dice": mean_val_dice}, self.current_epoch)
        self.logger.experiment.add_scalars("surface_dice", {"val_surface_dice": mean_val_surface_dice}, self.current_epoch)
        self.logger.experiment.add_scalars("haussdorf", {"val_haussdorf": mean_val_haussdorf}, self.current_epoch)
        self.logger.experiment.add_scalars("iou", {"val_iou": mean_val_iou}, self.current_epoch)

        # Save the best model
        if mean_val_dice > self.best_val_dice:
            self.best_val_dice = mean_val_dice
            self.best_val_epoch = self.current_epoch

    def test_step(self, batch, batch_idx):
        inputs, labels = batch["img"], batch["mask"]
        roi_size = VALIDATION_INFERENCE_ROI_SIZE
        sw_batch_size = 4

        outputs = sliding_window_inference(inputs, roi_size, sw_batch_size, self.forward, overlap=0.6)
        loss = self.loss_function(outputs, labels)
        outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
        labels = [self.post_label(i) for i in decollate_batch(labels)]

        # Dice metric
        self.dice_metric(y_pred=outputs, y=labels)
        volume_dice = self.dice_metric.aggregate().item()
        self.dice_metric.reset()

        # Surface dice metric
        self.surface_dice_metric(y_pred=outputs, y=labels, spacing=SPACING)
        volume_surface_dice = self.surface_dice_metric.aggregate().item()
        self.surface_dice_metric.reset()

        # Haussdorf metric
        self.haussdorf_metric(y_pred=outputs, y=labels, spacing=SPACING)
        volume_haussdorf = self.haussdorf_metric.aggregate().item()
        self.haussdorf_metric.reset()

        # IoU metric
        self.iou_metric(y_pred=outputs, y=labels)
        volume_iou = self.iou_metric.aggregate().item()
        self.iou_metric.reset()

        # Create a pandas dataframe with batch_idx, test_loss, and test_metric columns
        df = pd.DataFrame({
            "volume": [batch_idx],
            "test_loss": [loss.item()],
            "test_dice": [volume_dice],
            "test_surface_dice": [volume_surface_dice],
            "test_haussdorf": [volume_haussdorf],
            "test_iou": [volume_iou]
        })
        self.test_dump_data_frame.append(df)
        self.test_step_outputs.append({"test_loss": loss})

        # Paint the figure
        print(f"len(outputs): {len(outputs)}")
        print(f"len(labels): {len(labels)}")
        print(f"shape of outputs: {outputs[0].shape}")
        print(f"shape of labels: {labels[0].shape}")
    
        output_writer = monai.data.NibabelWriter()
        output_writer.set_data_array(outputs[0])
        output_writer.write(f"images/{self.test_ds.volumes[batch_idx]['img'].split('/')[-1].replace('.nii.gz','')}_output.nii.gz")
        
        for j in range(outputs[0].shape[-1]):
            burn_masks_in_ct(inputs[0, 0, :, :, j].cpu().numpy(), labels[0][0, :, :, j].cpu().numpy(), outputs[0][0, :, :, j].cpu().numpy(), path_to_save=f"images/{self.test_ds.volumes[batch_idx]['img'].split('/')[-1].replace('.nii.gz','')}_slice{j}.png")
           
        
    def on_test_epoch_end(self):
        df = pd.concat(self.test_dump_data_frame)
        df.to_csv(EXPERIMENTS_PATH + self.experiment_name + "/test_metrics.csv", header=True, index=False)

    def on_train_end(self) -> None:
        self.metrics_df.to_csv(EXPERIMENTS_PATH + self.experiment_name + "/train__val_metrics.csv", header=True, index=False)
    



In [30]:
# take path from lightning_logs
path = '../Experiments/unetr_dice_2/checkpoints/epoch=178-step=7160.ckpt'

# Set the seed for reproducibility
L.seed_everything(SEED)
torch.manual_seed(SEED)
monai.utils.set_determinism(seed=SEED)

# Experiment name
experiment_name = f"test"
print(f"Experiment name: {experiment_name}")


Seed set to 420


Experiment name: test


In [31]:
net = Net(
    learning_rate=1e-3,
    model=experiments_items.nets.covid_unetr,
    loss_function=DiceLoss(to_onehot_y=True, softmax=True, squared_pred=True),
    volumes_path="../" + COVID_PREPROCESSED_CASES_PATH,
    masks_path="../" + INFECTION_PREPROCESSED_MASKS_PATH,
    experiment_name=experiment_name
)


tensorboard_logger = L.pytorch.loggers.TensorBoardLogger(save_dir="../" + EXPERIMENTS_PATH, name=experiment_name, version="tensorboard",)

callbacks = [
    L.pytorch.callbacks.ModelCheckpoint(
        dirpath="../" + EXPERIMENTS_PATH + experiment_name + "/checkpoints",
        monitor="val_dice",
        save_top_k=1,
        mode="max",
        save_last=True
    ),
    CustomTimingCallback()
]

trainer = L.pytorch.Trainer(
    default_root_dir="../" + EXPERIMENTS_PATH,
    devices=[0],
    accelerator="gpu",
    strategy="auto",
    max_epochs=200,
    logger=tensorboard_logger,
    callbacks=callbacks,
    log_every_n_steps=14,
    deterministic=True,
    num_sanity_val_steps=0,
)

trainer.test(net)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Using model: <class 'monai.networks.nets.unetr.UNETR'>
Using loss: <class 'monai.losses.dice.DiceLoss'>
Using lr: 0.001
Checking the test dataset


Testing: |          | 0/? [00:00<?, ?it/s]

len(outputs): 1
len(labels): 1
shape of outputs: torch.Size([1, 577, 577, 51])
shape of labels: torch.Size([1, 577, 577, 51])


In [41]:
import numpy as np
import plotly.graph_objects as go

# Example 3D tensor with shape (577, 577, 51) with 0s and 1s
tensor = np.random.choice([0, 1], size=(577, 577, 51), p=[0.8, 0.2])

# Ensure there's a mix of 0s and 1s
print("Tensor shape:", tensor.shape)

# Dimensions of the tensor
dims = tensor.shape

# Create x, y, z coordinates for each voxel
x = np.repeat(np.arange(dims[0]), dims[1] * dims[2])
y = np.tile(np.repeat(np.arange(dims[1]), dims[2]), dims[0])
z = np.tile(np.arange(dims[2]), dims[0] * dims[1])

# Flatten the tensor to get the values for each voxel
values = tensor.flatten()

# Create a volume plot
fig = go.Figure(data=go.Volume(
    x=x,
    y=y,
    z=z,
    value=values,
    opacity=0.2,  # Adjust opacity for better visualization
    surface_count=17,  # Adjust for better visualization
    colorscale='Viridis'  # Or any other color scale
))

# Set plot title and axis labels
fig.update_layout(
    title='3D Metatensor Volume Visualization',
    scene=dict(
        xaxis_title='X Axis',
        yaxis_title='Y Axis',
        zaxis_title='Z Axis'
    )
)

# Show the plot
fig.show()


Tensor shape: (577, 577, 51)


KeyboardInterrupt: 