In [1]:
import sys
sys.path.append('../..')

In [1]:
from model import ModelFx
from dataset_service.isic_multimodal.dataset import ISIC_MultiModal_DataModule
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, random_split

device = torch.device('cuda')
# device = torch.device('mps')

model_list = []
dm_list = []
for i in tqdm(range(5)):
    model_list.append(
        ModelFx(i).model.to(device).eval()
    )
    dm = ISIC_MultiModal_DataModule(
        img_size=224,
        batch_size=32,
        data_loader_kwargs=dict(num_workers=8),
        dataset_init_kwargs=dict(
            fold=i
        ),
        suppress_aug_info_print=True
    )
    dm.setup('val')
    train_size = int(0.8 * len(dm.val))
    test_size = len(dm.val) - train_size

    # Split the dataset
    train_dataset, test_dataset = random_split(dm.val, [train_size, test_size])
    test_dataset.__se__ = None
    dm.val = test_dataset
    dm_list.append(dm)

print(dm_list[0].val.__len__())

  0%|          | 0/5 [00:00<?, ?it/s]

2317


### Uncertainty, General Eval


In [None]:
from torch_uncertainty.metrics.classification import BrierScore
from evaluation_service.model_eval.eval import get_brier_score

from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall, MulticlassF1Score


rs_data = []

for fold in range(5):
    model = model_list[fold]
    data_module = dm_list[fold]

    accuracy = MulticlassAccuracy(
        num_classes=data_module.num_classes).to(device)
    precision = MulticlassPrecision(
        num_classes=data_module.num_classes).to(device)
    recall = MulticlassRecall(num_classes=data_module.num_classes).to(device)
    f1 = MulticlassF1Score(num_classes=data_module.num_classes).to(device)
    brier_score = BrierScore(
        num_classes=data_module.num_classes
    ).to(device)

    for x, y in tqdm(data_module.val_dataloader(), desc=f'Fold {fold}'):
        x = x.to(device)
        y = y.to(device)
        with torch.no_grad():
            brier_score, pred, softmax_pred = get_brier_score(
                model, x, y, brier_score, data_module.num_classes)
            # Convert predictions to numpy array
            predicted_classes = pred.argmax(dim=1).cpu().numpy()
            true_classes = y.cpu().numpy()

            # Calculate metrics
            accuracy.update(pred, y)
            f1.update(pred, y)
            recall.update(pred, y)
            precision.update(pred, y)

            # print(bs)
    rs_data.append({
        'fold': fold,
        'bs': brier_score.compute().item(),
        'accuracy': accuracy.compute().item(),
        'f1': f1.compute().item(),
        'recall': recall.compute().item(),
        'precision': precision.compute().item(),
    })

Fold 0:   0%|          | 0/19 [00:00<?, ?it/s]

Fold 1:   0%|          | 0/19 [00:00<?, ?it/s]

Fold 2:   0%|          | 0/19 [00:00<?, ?it/s]

Fold 3:   0%|          | 0/19 [00:00<?, ?it/s]

Fold 4:   0%|          | 0/19 [00:00<?, ?it/s]

In [29]:
import pandas as pd

df = pd.DataFrame(rs_data)
df

Unnamed: 0,fold,bs,accuracy,f1,recall,precision
0,0,0.252084,0.70081,0.73007,0.70081,0.769346
1,1,0.203009,0.711097,0.702678,0.711097,0.710797
2,2,0.223311,0.744245,0.726691,0.744245,0.721449
3,3,0.239552,0.726333,0.695942,0.726333,0.682797
4,4,0.245369,0.71212,0.72066,0.71212,0.734468


### Grad-CAM, SHAP, LIME


In [2]:
from backend.xai_service.general_xai.general_xai import grad_cam, shap_map, lime_map
from xai_service.general_xai.gradient_methods import guided_absolute_grad
from backend_central_dev.utils import plotting_utils
from backend.evaluation_service.xai_eval.rcap import batch_rcap
from backend_central_dev.utils import data_utils

In [None]:
results = [
    dict(
        gag_rs=[],
        grad_cam_rs=[],
        shap_rs=[],
        lime_rs=[]
    )
    for i in range(5)
]

