Iniciamos o projeto importando e extraindo o dataset que está compactado (.zip) no Google Drive. Nesse caso, estamos importando o dataset já contendo as pastas de planejamento, preprocessamento e treinamento afim de realizar a inferência. (O funcionamento é o mesmo caso queira utilizar um dataset novo contendo apenas as dependências originais explicadas à frente).

In [None]:
from google.colab import drive
import zipfile
import os
import shutil

drive.mount('/content/drive')

raw_data_zip_path = '/content/drive/MyDrive/nnUNet_raw_data_base.zip'
preprocessed_zip_path = '/content/drive/MyDrive/nnUNet_preprocessed.zip'
results_zip_path = '/content/drive/MyDrive/nnUNet_results.zip'

extract_base_dir = '/content'
os.makedirs(extract_base_dir, exist_ok=True)

def extract_and_fix(zip_file_path, target_extract_dir):
    print(f"\nExtracting {os.path.basename(zip_file_path)} to {target_extract_dir}...")
    try:
        with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
            zip_ref.extractall(target_extract_dir)
        print(f"{os.path.basename(zip_file_path)} extracted.")
        nested_content_dir = os.path.join(target_extract_dir, 'content')
        if os.path.exists(nested_content_dir) and os.path.isdir(nested_content_dir) and os.listdir(nested_content_dir):
            print(f"Detected nested '/content/content' directory. Moving contents up...")
            for item in os.listdir(nested_content_dir):
                shutil.move(os.path.join(nested_content_dir, item), target_extract_dir)
            shutil.rmtree(nested_content_dir)
            print("Contents moved successfully. Nested folder removed.")
        else:
            print("No nested '/content/content' detected for this zip.")

    except FileNotFoundError:
        print(f"Error: {zip_file_path} not found. Please check the path in your Google Drive.")
    except Exception as e:
        print(f"An error occurred during extraction of {os.path.basename(zip_file_path)}: {e}")
extract_and_fix(raw_data_zip_path, extract_base_dir)
extract_and_fix(preprocessed_zip_path, extract_base_dir)
extract_and_fix(results_zip_path, extract_base_dir)

print("\n--- Final Verification of Extracted Folders ---")
expected_final_dirs = [
    '/content/nnUNet_raw_data_base',
    '/content/nnUNet_preprocessed',
    '/content/nnUNet_results'
]

for d in expected_final_dirs:
    if os.path.exists(d):
        print(f"Found: {d}")
        if 'nnUNet_raw_data_base' in d:
            dataset_path = os.path.join(d, 'nnUNet_raw_data', 'Dataset505_BraTS2020_subset')
            if os.path.exists(dataset_path):
                print(f"   --> Found Dataset505_BraTS2020_subset at: {dataset_path}")
                images_ts_path = os.path.join(dataset_path, 'imagesTs')
                if os.path.exists(images_ts_path) and os.listdir(images_ts_path):
                    print(f"   --> Confirmed imagesTs folder exists and is not empty.")
                else:
                    print(f"   --> WARNING: imagesTs folder not found or empty at: {images_ts_path}")
            else:
                print(f"   --> WARNING: Dataset505_BraTS2020_subset not found inside {d}.")
    else:
        print(f"NOT FOUND: {d}. Something went wrong with the extraction or fix.")

print("\nSetup for inference complete! You can now proceed with setting environment variables and running prediction.")

Instalação das dependências e da nnUNet através do github:

In [None]:
!pip install -U pip setuptools wheel

!git clone https://github.com/MIC-DKFZ/nnUNet.git

%cd nnUNet
!pip install -e .

Definição de variáveis de ambiente (padrão para o nnUNetv2).

In [None]:
import os
os.environ['nnUNet_preprocessed'] = '/content/nnUNet_preprocessed'
os.environ['nnUNet_results'] = '/content/nnUNet_results'
os.environ['nnUNet_raw'] = '/content/nnUNet_raw_data_base'

