In [11]:
import argparse
import json
import os
import warnings

import numpy as np
import pandas as pd
import seaborn as sns
import torch
import wandb
from lightning import Trainer
from matplotlib import pyplot as plt
from sklearn.metrics import classification_report
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
from tqdm.notebook import tqdm

from src.data.datamodule import DataModule
from src.models.regnety.regnety import RegNetY

warnings.filterwarnings("ignore", ".*does not have many workers.*")

%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Preparation

In [12]:
# Load the model config
api = wandb.Api()
# run = api.run(f'mvrcii_/SEER/f5v984te')
run = api.run(f'wuesuv/CV2024/qt8ov5wo')
config = argparse.Namespace(**run.config)

In [13]:
# Project paths
vissl_project_dir = 'C:\\Users\Marce\Git-Master\JMU\Masterarbeit\\vissl'
endoscopy_project_dir = 'C:\\Users\Marce\Git-Master\JMU\Masterarbeit\endoscopy'
cvip_project_dir = 'C:\\Users\Marce\Git-Master\Privat\cv2024'

## Loading Class Mapping

In [14]:
class_mapping_path = os.path.join(endoscopy_project_dir, 'datasets/endoextend_dataset/class_mapping.json')
absolute_path = os.path.abspath(class_mapping_path)

if not os.path.exists(absolute_path):
    raise FileNotFoundError(f"Class mapping file not found at {class_mapping_path}")

with open(class_mapping_path, 'r') as f:
    class_mapping = json.load(f)
class_mapping

{'angioectasia': 0,
 'bleeding': 1,
 'erosion': 2,
 'erythema': 3,
 'foreign_body': 4,
 'lymphangiectasia': 5,
 'normal-mucosa': 6,
 'polyp': 7,
 'ulcer': 8,
 'worms': 9}

## Loading Model with Checkpoint

In [15]:
# ckpt_filename = 'unique-sweep-1_epoch03_val_mAP_weighted0.81.ckpt'
ckpt_filename = 'run-20240827_135850-honest-salad-51/best_epoch07_val_AUC_macro0.99.ckpt'
ckpt_path = os.path.join(cvip_project_dir, 'pretrained_models', ckpt_filename)


def create_model():
    model = RegNetY.load_from_checkpoint(checkpoint_path=ckpt_path, config=config, class_to_idx=class_mapping)
    model.to(torch.device('cuda'))
    model.eval()
    return model

In [16]:
accelerator = "mps" if torch.backends.mps.is_available() else ("gpu" if torch.cuda.is_available() else "cpu")

