In [1]:
import sys
sys.path.append('../..')

In [None]:
from model import ModelFx
from dataset_service.isic_multimodal.dataset import ISIC_MultiModal_DataModule
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, random_split
from backend_central_dev.utils.pytorch_utils import get_device

device = get_device()

model_list = []
dm_list = []
for i in tqdm(range(5)):
    model_list.append(
        ModelFx(i).model.to(device).eval()
    )
    dm = ISIC_MultiModal_DataModule(
        img_size=224,
        batch_size=32,
        data_loader_kwargs=dict(num_workers=8),
        dataset_init_kwargs=dict(
            fold=i
        ),
        suppress_aug_info_print=True
    )
    dm.setup('val')
    train_size = int(0.8 * len(dm.val))
    test_size = len(dm.val) - train_size

    # Split the dataset
    train_dataset, test_dataset = random_split(dm.val, [train_size, test_size])
    test_dataset.__se__ = None
    dm.val = test_dataset
    dm_list.append(dm)

print(dm_list[0].val.__len__())

  0%|          | 0/5 [00:00<?, ?it/s]

Downloading 9c_b4ns_448_ext_15ep-newfold_best_fold1.pth to /Users/yinnnyou/workspace/computing-solution-of-ml-and-xai/backend/model_service/isic_champion/weights
Downloading 9c_b4ns_448_ext_15ep-newfold_best_fold1.pth.zip to /Users/yinnnyou/workspace/computing-solution-of-ml-and-xai/backend/model_service/isic_champion/weights


100%|██████████| 62.7M/62.7M [00:04<00:00, 16.2MB/s]



Downloading 9c_b4ns_448_ext_15ep-newfold_best_fold2.pth to /Users/yinnnyou/workspace/computing-solution-of-ml-and-xai/backend/model_service/isic_champion/weights
Downloading 9c_b4ns_448_ext_15ep-newfold_best_fold2.pth.zip to /Users/yinnnyou/workspace/computing-solution-of-ml-and-xai/backend/model_service/isic_champion/weights


100%|██████████| 62.7M/62.7M [00:02<00:00, 24.9MB/s]



Downloading 9c_b4ns_448_ext_15ep-newfold_best_fold3.pth to /Users/yinnnyou/workspace/computing-solution-of-ml-and-xai/backend/model_service/isic_champion/weights
Downloading 9c_b4ns_448_ext_15ep-newfold_best_fold3.pth.zip to /Users/yinnnyou/workspace/computing-solution-of-ml-and-xai/backend/model_service/isic_champion/weights


100%|██████████| 62.7M/62.7M [00:01<00:00, 36.0MB/s]



Downloading 9c_b4ns_448_ext_15ep-newfold_best_fold4.pth to /Users/yinnnyou/workspace/computing-solution-of-ml-and-xai/backend/model_service/isic_champion/weights
Downloading 9c_b4ns_448_ext_15ep-newfold_best_fold4.pth.zip to /Users/yinnnyou/workspace/computing-solution-of-ml-and-xai/backend/model_service/isic_champion/weights


 93%|█████████▎| 58.0M/62.7M [00:02<00:00, 34.0MB/s]




100%|██████████| 62.7M/62.7M [00:02<00:00, 30.6MB/s]


2317


### Uncertainty


In [None]:
from torch_uncertainty.metrics.classification import BrierScore
from evaluation_service.model_eval.eval import get_brier_score

from torchmetrics.classification import (
    MulticlassAccuracy, MulticlassPrecision,
    MulticlassRecall, MulticlassF1Score,
    MulticlassCalibrationError,
    MulticlassAUROC
)


rs_data = []