In [None]:
import os
input_images_ts_dir = '/content/nnUNet_raw_data_base/nnUNet_raw_data/Dataset505_BraTS2020_subset/imagesTs'
output_predictions_dir = '/content/nnUNet_predictions'
os.makedirs(output_predictions_dir, exist_ok=True)
print(f"Output directory for predictions created: {output_predictions_dir}")
task_id = 505
model_config = '3d_fullres'
folds_to_predict = 0
print(f"\nStarting nnUNet prediction for Task {task_id} with {model_config} model...")
print(f"Input: {input_images_ts_dir}")
print(f"Output: {output_predictions_dir}")

Mover o conteúdo extraído para o local correto

In [None]:
!mv /content/nnUNet_raw_data_base/nnUNet_raw_data/Dataset505_BraTS2020_subset /content/nnUNet_raw_data_base/

Para executar o planejamento e o preprocessamento dos dados, as pastas devem conter a seguinte estrutura:
/content/nnUNet_raw_data_base/(Dataset)/imagesTr                                
/content/nnUNet_raw_data_base/(Dataset)/labelsTr                                
/content/nnUNet_raw_data_base/(Dataset)/dataset.json

O código também realiza a verificação da integridade do dataset.

In [None]:
!nnUNetv2_plan_and_preprocess -d 505 --verify_dataset_integrity

Caso queira dar início ao treinamento:

In [None]:
!nnUNetv2_train 505 3d_fullres 0

Caso queira retomar o treinamento:

In [None]:
!nnUNetv2_train 505 3d_fullres 0 --c

Realizar a inferência baseada nos passos realizados acima. Note que no mesmo local onde existem as pastas imagesTr e labelsTr, deve haver as pastas imagesTs e labelsTs para teste.

In [None]:
!nnUNetv2_predict \
  -i /content/nnUNet_raw_data_base/Dataset505_BraTS2020_subset/imagesTs \
  -o /content/nnUNet_raw_data_base/nnUNet_raw_data/Dataset505_BraTS2020_subset/predictions \
  -d Dataset505_BraTS2020_subset \
  -c 3d_fullres \
  -f 0 \
  -chk checkpoint_best.pth

O código a seguir realiza uma plotagem dos dados obtidos acima para melhor visualização de desempenho e resultados conclusivos.

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt

json_path = "/content/nnUNet_raw_data_base/Dataset505_BraTS2020_subset/predictions/summary.json"
with open(json_path) as f:
    summary = json.load(f)

df_mean = pd.DataFrame.from_dict(summary["mean"], orient="index")
df_mean.index.name = "Classe"
df_mean.reset_index(inplace=True)

print("MÉTRICAS MÉDIAS POR CLASSE (ORIGINAL)")
display(df_mean)

plt.figure(figsize=(8, 5))
plt.bar(df_mean['Classe'].astype(str), df_mean['Dice'], color='skyblue')
plt.title("Dice Score por Classe (Média Geral)")
plt.xlabel("Classe")
plt.ylabel("Dice")
plt.ylim(0, 1)
plt.grid(True)
plt.show()

df_mean[["Classe", "Dice", "IoU", "FP", "FN"]].set_index("Classe").plot(
    kind="bar", figsize=(10, 6), title="Comparação de Métricas por Classe (Média Geral)"
)
plt.ylabel("Valor")
plt.grid(True)
plt.show()

worst_case, best_case = None, None
worst_dice, best_dice = float("inf"), float("-inf")
case_dice_averages = []

for case in summary["metric_per_case"]:
    metrics = case["metrics"]
    mean_dice = sum(m["Dice"] for m in metrics.values()) / len(metrics)
    case_dice_averages.append((mean_dice, case))
    if mean_dice < worst_dice:
        worst_dice = mean_dice
        worst_case = case
    if mean_dice > best_dice:
        best_dice = mean_dice
        best_case = case

df_worst = pd.DataFrame.from_dict(worst_case["metrics"], orient="index")
df_worst.index.name = "Classe"
df_worst.reset_index(inplace=True)

print(f"\nPIOR CASO: {worst_case['prediction_file'].split('/')[-1]} | Média Dice: {worst_dice:.4f}")
display(df_worst)

