In [None]:
import os
import sys
import torch
import numpy as np

In [None]:
!git clone https://github.com/msskzx/FCT

Cloning into 'FCT'...
remote: Enumerating objects: 287, done.[K
remote: Counting objects: 100% (64/64), done.[K
remote: Compressing objects: 100% (49/49), done.[K
remote: Total 287 (delta 39), reused 32 (delta 15), pack-reused 223[K
Receiving objects: 100% (287/287), 53.49 KiB | 4.86 MiB/s, done.
Resolving deltas: 100% (178/178), done.


In [None]:
dir_path = '/content/FCT'
os.chdir(dir_path)

# ACDC Dataset

## Download Test Dataset Only

In [None]:
!wget --output-document=data.zip https://humanheart-project.creatis.insa-lyon.fr/database/api/v1/folder/6372203a73e9f0047faa117e/download

--2023-09-13 19:35:35--  https://humanheart-project.creatis.insa-lyon.fr/database/api/v1/folder/6372203a73e9f0047faa117e/download
Resolving humanheart-project.creatis.insa-lyon.fr (humanheart-project.creatis.insa-lyon.fr)... 195.220.108.28
Connecting to humanheart-project.creatis.insa-lyon.fr (humanheart-project.creatis.insa-lyon.fr)|195.220.108.28|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/zip]
Saving to: ‘data.zip’

data.zip                [            <=>     ] 771.85M  16.6MB/s    in 48s     

2023-09-13 19:36:24 (16.1 MB/s) - ‘data.zip’ saved [809346596]



In [None]:
!unzip data.zip
!rm data.zip

# Model

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
!cp /content/gdrive/MyDrive/models models

In [None]:
!pip install monai lightning

In [None]:
model = torch.load('models/fct.model')

In [None]:
print(model.loss_fn)

BCEWithLogitsLoss()


# Test

In [None]:
from utils.data_utils import get_acdc,convert_masks
from torch.utils.data import DataLoader,TensorDataset
import torch.nn.functional as F

In [None]:
def evaluate_model(model, dataloader):
    device = torch.device("cuda")
    model.eval()
    model = model.to(device)
    patient = 0
    results = ""
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model.forward(inputs)

        y_pred = torch.argmax(outputs[2], axis=1)
        y_pred_onehot = F.one_hot(y_pred, 4).permute(0, 3, 1, 2)
        dice = compute_dice(y_pred_onehot, targets)
        dice_LV = dice[3]
        dice_RV = dice[1]
        dice_MYO = dice[2]

        patient += 1
        results += 'Patient:' + str(patient) + "\n"
        results += 'dice/all_validate_dice' + str(dice[1:].mean())  + "\n"
        results += 'dice/LV_dice' + str(dice_LV)  + "\n"
        results += 'dice/RV_dice' + str(dice_RV)  + "\n"
        results += 'dice/MYO_dice' + str(dice_MYO)  + "\n"
        results += "========================================================\n"

    return results

@torch.no_grad()
def compute_dice(pred_y, y):
    """
    Computes the Dice coefficient for each class in the ACDC dataset.
    Assumes binary masks with shape (num_masks, num_classes, height, width).
    """
    epsilon = 1e-6
    num_masks = pred_y.shape[0]
    num_classes = pred_y.shape[1]
    device = torch.device("cuda")
    dice_scores = torch.zeros((num_classes,), device=device)

    for c in range(num_classes):
        intersection = torch.sum(pred_y[:, c] * y[:, c])
        sum_masks = torch.sum(pred_y[:, c]) + torch.sum(y[:, c])
        dice_scores[c] = (2. * intersection + epsilon) / (sum_masks + epsilon)

    return dice_scores

In [None]:
# test dataloader
acdc_data, _, _ = get_acdc('testing', input_size=(224, 224, 1))
acdc_data[1] = convert_masks(acdc_data[1])
acdc_data[0] = np.transpose(acdc_data[0], (0, 3, 1, 2)) # for the channels
acdc_data[1] = np.transpose(acdc_data[1], (0, 3, 1, 2)) # for the channels
acdc_data[0] = torch.Tensor(acdc_data[0]) # convert to tensors
acdc_data[1] = torch.Tensor(acdc_data[1]) # convert to tensors
acdc_data = TensorDataset(acdc_data[0], acdc_data[1])
test_loader = DataLoader(acdc_data, batch_size=1, num_workers=2)

results = evaluate_model(model, test_loader)

In [None]:
file_name = "fct_results.txt"

# Open the file in write mode and save the string
with open(file_name, 'w') as file:
    file.write(results)

print(f"The string has been saved to {file_name}.")

The string has been saved to results.txt.