for fold in range(5):
    model = model_list[fold]
    data_module = dm_list[fold]

    accuracy = MulticlassAccuracy(
        num_classes=data_module.num_classes).to(device)
    precision = MulticlassPrecision(
        num_classes=data_module.num_classes).to(device)
    recall = MulticlassRecall(num_classes=data_module.num_classes).to(device)
    f1 = MulticlassF1Score(num_classes=data_module.num_classes).to(device)
    brier_score = BrierScore(
        num_classes=data_module.num_classes
    ).to(device)

    ce = MulticlassCalibrationError(
        num_classes=data_module.num_classes,
    )

    auc_roc = MulticlassAUROC(
        num_classes=data_module.num_classes,
    )

    for x, y in tqdm(data_module.val_dataloader(), desc=f'Fold {fold}'):
        x = x.to(device)
        y = y.to(device)
        with torch.no_grad():
            # brier_score, pred, softmax_pred = get_brier_score(
            #     model, x, y, brier_score, data_module.num_classes)

            pred = model(x)
            softmax_pred = torch.softmax(pred, dim=1)
            # return None, pred, softmax_pred
            oh_y = torch.nn.functional.one_hot(
                y, num_classes=data_module.num_classes).float()
            brier_score.update(softmax_pred.detach(), oh_y.detach())

            # Calculate metrics
            accuracy.update(pred, y)
            f1.update(pred, y)
            recall.update(pred, y)
            precision.update(pred, y)
            ce.update(pred, y)
            auc_roc.update(pred, y)

            # print(bs)
        # break
    rs_data.append({
        'fold': fold,
        'accuracy': accuracy.compute().item(),
        'f1': f1.compute().item(),
        'recall': recall.compute().item(),
        'precision': precision.compute().item(),

        'bs': brier_score.compute().item(),
        'ce': ce.compute().item(),
        'aucroc': auc_roc.compute().item()
    })
    # break

Fold 0:   0%|          | 0/73 [00:00<?, ?it/s]

Fold 1:   0%|          | 0/73 [00:00<?, ?it/s]

Fold 2:   0%|          | 0/73 [00:00<?, ?it/s]

Fold 3:   0%|          | 0/73 [00:00<?, ?it/s]

Fold 4:   0%|          | 0/73 [00:00<?, ?it/s]

In [14]:
import pandas as pd

pd.DataFrame(rs_data)

Unnamed: 0,fold,accuracy,f1,recall,precision,bs,ce,aucroc
0,0,0.676384,0.697692,0.676384,0.732585,0.234661,0.061825,0.974895
1,1,0.711441,0.700029,0.711441,0.696856,0.224076,0.065237,0.976862
2,2,0.746143,0.743993,0.746143,0.750476,0.208678,0.049996,0.97613
3,3,0.706099,0.696685,0.706099,0.694608,0.236249,0.06446,0.974334
4,4,0.663625,0.68542,0.663625,0.716713,0.239822,0.065058,0.973772


### Grad-CAM, SHAP, LIME


In [None]:
from backend.xai_service.general_xai.general_xai import grad_cam, shap_map, lime_map
from xai_service.general_xai.gradient_methods import guided_absolute_grad
from backend_central_dev.utils import plotting_utils
from backend.evaluation_service.xai_eval.rcap import batch_rcap
from backend_central_dev.utils import data_utils

In [None]:
results = [
    dict(
        gag_rs=[],
        grad_cam_rs=[],
        shap_rs=[],
        lime_rs=[]
    )
    for i in range(5)
]

for fold in range(5):
    model = model_list[fold]
    data_module = dm_list[fold]
    with tqdm(total=len(data_module.val_dataloader()), desc=f"Fold {fold}") as pbar:
        for x, y in data_module.val_dataloader():
            # s = 20
            # e = s + 1
            # x = x.to(device)[s:e]
            # y = y.to(device)[s:e]
            x = x.to(device)
            y = y.to(device)

            # plotting_utils.plot_hor([xx.transpose(1, 2, 0)
            #                         for xx in data_utils.denormm_i_t(x.clone()).cpu().numpy()])

            # cam_map = grad_cam(model, x, y, getattr(model.enet.blocks, '6'))

            # plotting_utils.plot_hor([one_map.cpu().numpy()
            #                         for one_map in cam_map])

            # gag_map = guided_absolute_grad(model, x, y, num_samples=10, blur=True)

            # plotting_utils.plot_hor([one_map.cpu().numpy()
            #                          for one_map in gag_map])

            # shap_values = shap_map(
            #     model,
            #     x,
            #     y,
            #     lambda i: torch.clamp(i, 0, 1).cpu(
            #     ).numpy().transpose(0, 2, 3, 1),
            #     ("blur(128,128)", (224, 224, 3)),
            #     device,
            #     shap_params=dict(
            #         max_evals=500,
            #         batch_size=20,
            #     )
            # )

            # plotting_utils.plot_hor(shap_values)

            # explanations = lime_map(
            #     model,
            #     x,
            #     y,
            #     device,
            #     lime_params=dict(
            #         num_samples=10,
            #         progress_bar=True,
            #     )
            # )

            # plotting_utils.plot_hor(explanations)

            # ======== RCAP =========

            # pbar.set_postfix_str("Grad Cam")
            # grad_cam_rcap = batch_rcap(model, (x, y), grad_cam, dict(
            #     target_layers=getattr(model.enet.blocks, '6')))

            # results[fold]['grad_cam_rs'].append(grad_cam_rcap)

            # pbar.set_postfix_str("GAG")
            # gag_map_rcap = batch_rcap(model, (x, y), guided_absolute_grad, dict(
            #     blur=True
            # ))
            # results[fold]['gag_rs'].append(gag_map_rcap)

            # pbar.set_postfix_str("SHAP")
            # shap_rcap = batch_rcap(model, (x, y), shap_map, dict(
            #     image_processor=lambda i: torch.clamp(
            #         i, 0, 1).cpu().numpy().transpose(0, 2, 3, 1),
            #     masker_params=("blur(128,128)", (224, 224, 3)),
            #     device=device,
            #     norm_output=True,
            #     # shap_params=dict(
            #     #     max_evals=100,
            #     #     batch_size=10
            #     # )
            # ))
            # results[fold]['shap_rs'].append(shap_rcap)

            pbar.set_postfix_str("lime")
            lime_rcap = batch_rcap(model, (x, y), lime_map, dict(
                device=device,
                lime_params=dict(
                    num_samples=100,
                )
            ))
            results[fold]['lime_rs'].append(lime_rcap)
            pbar.update(1)
    #         break
    # break