plt.figure(figsize=(8, 5))
plt.bar(df_worst['Classe'].astype(str), df_worst['Dice'], color='salmon')
plt.title("Dice Score por Classe (Pior Caso)")
plt.xlabel("Classe")
plt.ylabel("Dice")
plt.ylim(0, 1)
plt.grid(True)
plt.show()

df_worst[["Classe", "Dice", "IoU", "FP", "FN"]].set_index("Classe").plot(
    kind="bar", figsize=(10, 6), title="Comparação de Métricas por Classe (Pior Caso)"
)
plt.ylabel("Valor")
plt.grid(True)
plt.show()

df_best = pd.DataFrame.from_dict(best_case["metrics"], orient="index")
df_best.index.name = "Classe"
df_best.reset_index(inplace=True)

print(f"\nMELHOR CASO: {best_case['prediction_file'].split('/')[-1]} | Média Dice: {best_dice:.4f}")
display(df_best)

plt.figure(figsize=(8, 5))
plt.bar(df_best['Classe'].astype(str), df_best['Dice'], color='mediumseagreen')
plt.title("Dice Score por Classe (Melhor Caso)")
plt.xlabel("Classe")
plt.ylabel("Dice")
plt.ylim(0, 1)
plt.grid(True)
plt.show()

df_best[["Classe", "Dice", "IoU", "FP", "FN"]].set_index("Classe").plot(
    kind="bar", figsize=(10, 6), title="Comparação de Métricas por Classe (Melhor Caso)"
)
plt.ylabel("Valor")
plt.grid(True)
plt.show()


O código a seguir também utiliza de meios visuais para fazermos a análise dos resultados, porém utilizando os dados puramente do treinamento.

In [None]:
import re
import matplotlib.pyplot as plt

log_file_path = "/content/nnUNet_results/Dataset505_BraTS2020_subset/nnUNetTrainer__nnUNetPlans__3d_fullres/fold_0/training_log_2025_6_3_17_58_45.txt"

epochs = []
train_losses = []
val_losses = []
learning_rates = []
epoch_times = []
dice_0, dice_1, dice_2 = [], [], []

with open(log_file_path, "r") as f:
    for line in f:
        if "Epoch " in line:
            match = re.search(r"Epoch (\d+)", line)
            if match:
                epochs.append(int(match.group(1)))
        if "Current learning rate" in line:
            match = re.search(r"Current learning rate: ([\d.]+)", line)
            if match:
                learning_rates.append(float(match.group(1)))
        if "train_loss" in line:
            match = re.search(r"train_loss (-?[\d.]+)", line)
            if match:
                train_losses.append(float(match.group(1)))
        if "val_loss" in line:
            match = re.search(r"val_loss (-?[\d.]+)", line)
            if match:
                val_losses.append(float(match.group(1)))
        if "Pseudo dice" in line:
            match = re.search(r"\[np\.float32\(([\d.]+)\), np\.float32\(([\d.]+)\), np\.float32\(([\d.]+)\)\]", line)
            if match:
                dice_0.append(float(match.group(1)))
                dice_1.append(float(match.group(2)))
                dice_2.append(float(match.group(3)))
        if "Epoch time" in line:
            match = re.search(r"Epoch time: ([\d.]+) s", line)
            if match:
                epoch_times.append(float(match.group(1)))

def plotar(x, y_list, labels, title, ylabel):
    plt.figure(figsize=(8, 5))
    min_len = min(len(x), *(len(y) for y in y_list))
    x = x[:min_len]
    y_list = [y[:min_len] for y in y_list]
    for y, label in zip(y_list, labels):
        plt.plot(x, y, label=label)
    plt.xlabel("Epoch")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

plotar(epochs, [train_losses, val_losses], ["Train Loss", "Val Loss"], "Train vs Val Loss", "Loss")
plotar(epochs, [dice_0, dice_1, dice_2], ["Class 0", "Class 1", "Class 2"], "Pseudo Dice per Class", "Dice Score")
plotar(epochs, [learning_rates], ["Learning Rate"], "Learning Rate over Epochs", "Learning Rate")
plotar(epochs, [epoch_times], ["Epoch Time"], "Time per Epoch", "Seconds")