for fold in range(5):
    model = model_list[fold]
    data_module = dm_list[fold]
    with tqdm(total=len(data_module.val_dataloader()), desc=f"Fold {fold}") as pbar:
        for x, y in data_module.val_dataloader():
            # s = 20
            # e = s + 1
            # x = x.to(device)[s:e]
            # y = y.to(device)[s:e]
            x = x.to(device)
            y = y.to(device)

            # plotting_utils.plot_hor([xx.transpose(1, 2, 0)
            #                         for xx in data_utils.denormm_i_t(x.clone()).cpu().numpy()])

            # cam_map = grad_cam(model, x, y, getattr(model.enet.blocks, '6'))

            # plotting_utils.plot_hor([one_map.cpu().numpy()
            #                         for one_map in cam_map])

            # gag_map = guided_absolute_grad(model, x, y, num_samples=10, blur=True)

            # plotting_utils.plot_hor([one_map.cpu().numpy()
            #                          for one_map in gag_map])

            # shap_values = shap_map(
            #     model,
            #     x,
            #     y,
            #     lambda i: torch.clamp(i, 0, 1).cpu(
            #     ).numpy().transpose(0, 2, 3, 1),
            #     ("blur(128,128)", (224, 224, 3)),
            #     device,
            #     shap_params=dict(
            #         max_evals=500,
            #         batch_size=20,
            #     )
            # )

            # plotting_utils.plot_hor(shap_values)

            # explanations = lime_map(
            #     model,
            #     x,
            #     y,
            #     device,
            #     lime_params=dict(
            #         num_samples=10,
            #         progress_bar=True,
            #     )
            # )

            # plotting_utils.plot_hor(explanations)

            # ======== RCAP =========

            # pbar.set_postfix_str("Grad Cam")
            # grad_cam_rcap = batch_rcap(model, (x, y), grad_cam, dict(
            #     target_layers=getattr(model.enet.blocks, '6')))

            # results[fold]['grad_cam_rs'].append(grad_cam_rcap)

            # pbar.set_postfix_str("GAG")
            # gag_map_rcap = batch_rcap(model, (x, y), guided_absolute_grad, dict(
            #     blur=True
            # ))
            # results[fold]['gag_rs'].append(gag_map_rcap)

            # pbar.set_postfix_str("SHAP")
            # shap_rcap = batch_rcap(model, (x, y), shap_map, dict(
            #     image_processor=lambda i: torch.clamp(
            #         i, 0, 1).cpu().numpy().transpose(0, 2, 3, 1),
            #     masker_params=("blur(128,128)", (224, 224, 3)),
            #     device=device,
            #     norm_output=True,
            #     # shap_params=dict(
            #     #     max_evals=100,
            #     #     batch_size=10
            #     # )
            # ))
            # results[fold]['shap_rs'].append(shap_rcap)

            pbar.set_postfix_str("lime")
            lime_rcap = batch_rcap(model, (x, y), lime_map, dict(
                device=device,
                lime_params=dict(
                    num_samples=100,
                )
            ))
            results[fold]['lime_rs'].append(lime_rcap)
            pbar.update(1)
    #         break
    # break

Fold 0:   0%|          | 0/73 [00:00<?, ?it/s]

Fold 1:   0%|          | 0/73 [00:00<?, ?it/s]

Fold 2:   0%|          | 0/73 [00:00<?, ?it/s]

Fold 3:   0%|          | 0/73 [00:00<?, ?it/s]

Fold 4:   0%|          | 0/73 [00:00<?, ?it/s]

In [4]:
import numpy as np

np.save('results_lime.npy', results)

In [13]:
import numpy as np
np.set_printoptions(suppress=True)


def print_rs(rs):
    target_aggregated_rcap_keys = [
        # 'original_pred_score',
        # 'recovered_pred_score',
        # 'original_pred_prob',
        'recovered_pred_prob',
        # 'local_heat_mean',
        # 'local_heat_sum',
        # 'overall_heat_mean',
        # 'overall_heat_sum',
        # 'all_original_pred_prob_full',
        # 'all_recovered_pred_prob_full',
        'overall_rcap'
    ]
    rs_data = []
    for fold, fold_exp in enumerate(rs):
        for xai_key, xai_rcap_result_of_all_batches in fold_exp.items():
            recovered_pred_prob_list = []
            rcap_list = []
            visual_noize_level = []
            for xai_rcap_result_of_one_batch in xai_rcap_result_of_all_batches:
                recovered_pred_prob_of_batch = xai_rcap_result_of_one_batch['recovered_pred_prob']
                recovered_pred_prob_list.append(
                    np.array(recovered_pred_prob_of_batch).mean())
                rcap_list.append(
                    np.array(
                        xai_rcap_result_of_one_batch['overall_rcap']['RCAP']).mean()
                )
                visual_noize_level.append(
                    np.array(
                        xai_rcap_result_of_one_batch['overall_rcap']['visual_noise_level']).mean()
                )
            # print(len(recovered_pred_prob_list))
            # print(len(rcap_list))
            # print(len(visual_noize_level))
            rs_data.append({
                'fold': fold,
                'xai_key': xai_key,
                'localization': np.array(recovered_pred_prob_list).mean(),
                'rcap': np.array(rcap_list).mean(),
                'visual_noize_level': np.array(visual_noize_level).mean()
            })
        break
    display(pd.DataFrame(rs_data))

In [14]:
print_rs(np.load('results_cam_gag_shap.npy', allow_pickle=True))

  'localization': np.array(recovered_pred_prob_list).mean(),
  ret = ret.dtype.type(ret / rcount)
  'rcap': np.array(rcap_list).mean(),
  'visual_noize_level': np.array(visual_noize_level).mean()


Unnamed: 0,fold,xai_key,localization,rcap,visual_noize_level
0,0,gag_rs,0.723225,0.587459,0.811923
1,0,grad_cam_rs,0.751662,0.552296,0.732856
2,0,shap_rs,0.754528,0.382498,0.502029
3,0,lime_rs,,,