In [20]:
import numpy as np
np.set_printoptions(suppress=True)


def print_rs(rs):
    target_aggregated_rcap_keys = [
        # 'original_pred_score',
        # 'recovered_pred_score',
        # 'original_pred_prob',
        'recovered_pred_prob',
        # 'local_heat_mean',
        # 'local_heat_sum',
        # 'overall_heat_mean',
        # 'overall_heat_sum',
        # 'all_original_pred_prob_full',
        # 'all_recovered_pred_prob_full',
        'overall_rcap'
    ]
    rs_data = []
    for fold, fold_exp in enumerate(rs):
        for xai_key, xai_rcap_result_of_all_batches in fold_exp.items():
            recovered_pred_prob_list = []
            rcap_list = []
            visual_noize_level = []
            for xai_rcap_result_of_one_batch in xai_rcap_result_of_all_batches:
                recovered_pred_prob_of_batch = xai_rcap_result_of_one_batch['recovered_pred_prob']
                recovered_pred_prob_list.append(
                    np.array(recovered_pred_prob_of_batch).mean())
                rcap_list.append(
                    np.array(
                        xai_rcap_result_of_one_batch['overall_rcap']['RCAP']).mean()
                )
                visual_noize_level.append(
                    np.array(
                        xai_rcap_result_of_one_batch['overall_rcap']['visual_noise_level']).mean()
                )
            # print(len(recovered_pred_prob_list))
            # print(len(rcap_list))
            # print(len(visual_noize_level))
            vnl = np.array(visual_noize_level).mean()
            rs_data.append({
                'fold': fold,
                'xai_key': xai_key,
                'localization': np.array(recovered_pred_prob_list).mean(),
                'rcap': np.array(rcap_list).mean(),
                'visual_noize_level': np.nan if vnl == 1.0 else vnl
            })
    display(pd.DataFrame(rs_data).style.background_gradient())

In [16]:
results_cam_gag_shap = np.load('results_cam_gag_shap.npy', allow_pickle=True)
results_lime = np.load('results_lime.npy', allow_pickle=True)

In [21]:
complete_rs = []
for i in range(5):
    complete_rs.append({
        'grad_cam_rs': results_cam_gag_shap[i]['grad_cam_rs'],
        'gag_rs': results_cam_gag_shap[i]['gag_rs'],
        'shap_rs': results_cam_gag_shap[i]['shap_rs'],
        'lime_rs': results_lime[i]['lime_rs']
    })

print_rs(complete_rs)

Unnamed: 0,fold,xai_key,localization,rcap,visual_noize_level
0,0,grad_cam_rs,0.751662,0.552296,0.732856
1,0,gag_rs,0.723225,0.587459,0.811923
2,0,shap_rs,0.754528,0.382498,0.502029
3,0,lime_rs,0.664844,0.664844,
4,1,grad_cam_rs,0.770895,0.55218,0.716362
5,1,gag_rs,0.735774,0.573955,0.785039
6,1,shap_rs,0.75533,0.372595,0.495204
7,1,lime_rs,0.678972,0.678972,
8,2,grad_cam_rs,0.742714,0.521746,0.701936
9,2,gag_rs,0.708103,0.56092,0.79038
