In [15]:
import os
import torch
import numpy as np
from torch.utils.data import DataLoader,TensorDataset
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
!git clone https://github.com/msskzx/FCT
dir_path = '/content/FCT'
os.chdir(dir_path)

# ACDC Dataset

## Download Test Dataset

In [None]:
!wget --output-document=data.zip https://humanheart-project.creatis.insa-lyon.fr/database/api/v1/folder/6372203a73e9f0047faa117e/download
!unzip data.zip
!rm data.zip

## Test One Image Only

Inference on one image only, by providing the link of the patient folder

In [None]:
!wget --output-document=data.zip https://humanheart-project.creatis.insa-lyon.fr/database/api/v1/folder/63721d7073e9f0047faa052a/download
!unzip data.zip
!rm data.zip
!mkdir testing
!mv patient001/ testing/patient001

# Model

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
!cp -r /content/gdrive/MyDrive/models models
drive.flush_and_unmount()
!pip install monai lightning

In [10]:
model = torch.load('models/fct.model')
print(model.loss_fn)

# Testing

In [55]:
from utils.data_utils import get_acdc,convert_masks

# test dataloader
acdc_data, _, _ = get_acdc('testing', input_size=(224, 224, 1))
acdc_data[1] = convert_masks(acdc_data[1])
acdc_data[0] = np.transpose(acdc_data[0], (0, 3, 1, 2)) # for the channels
acdc_data[1] = np.transpose(acdc_data[1], (0, 3, 1, 2)) # for the channels
acdc_data[0] = torch.Tensor(acdc_data[0]) # convert to tensors
acdc_data[1] = torch.Tensor(acdc_data[1]) # convert to tensors
acdc_data = TensorDataset(acdc_data[0], acdc_data[1])
test_loader = DataLoader(acdc_data, batch_size=1, num_workers=2)
print(len(test_loader))

1076


## Prediction and Visualization

In [64]:
def evaluate_model(model, dataloader):
    device = torch.device("cuda")
    model.eval()
    model = model.to(device)
    patient_id = 101
    slice_id = 1
    i = 0
    scores = pd.DataFrame(columns=['patient_id', 'slice_id', 'dice_avg', 'dice_lv', 'dice_rv', 'dice_myo'])

    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)

        with torch.no_grad():
            outputs = model(inputs)

        y_pred = torch.argmax(outputs[2], axis=1)

        """
        # Visualize the input image, ground truth mask, and predicted mask
        input_image = inputs[0].cpu().numpy().transpose(1, 2, 0)
        # convert into a single channel to visualize
        ground_truth_mask = torch.argmax(targets[0], dim=0)
        predicted_mask = y_pred.cpu().numpy().transpose(1, 2, 0)
        

        plt.figure(figsize=(12, 4))

        plt.subplot(1, 3, 1)
        plt.title("Input Image")
        plt.imshow(input_image, cmap='gray')

        plt.subplot(1, 3, 2)
        plt.title("Ground Truth Mask")
        plt.imshow(ground_truth_mask, cmap='gray')

        plt.subplot(1, 3, 3)
        plt.title("Predicted Mask")
        plt.imshow(predicted_mask, cmap='gray')

        plt.show()
        """

        # compute dice
        # convert to 4 channels to compare with gt, since gt has 4 channels
        y_pred_onehot = F.one_hot(y_pred, 4).permute(0, 3, 1, 2)

        dice = compute_dice(y_pred_onehot, targets)
        dice_lv = dice[3].item()
        dice_rv = dice[1].item()
        dice_myo = dice[2].item()
        # skip background for mean
        dice_avg = dice[1:].mean().item()

        scores.loc[i] = {
                'patient_id': patient_id,
                'slice_id': slice_id % 10 + 1,
                'dice_avg': dice_avg,
                'dice_lv': dice_lv,
                'dice_rv': dice_rv,
                'dice_myo': dice_myo
            }
        if slice_id == 20:
          patient_id += 1
          slice_id = 0
        slice_id += 1
        i+= 1

    return scores

def compute_dice(pred_y, y):
    """
    Computes the Dice coefficient for each class in the ACDC dataset.
    Assumes binary masks with shape (num_masks, num_classes, height, width).
    """
    epsilon = 1e-6
    num_masks = pred_y.shape[0]
    num_classes = pred_y.shape[1]
    device = torch.device("cuda")
    dice_scores = torch.zeros((num_classes,), device=device)

    for c in range(num_classes):
        intersection = torch.sum(pred_y[:, c] * y[:, c])
        sum_masks = torch.sum(pred_y[:, c]) + torch.sum(y[:, c])
        dice_scores[c] = (2. * intersection + epsilon) / (sum_masks + epsilon)

    return dice_scores

In [65]:
scores = evaluate_model(model, test_loader)

In [66]:
scores

Unnamed: 0,patient_id,slice_id,dice_avg,dice_lv,dice_rv,dice_myo
0,101,2,7.154040e-01,7.531149e-01,8.771700e-01,5.159269e-01
1,101,3,7.526994e-01,7.737578e-01,8.591394e-01,6.252010e-01
2,101,4,7.692711e-01,7.820248e-01,8.838133e-01,6.419753e-01
3,101,5,7.697833e-01,8.335992e-01,8.073702e-01,6.683804e-01
4,101,6,7.970834e-01,9.008464e-01,7.522698e-01,7.381342e-01
...,...,...,...,...,...,...
1071,154,3,8.508065e-01,9.328461e-01,8.625712e-01,7.570023e-01
1072,154,4,8.364512e-01,9.033434e-01,8.568935e-01,7.491166e-01
1073,154,5,6.120958e-01,8.599671e-01,2.626642e-01,7.136564e-01
1074,154,6,4.065759e-09,3.030303e-09,5.291005e-09,3.875969e-09


## Save Results

So it could be used for further analysis at anytime without infering again.

In [68]:
!mkdir results

In [69]:
export_path = 'results/fct_scores.csv'
scores.to_csv(export_path, index=False)
print(f"The scores have been saved to {export_path}")

The scores have been saved to results/fct_scores.csv