trainer = Trainer(
    devices=1,
    accelerator=accelerator,
    precision="16-mixed",
    gradient_clip_val=0.5,
    enable_progress_bar=True,
    enable_model_summary=False,
    inference_mode=True
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


## Loading Transforms

In [17]:
from src.utils.transform_utils import load_transforms

transforms_str = run.summary.get('transforms')
transforms = load_transforms(img_size=config.img_size, transforms_string=transforms_str)

## Loading DataModules

In [21]:
# EndoExtend DataModule
ee_dataset_path = os.path.join(vissl_project_dir, '../data')
ee_dataset_csv_path = os.path.join(endoscopy_project_dir, 'capsulevision', 'endoextend_dataset')

ee_data_module = DataModule(
    class_mapping=class_mapping,
    transforms=transforms,
    train_bs=32,
    val_bs=64,
    dataset_path=ee_dataset_path,
    dataset_csv_path=ee_dataset_csv_path,
    fold_idx=config.fold_id,
    num_workers=0
)
ee_data_module.setup()
print("EndoExtend Dataset with Val-Fold Index:", config.fold_id)
print("Train Images:", len(ee_data_module.train_dataloader().dataset))
print("Val Images:", len(ee_data_module.val_dataloader().dataset))
print("Test Images:", len(ee_data_module.test_dataloader().dataset))

EndoExtend Dataset with Val-Fold Index: 1
Train Images: 54209
Val Images: 17961
Test Images: 18042


### 

In [22]:
# CVIP DataModule
dataset_path = os.path.join(cvip_project_dir, 'data')
dataset_csv_path = os.path.join(cvip_project_dir, 'dataset')

fold_id = 1
cvip_data_module = DataModule(
    class_mapping=class_mapping,
    transforms=transforms,
    train_bs=32,
    val_bs=64,
    dataset_path=dataset_path,
    dataset_csv_path=dataset_csv_path,
    fold_idx=fold_id,
    num_workers=0
)
cvip_data_module.setup()
print("CVIP Dataset with Val-Fold Index:", fold_id)
print("Train Images:", len(cvip_data_module.train_dataloader().dataset))
print("Val Images:", len(cvip_data_module.val_dataloader().dataset))

CVIP Dataset with Val-Fold Index: 1
Train Images: 37310
Val Images: 15987


## Sanity Check: Validate Model Checkpoint on EndoExtend Data

In [23]:
trainer.validate(create_model(), ee_data_module.val_dataloader())

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

## Predict on CVIP Data

In [None]:
trainer.validate(create_model(), cvip_data_module.val_dataloader())

In [None]:
from torch.utils.data import DataLoader

train_ds = cvip_data_module.datasets['train']
train_loader = DataLoader(train_ds, batch_size=64, num_workers=0)
labels = []
for batch in train_loader:
    _, _labels = batch
    labels.append(_labels)

model = create_model()
trainer.validate(model, train_loader)

In [None]:
dataloader = cvip_data_module.val_dataloader()

print("Model Device:", model.device)
model_result = []
for batch in tqdm(dataloader, total=len(dataloader)):
    images, labels = batch
    images = images.to(torch.device('cuda'))

    logits = model.forward(images)
    model_result.append((logits, labels))

In [None]:
def extract_results(results):
    logits = []
    labels = []
    for _tuple in results:
        _logits, _labels = _tuple
        logits.append(_logits.detach().cpu())
        labels.append(_labels)

    logits = np.array(logits).reshape(-1, 10)
    labels = np.array(labels).reshape(-1)

    return logits, labels


logits, y_true = extract_results(model_result)
y_pred = np.argmax(logits, axis=1)

print(y_pred.shape, y_true.shape)

In [None]:
class_indices = list(class_mapping.values())
class_labels = list(class_mapping.keys())
report = classification_report(y_true=y_true,
                               y_pred=y_pred,
                               labels=class_indices,
                               target_names=class_labels, output_dict=True, zero_division=np.nan)
pd.DataFrame(report).T

In [None]:
def plot_roc_curve(preds: np.array, labels: np.array, class_mapping):
    print("Creating ROC curve plot")
    sns.set(style="whitegrid", context="poster", palette="bright")
    sns.set_style('ticks')

    class_labels = list(class_mapping.keys())
    num_classes = len(class_labels)

    if labels.ndim == 1 or labels.shape[1] == 1:
        micro_labels = label_binarize(labels, classes=np.arange(num_classes)).ravel()
    else:
        micro_labels = labels.ravel()

    micro_preds = preds.ravel()

    mean_fpr = np.linspace(0, 1, 100)
    tprs = []
    aucs = []

    # Create a subplot figure with 1 row and 2 columns
    fig, axes = plt.subplots(1, 2, figsize=(28, 10))
    fig.subplots_adjust(hspace=0.4, wspace=0.4, bottom=0.15)

    ax1 = axes[0]
    for i in range(num_classes):
        binary_labels = (labels == i)
        fpr, tpr, _ = roc_curve(binary_labels, preds[:, i])
        auc_score = auc(fpr, tpr)
        aucs.append(auc_score)
        interp_tpr = np.interp(mean_fpr, fpr, tpr)
        tprs.append(interp_tpr)
        sns.lineplot(x=fpr, y=tpr, ax=ax1, label=f'{class_labels[i]} (AUC = {auc_score:.2f})')

    sns.lineplot(x=[0, 1], y=[0, 1], ax=ax1, color="gray", linestyle='--', label='Random Classifier')
    ax1.set_title('ROC Curves for Each Class')
    ax1.set_xlabel('False Positive Rate (Specificity)')
    ax1.set_ylabel('True Positive Rate (Sensitivity)')
    ax1.legend(loc='lower right', fontsize='small', title_fontsize='medium')

    # Calculate and plot the micro-average ROC curve
    micro_fpr, micro_tpr, _ = roc_curve(micro_labels, micro_preds)
    micro_auc = auc(micro_fpr, micro_tpr)
    sns.lineplot(x=micro_fpr, y=micro_tpr, ax=axes[1], color='blue', label=f'Micro-average ROC (AUC = {micro_auc:.2f})')

    # Macro-average ROC curve
    mean_tpr = np.mean(tprs, axis=0)
    mean_auc = auc(mean_fpr, mean_tpr)
    mean_tpr[0] = 0.0
    sns.lineplot(x=mean_fpr, y=mean_tpr, ax=axes[1], color='red', label=f'Macro-average ROC (AUC = {mean_auc:.2f})')

    sns.lineplot(x=[0, 1], y=[0, 1], ax=axes[1], color="gray", linestyle='--', label='Random Classifier')
    axes[1].set_title('Macro and Micro-average ROC Curves')
    axes[1].set_xlabel('False Positive Rate (Specificity)')
    axes[1].set_ylabel('True Positive Rate (Sensitivity)')
    axes[1].legend(loc='lower right', fontsize='small', title_fontsize='medium')

    plt.tight_layout()
    plt.show()

In [None]:
logits_sm = torch.softmax(torch.tensor(logits), dim=1)
plot_roc_curve(logits_sm, y_true, class_mapping)