In [None]:
AttributeError# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# !pip install captum==0.5.0
# !pip install h5py==3.7.0
# !pip install imbalanced-learn==0.9.1
# !pip install joblib==1.2.0
# !pip install matplotlib-base==3.5.3
# !pip install networkx==2.8.4
# !pip install numpy==1.23.4
# !pip install pandas==1.5.1
# !pip install pyg==2.2.0
# !pip install torch==1.13.0
# !pip install scikit-learn==1.1.3
# !pip install seaborn==0.12.1
# !pip install torchaudio==0.13.0
# !pip install torchvision==0.14.0
# !pip install tqdm==4.64.1
# !pip install build==0.10.0
# !pip install docopt==0.6.2
# !pip install kaggle==1.5.12
# !pip install pip-tools==6.12.3
# !pip install pipreqs==0.4.11
# !pip install pyproject-hooks==1.0.0
# !pip install python-slugify==7.0.0
# !pip install text-unidecode==1.3
# !pip install yarg==0.1.9
# !pip install pytorch-lightning==2.0.2
# !pip install e2cnn==0.2.3
# !pip install wandb==0.15.0
# !pip install nltk==3.8.1

In [None]:
!git clone https://github.com/JonathanCrabbe/RobustXAI.git

In [None]:
!pip install torch_geometric
!pip install e2cnn==0.2.3
!pip install captum==0.5.0

In [None]:
# %cd /opt/conda/envs/robustxai

In [None]:
# %cd /kaggle/working/RobustXAI

In [None]:
!pwd

In [None]:
# !mv /kaggle/working/RobustXAI/* ..
# # !mv /kaggle/working/RobustXAI/.[!.]* .

In [None]:
!mv /kaggle/working/RobustXAI/* /kaggle/working/

In [None]:
!rm -d /kaggle/working/RobustXAI

In [None]:
!cd /kaggle/working/RobustXAI
!ls

In [None]:
# !python -m experiments.ecg --name feature_importance --train --plot

In [None]:
import numpy as np
import torch
import torch.nn as nn
import os
import logging
import argparse
import pandas as pd
import itertools
from pathlib import Path
from torch.utils.data import DataLoader, RandomSampler, Subset
# from datasets.loaders import ECGDataset
from models.time_series import AllCNN, StandardCNN
from utils.symmetries import Translation1D
from utils.misc import set_random_seed
# from utils.plots import (
#     single_robustness_plots,
#     relaxing_invariance_plots,
#     enforce_invariance_plot,
#     sensitivity_plot,
# )
from interpretability.robustness import (
    accuracy,
    InvariantExplainer,
    model_invariance_exact,
    explanation_invariance_exact,
    explanation_equivariance_exact,
    sensitivity,
)
from interpretability.example import (
    SimplEx,
    RepresentationSimilarity,
    TracIn,
    InfluenceFunctions,
)
from interpretability.feature import FeatureImportance
from interpretability.concept import CAR, CAV, ConceptExplainer
from captum.attr import (
    IntegratedGradients,
    GradientShap,
    FeaturePermutation,
    FeatureAblation,
    Occlusion,
    DeepLift
)

In [None]:
import logging

logger = logging.getLogger()
logger.setLevel(logging.INFO)

# 检查是否已有处理器，避免重复添加
# if not logger.handlers:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

# 测试日志
logger.info("This is an info 33message")


### datasets.loaders

In [None]:
import os
import random
import re
from abc import ABC, abstractmethod
from collections import Counter
from functools import partial
from pathlib import Path

import h5py
import networkx as nx
import nltk
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from imblearn.over_sampling import SMOTE
from joblib import Parallel, delayed
from torch.utils.data import Dataset, SubsetRandomSampler
from torch.utils.data.dataset import random_split
from torch_geometric.datasets import TUDataset
from torchvision.datasets import CIFAR100, STL10, FashionMNIST, ImageFolder
from torchvision.transforms import transforms
from tqdm import tqdm

from utils.misc import to_molecule

class ConceptDataset(ABC, Dataset):
    @property
    @abstractmethod
    def concept_names(self):
        ...

    @abstractmethod
    def generate_concept_dataset(self, concept_id: int, concept_set_size: int) -> tuple:
        ...

        
class ECGDataset(ConceptDataset):
    def __init__(
        self,
        data_dir: Path,
        train: bool,
        balance_dataset: bool,
        random_seed: int = 42,
        binarize_label: bool = True,
    ):
        """
        Generate a ECG dataset
        Args:
            data_dir: directory where the dataset should be stored
            train: True if the training set should be returned, False for the testing set
            balance_dataset: True if the classes should be balanced with SMOTE
            random_seed: random seed for reproducibility
            binarize_label: True if the label should be binarized (0: normal heartbeat, 1: abnormal heartbeat)
        """
        self.data_dir = data_dir
        if not data_dir.exists():
            os.makedirs(data_dir)
            self.download()
        # Read CSV; extract features and labels
#         file_path = (
#             data_dir / "mitbih_train.csv" if train else data_dir / "mitbih_test.csv"
#         )
        file_path = (
            data_dir / "mitbih_train.csv" if train else data_dir / "mitbih_test.csv"
        )
#         data_train = pd.read_csv('../input/heartbeat/mitbih_train.csv', header=None)
#         data_test = pd.read_csv('../input/heartbeat/mitbih_test.csv', header=None)
        df = pd.read_csv(file_path)
        X = df.iloc[:, :187].values
        y = df.iloc[:, 187].values
        if balance_dataset:
            n_normal = np.count_nonzero(y == 0)
            balancing_dic = {
                0: n_normal,
                1: int(n_normal / 4),
                2: int(n_normal / 4),
                3: int(n_normal / 4),
                4: int(n_normal / 4),
            }
            smote = SMOTE(random_state=random_seed, sampling_strategy=balancing_dic)
            X, y = smote.fit_resample(X, y)
        if binarize_label:
            y = np.where(y >= 1, 1, 0)
        self.X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
        self.y = torch.tensor(y, dtype=torch.long)
        self.binarize_label = binarize_label

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

    def download(self) -> None:
        import kaggle

        logging.info(f"Downloading ECG dataset in {self.data_dir}")
        kaggle.api.authenticate()
        kaggle.api.dataset_download_files(
            "shayanfazeli/heartbeat", path=self.data_dir, unzip=True
        )
        logging.info(f"ECG dataset downloaded in {self.data_dir}")

    def generate_concept_dataset(self, concept_id: int, concept_set_size: int) -> tuple:
        """
        Return a concept dataset with positive/negatives for ECG
        Args:
            random_seed: random seed for reproducibility
            concept_set_size: size of the positive and negative subset
        Returns:
            a concept dataset of the form X (features),C (concept labels)
        """
        assert not self.binarize_label
        mask = self.y == concept_id + 1
        positive_idx = torch.nonzero(mask).flatten()
        negative_idx = torch.nonzero(~mask).flatten()
        positive_loader = torch.utils.data.DataLoader(
            self, batch_size=concept_set_size, sampler=SubsetRandomSampler(positive_idx)
        )
        negative_loader = torch.utils.data.DataLoader(
            self, batch_size=concept_set_size, sampler=SubsetRandomSampler(negative_idx)
        )
        X_pos, C_pos = next(iter(positive_loader))
        X_neg, C_neg = next(iter(negative_loader))
        X = torch.concatenate((X_pos, X_neg), 0)
        C = torch.concatenate(
            (torch.ones(concept_set_size), torch.zeros(concept_set_size)), 0
        )
        rand_perm = torch.randperm(len(X))
        return X[rand_perm], C[rand_perm]

    def concept_names(self):
        return ["Supraventricular", "Premature Ventricular", "Fusion Beats", "Unknown"]

### utils.plots

In [None]:
import argparse
import itertools
import json
import logging
import textwrap
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import seaborn as sns

sns.set_style("whitegrid")
sns.set_palette("colorblind")
markers = {
    "DeepLift": "o",
    "Feature Ablation": "s",
    "Feature Occlusion": "X",
    "Feature Permutation": "D",
    "Gradient Shap": "v",
    "Integrated Gradients": "p",
    "Influence Functions": "^",
    "Rep. Similar-Lin1": "*",
    "SimplEx-Lin1": "H",
    "TracIn": ">",
    "CAR-Lin1": "<",
    "CAV-Lin1": "d",
}


def single_robustness_plots(plot_dir: Path, dataset: str, experiment_name: str) -> None:
    metrics_df = pd.read_csv(plot_dir / "metrics.csv")
    for model_type in metrics_df["Model Type"].unique():
        sub_df = metrics_df[metrics_df["Model Type"] == model_type]
        y = (
            "Explanation Equivariance"
            if "Explanation Equivariance" in metrics_df.columns
            else "Explanation Invariance"
        )
        ax = sns.boxplot(sub_df, x="Explanation", y=y, showfliers=False)
        wrap_labels(ax, 10)
        plt.ylim(-1.1, 1.1)
        plt.tight_layout()
        plt.savefig(
            plot_dir
            / f'{experiment_name}_{dataset}_{model_type.lower().replace(" ", "_")}.pdf'
        )
        plt.close()


def global_robustness_plots(experiment_name: str) -> None:
    sns.set(font_scale=0.9)
    sns.set_style("whitegrid")
    sns.set_palette("colorblind")
    with open(Path.cwd() / "results_dir.json") as f:
        path_dic = json.load(f)
    global_df = []
    for dataset in path_dic:
        dataset_df = pd.read_csv(
            Path.cwd() / path_dic[dataset] / experiment_name / "metrics.csv"
        )
        dataset_df["Dataset"] = [dataset] * len(dataset_df)
        global_df.append(dataset_df)
    global_df = pd.concat(global_df)
    rename_dic = {
        "SimplEx-Lin1": "SimplEx-Inv",
        "SimplEx-Conv3": "SimplEx-Equiv",
        "Representation Similarity-Lin1": "Rep. Similar-Inv",
        "Representation Similarity-Conv3": "Rep. Similar-Equiv",
        "CAR-Lin1": "CAR-Inv",
        "CAR-Conv3": "CAR-Equiv",
        "CAV-Lin1": "CAV-Inv",
        "CAV-Conv3": "CAV-Equiv",
        "SimplEx-Phi": "SimplEx-Equiv",
        "SimplEx-Rho": "SimplEx-Inv",
        "Representation Similarity-Phi": "Rep. Similar-Equiv",
        "Representation Similarity-Rho": "Rep. Similar-Inv",
        "CAR-Phi": "CAR-Equiv",
        "CAR-Rho": "CAR-Inv",
        "CAV-Phi": "CAV-Equiv",
        "CAV-Rho": "CAV-Inv",
        "CAR-Conv1": "CAR-Equiv",
        "CAV-Conv1": "CAV-Equiv",
        "SimplEx-Conv1": "SimplEx-Equiv",
        "Representation Similarity-Conv1": "Rep. Similar-Equiv",
        "CAR-Layer3": "CAR-Inv",
        "CAV-Layer3": "CAV-Inv",
        "SimplEx-Layer3": "SimplEx-Inv",
        "Representation Similarity-Layer3": "Rep. Similar-Inv",
        "CAR-Embedding": "CAR-Inv",
        "CAV-Embedding": "CAV-Inv",
        "SimplEx-Embedding": "SimplEx-Inv",
        "Representation Similarity-Embedding": "Rep. Similar-Inv",
    }
    global_df = global_df.replace(rename_dic)
    global_df = global_df[
        (global_df["Model Type"] == "All-CNN")
        | (global_df["Model Type"] == "GNN")
        | (global_df["Model Type"] == "Deep-Set")
        | (global_df["Model Type"] == "D8-Wide-ResNet")
        | (global_df["Model Type"] == "bow_classifier")
    ]
    y = (
        "Explanation Equivariance"
        if "Explanation Equivariance" in global_df.columns
        else "Explanation Invariance"
    )
    ax = sns.boxplot(global_df, x="Dataset", hue="Explanation", y=y, showfliers=False)
    wrap_labels(ax, 10)
    plt.ylim(-1.1, 1.1)
    box_patches = [
        patch for patch in ax.patches if type(patch) == matplotlib.patches.PathPatch
    ]
    if (
        len(box_patches) == 0
    ):  # in matplotlib older than 3.5, the boxes are stored in ax2.artists
        box_patches = ax.artists
    num_patches = len(box_patches)
    lines_per_boxplot = len(ax.lines) // num_patches
    for i, patch in enumerate(box_patches):
        # Set the linecolor on the patch to the facecolor, and set the facecolor to None
        col = patch.get_facecolor()
        patch.set_edgecolor(col)
        patch.set_facecolor("None")

        # Each box has associated Line2D objects (to make the whiskers, fliers, etc.)
        # Loop over them here, and use the same color as above
        for line in ax.lines[i * lines_per_boxplot : (i + 1) * lines_per_boxplot]:
            line.set_color(col)
            line.set_mfc(col)  # facecolor of fliers
            line.set_mec(col)  # edgecolor of fliers

    # Also fix the legend
    for legpatch in ax.legend_.get_patches():
        col = legpatch.get_facecolor()
        legpatch.set_edgecolor(col)
        legpatch.set_facecolor("None")
    sns.despine(left=True)
    plt.tight_layout()
    plt.savefig(Path.cwd() / f"results/{experiment_name}_global_robustness.pdf")
    plt.close()


def relaxing_invariance_plots(
    plot_dir: Path, dataset: str, experiment_name: str
) -> None:
    sns.set(font_scale=1.2)
    sns.set_style("whitegrid")
    sns.set_palette("colorblind")
    metrics_df = pd.read_csv(plot_dir / "metrics.csv")
    metrics_df = metrics_df.drop(
        metrics_df[
            (metrics_df.Explanation == "SimplEx-Conv3")
            | (metrics_df.Explanation == "Representation Similarity-Conv3")
            | (metrics_df.Explanation == "CAR-Conv3")
            | (metrics_df.Explanation == "CAV-Conv3")
        ].index
    )
    rename_dic = {"Representation Similarity-Lin1": "Rep. Similar-Lin1"}
    metrics_df = metrics_df.replace(rename_dic)
    y = (
        "Explanation Equivariance"
        if "Explanation Equivariance" in metrics_df.columns
        else "Explanation Invariance"
    )
    plot_df = metrics_df.groupby(["Model Type", "Explanation"]).mean()
    plot_df[["Model Invariance CI", f"{y} CI"]] = (
        2 * metrics_df.groupby(["Model Type", "Explanation"]).sem()
    )
    unique_explanations = metrics_df["Explanation"].unique()
    trimmed_markers = {key: value for key, value in markers.items() if key in unique_explanations}
    sns.scatterplot(
        plot_df,
        x="Model Invariance",
        y=y,
        hue="Model Type",
        edgecolor="black",
        alpha=0.5,
        style="Explanation",
        # markers=markers[: metrics_df["Explanation"].unique()],
        markers=trimmed_markers,
        s=100,
    )
    plt.errorbar(
        x=plot_df["Model Invariance"],
        y=plot_df[y],
        xerr=plot_df["Model Invariance CI"],
        yerr=plot_df[f"{y} CI"],
        ecolor="black",
        elinewidth=1.7,
        linestyle="",
        capsize=1.7,
        capthick=1.7,
    )
    plt.xscale("linear")
    plt.axline((0, 0), slope=1, color="gray", linestyle="dotted")
    plt.xlim(0, 1.1)
    plt.ylim(0, 1.1)
    plt.tight_layout()
    plt.savefig(plot_dir / f"{experiment_name}_{dataset}_relaxing_invariance.pdf")
    plt.close()


def mc_convergence_plot(plot_dir: Path, dataset: str, experiment_name: str) -> None:
    metrics_df = pd.read_csv(plot_dir / "metrics.csv")
    for estimator_name in metrics_df["Estimator Name"].unique():
        metrics_subdf = metrics_df[metrics_df["Estimator Name"] == estimator_name]
        x = metrics_subdf["Number of MC Samples"]
        y = metrics_subdf["Estimator Value"]
        ci = 2 * metrics_subdf["Estimator SEM"]
        plt.plot(x, y, label=estimator_name)
        plt.fill_between(x, y - ci, y + ci, alpha=0.2)
    plt.legend()
    plt.xlabel(r"$N_{\mathrm{samp}}$")
    plt.ylabel("Monte Carlo Estimator")
    plt.ylim(-1, 1)
    plt.tight_layout()
    plt.savefig(plot_dir / f"{experiment_name}_{dataset}.pdf")
    plt.close()


def understanding_randomness_plots(plot_dir: Path, dataset: str) -> None:
    data_df = pd.read_csv(plot_dir / "data.csv")
    sub_df = data_df[data_df["Baseline"] == False]
    print(sub_df)
    sns.kdeplot(data=data_df, x="y1", y="y2", hue="Model Type", fill=True)
    for model_type in data_df["Model Type"].unique():
        baseline = data_df[
            (data_df["Model Type"] == model_type) & (data_df["Baseline"] == True)
        ]
        plt.plot(
            baseline["y1"],
            baseline["y2"],
            marker="x",
            linewidth=0,
            label=f"Baseline {model_type}",
        )
    plt.axhline(0, color="black")
    plt.axvline(0, color="black")
    plt.xlabel(r"$y_1$")
    plt.ylabel(r"$y_2$")
    plt.legend()
    plt.show()


def enforce_invariance_plot(plot_dir: Path, dataset: str) -> None:
    sns.set(font_scale=1.3)
    sns.set_style("whitegrid")
    sns.set_palette("colorblind")
    metrics_df = pd.read_csv(plot_dir / "metrics.csv")
    sns.lineplot(metrics_df, x="N_inv", y="Explanation Invariance", hue="Explanation")
    plt.legend()
    plt.xlabel(r"$N_{\mathrm{inv}}$")
    plt.tight_layout()
    plt.savefig(plot_dir / f"enforce_invariance_{dataset}.pdf")
    plt.close()


def sensitivity_plot(plot_dir: Path, dataset: str) -> None:
    metrics_df = pd.read_csv(plot_dir / "metrics.csv")
    sns.scatterplot(
        metrics_df,
        x="Explanation Sensitivity",
        y="Explanation Equivariance",
        hue="Explanation",
        alpha=0.5,
        s=10,
    )
    plt.legend()
    plt.tight_layout()
    plt.savefig(plot_dir / f"sensitivity_comparison_{dataset}.pdf")
    plt.close()


def draw_molecule(g, edge_mask=None, draw_edge_labels=False):
    g = g.copy().to_undirected()
    node_labels = {}
    for u, data in g.nodes(data=True):
        node_labels[u] = data["name"]
    pos = nx.planar_layout(g)
    pos = nx.spring_layout(g, pos=pos)
    if edge_mask is None:
        edge_color = "black"
        widths = None
    else:
        edge_color = [edge_mask[(u, v)] for u, v in g.edges()]
        widths = [x * 10 for x in edge_color]
    nx.draw(
        g,
        pos=pos,
        labels=node_labels,
        width=widths,
        edge_color=edge_color,
        edge_cmap=plt.cm.Blues,
        node_color="azure",
    )

    if draw_edge_labels and edge_mask is not None:
        edge_labels = {k: ("%.2f" % v) for k, v in edge_mask.items()}
        nx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels, font_color="red")
    plt.show()


def wrap_labels(ax, width, break_long_words=False, do_y: bool = False) -> None:
    """
    Break labels in several lines in a figure
    Args:
        ax: figure axes
        width: maximal number of characters per line
        break_long_words: if True, allow breaks in the middle of a word
        do_y: if True, apply the function to the y axis as well
    Returns:
    """
    labels = []
    for label in ax.get_xticklabels():
        text = label.get_text()
        labels.append(
            textwrap.fill(text, width=width, break_long_words=break_long_words)
        )
    ax.set_xticklabels(labels, rotation=0)
    if do_y:
        labels = []
        for label in ax.get_yticklabels():
            text = label.get_text()
            labels.append(
                textwrap.fill(text, width=width, break_long_words=break_long_words)
            )
        ax.set_yticklabels(labels, rotation=0)


def global_relax_invariance() -> None:
    sns.set(font_scale=1.2)
    sns.set_style("whitegrid")
    with open(Path.cwd() / "results_dir.json") as f:
        path_dic = json.load(f)
    global_df = []
    for dataset, experiment_name in itertools.product(
        ["ECG", "Fa.MNIST"],
        ["feature_importance", "example_importance", "concept_importance"],
    ):
        dataset_df = pd.read_csv(
            Path.cwd() / path_dic[dataset] / experiment_name / "metrics.csv"
        )
        dataset_df["Dataset"] = [dataset] * len(dataset_df)
        dataset_df["Experiment"] = [experiment_name] * len(dataset_df)
        dataset_df = dataset_df.drop(
            dataset_df[
                (dataset_df.Explanation == "SimplEx-Conv3")
                | (dataset_df.Explanation == "Representation Similarity-Conv3")
                | (dataset_df.Explanation == "CAR-Conv3")
                | (dataset_df.Explanation == "CAV-Conv3")
            ].index
        )
        rename_dic = {"Representation Similarity-Lin1": "Rep. Similar-Lin1"}
        dataset_df = dataset_df.replace(rename_dic)
        global_df.append(dataset_df)
    global_df = pd.concat(global_df)

    n_datasets = len(global_df["Dataset"].unique())

    # Create a grid of plots
    fig, axs = plt.subplots(nrows=n_datasets, ncols=3, figsize=(17, 9), sharex=True)

    datasets = global_df["Dataset"].unique()
    y_titles = [
        "Feature Importance Equivariance",
        "Example Importance Invariance",
        "Concept Importance Invariance",
    ]
    experiments = global_df["Experiment"].unique()
    style_handles = []
    style_labels = []
    # Loop over the subplots and plot the data
    for i, dataset in enumerate(datasets):  # rows
        for j, experiment in enumerate(experiments):  # columns
            ax = axs[i, j]
            metrics_df = global_df[
                (global_df["Dataset"] == dataset)
                & (global_df["Experiment"] == experiment)
            ]
            y = (
                "Explanation Equivariance"
                if "feature" in experiment
                else "Explanation Invariance"
            )
            plot_df = metrics_df.groupby(["Model Type", "Explanation"]).mean(
                numeric_only=True
            )
            plot_df[["Model Invariance CI", f"{y} CI"]] = 2 * metrics_df.groupby(
                ["Model Type", "Explanation"]
            )[["Model Invariance", y]].apply("sem")
            sns.scatterplot(
                ax=ax,
                data=plot_df,
                x="Model Invariance",
                y=y,
                hue="Model Type",
                edgecolor="black",
                alpha=0.5,
                style="Explanation",
                markers=markers,
                s=200,
            )
            ax.errorbar(
                x=plot_df["Model Invariance"],
                y=plot_df[y],
                xerr=plot_df["Model Invariance CI"],
                yerr=plot_df[f"{y} CI"],
                ecolor="black",
                elinewidth=1.7,
                linestyle="",
                capsize=1.7,
                capthick=1.7,
            )
            ax.set_xscale("linear")
            ax.axline((0, 0), slope=1, color="gray", linestyle="dotted")
            ax.set_xlim(0, 1.1)
            ax.set_ylim(0, 1.1)
            ax.set_ylabel(y_titles[j])
            # Get handles and labels for hue and style legends
            handles, labels = ax.get_legend_handles_labels()
            explanation_cut = labels.index("Explanation") + int(j > 0)
            # Create separate legends for hue and style
            if i == 0 and j == 0:
                hue_handles = handles[
                    :explanation_cut
                ]  # first half of handles are for hue
                hue_labels = labels[:explanation_cut]
            if i == len(datasets) - 1:
                style_handles.extend(handles[explanation_cut:])
                style_labels.extend(labels[explanation_cut:])

            ax.legend().remove()
            if j == 1:
                ax.set_title(dataset)
    fig.legend(
        hue_handles + style_handles,
        hue_labels + style_labels,
        loc="lower center",
        ncol=5,
        bbox_to_anchor=(0.5, -0.1),
    )
    # fig.tight_layout()

    plt.savefig(
        Path.cwd() / f"results/global_relax_invariance.pdf", bbox_inches="tight"
    )
    plt.close()


def training_dynamic_plot(
    data_path: Path = Path.cwd() / "results/d8-wideresnet-training_dynamics.csv",
) -> None:
    sns.set(font_scale=1.0)
    sns.set_style("whitegrid")
    df = pd.read_csv(data_path)
    df = df[
        [
            "epoch",
            "cifar100_d8_wideresnet_seed42 - model_invariance",
            "stl10_d8_wideresnet_seed42 - model_invariance",
            "cifar100_d8_wideresnet_seed42 - gradient_equivariance",
            "stl10_d8_wideresnet_seed42 - gradient_equivariance",
        ]
    ]
    rename_cols = {
        "epoch": "Epoch",
        "cifar100_d8_wideresnet_seed42 - model_invariance": "CIFAR100 Model Invariance",
        "stl10_d8_wideresnet_seed42 - model_invariance": "STL10 Model Invariance",
        "cifar100_d8_wideresnet_seed42 - gradient_equivariance": "CIFAR100 Gradient Equivariance",
        "stl10_d8_wideresnet_seed42 - gradient_equivariance": "STL10 Gradient Equivariance",
    }
    df = df.rename(columns=rename_cols)
    data = []
    for dataset in ["CIFAR100", "STL10"]:
        for property in ["Model Invariance", "Gradient Equivariance"]:
            for epoch, score in df[["Epoch", f"{dataset} {property}"]].values:
                data.append(
                    {
                        "Dataset": dataset,
                        "Property": property,
                        "Epoch": epoch,
                        "Score": score,
                    }
                )

    plot_df = pd.DataFrame(data)
    sns.lineplot(data=plot_df, x="Epoch", y="Score", hue="Dataset", style="Property")
    plt.savefig(Path.cwd() / "results/training_dynamics.pdf", bbox_inches="tight")


# if __name__ == "__main__":
#     logging.basicConfig(
#         level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
#     )
#     parser = argparse.ArgumentParser()
#     parser.add_argument("--experiment_name", type=str, default="feature_importance")
#     parser.add_argument("--plot_name", type=str, default="relax_invariance")
#     parser.add_argument("--dataset", type=str, default="ecg")
#     parser.add_argument("--model", type=str, default="cnn32_seed42")
#     parser.add_argument("--concept", type=str, default=None)
#     args = parser.parse_args()
#     with open(Path.cwd() / "results_dir.json") as f:
#         path_dic = json.load(f)
#     dataset_full_names = {
#         "ecg": "ECG",
#         "mut": "Muta.",
#         "mnet": "M.Net40",
#         "fashion_mnist": "Fa.MNIST",
#     }
#     plot_path = (
#         (Path.cwd() / path_dic[dataset_full_names[args.dataset]] / args.experiment_name)
#         if "global" not in args.plot_name and args.plot_name != "training_dynamics"
#         else Path.cwd() / "results"
#     )

#     logging.info(f"Saving {args.plot_name} plot in {str(plot_path)}")
#     match args.plot_name:
#         case "robustness":
#             single_robustness_plots(plot_path, args.dataset, args.experiment_name)
#         case "global_robustness":
#             global_robustness_plots(args.experiment_name)
#         case "relax_invariance":
#             relaxing_invariance_plots(plot_path, args.dataset, args.experiment_name)
#         case "mc_convergence":
#             mc_convergence_plot(plot_path, args.dataset, args.experiment_name)
#         case "enforce_invariance":
#             enforce_invariance_plot(plot_path, args.dataset)
#         case "sensitivity_comparison":
#             sensitivity_plot(plot_path, args.dataset)
#         case "global_relax_invariance":
#             global_relax_invariance()
#         case "training_dynamics":
#             training_dynamic_plot()
#         case other:
#             raise ValueError("Unknown plot name")


In [None]:
Path.cwd()

### ecg main

In [None]:

def train_ecg_model(
    random_seed: int,
    latent_dim: int,
    batch_size: int,
    model_name: str = "model",
    model_dir: Path = Path.cwd() / f"results/ecg/",
#     data_dir: Path = Path.cwd() / "datasets/ecg",
    data_dir: Path = Path('../input/heartbeat'),
) -> None:
    logging.info("Fitting the ECG classifiers")
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    set_random_seed(random_seed)
    model_dir = model_dir / model_name
    if not model_dir.exists():
        os.makedirs(model_dir)
    models = {
        "All-CNN": AllCNN(latent_dim, f"{model_name}_allcnn"),
        "Standard-CNN": StandardCNN(latent_dim, f"{model_name}_standard"),
        "Augmented-CNN": StandardCNN(latent_dim, f"{model_name}_augmented"),
    }
    train_set = ECGDataset(data_dir, train=True, balance_dataset=True)
    test_set = ECGDataset(data_dir, train=False, balance_dataset=False)
    train_loader = DataLoader(train_set, batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size, shuffle=True)
    for model_type in models:
        logging.info(f"Now fitting a {model_type} classifier")
        if model_type == "Augmented-CNN":
            models[model_type].fit(
                device,
                train_loader,
                test_loader,
                model_dir,
                augmentation=True,
                checkpoint_interval=10,
            )
        else:
            models[model_type].fit(
                device,
                train_loader,
                test_loader,
                model_dir,
                augmentation=False,
                checkpoint_interval=10,
            )

def feature_importance(
    random_seed: int,
    latent_dim: int,
    batch_size: int,
    plot: bool,
    model_name: str = "model",
    model_dir: Path = Path.cwd() / f"results/ecg/",
#   data_dir: Path = Path.cwd() / "datasets/ecg",
    data_dir: Path = Path('../input/heartbeat'),
    n_test: int = 1000,
) -> None:
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    set_random_seed(random_seed)
    test_set = ECGDataset(data_dir, train=False, balance_dataset=False)
#     print(test_set)
    test_subset = Subset(test_set, torch.randperm(len(test_set))[:n_test])
    test_loader = DataLoader(test_subset, batch_size)
    models = {
        "All-CNN": AllCNN(latent_dim, f"{model_name}_allcnn"),
        "Standard-CNN": StandardCNN(latent_dim, f"{model_name}_standard"),
        "Augmented-CNN": StandardCNN(latent_dim, f"{model_name}_augmented"),
    }
    attr_methods = {
        "DeepLift": DeepLift,
        "Integrated Gradients": IntegratedGradients,
        "Gradient Shap": GradientShap,
        "Feature Permutation": FeaturePermutation,
        "Feature Ablation": FeatureAblation,
        "Feature Occlusion": Occlusion,
    }
    model_dir = model_dir / model_name
    save_dir = model_dir / "feature_importance"
    if not save_dir.exists():
        os.makedirs(save_dir)
    translation = Translation1D()
    metrics = []
    for model_type in models:
        logging.info(f"Now working with {model_type} classifier")
        model = models[model_type]
        model.load_metadata(model_dir)
        model.load_state_dict(torch.load(model_dir / f"{model.name}.pt"), strict=False)
        model.to(device).eval()
        model_inv = model_invariance_exact(model, translation, test_loader, device)
        logging.info(f"Model invariance: {torch.mean(model_inv):.3g}")
        for attr_name in attr_methods:
            logging.info(f"Now working with {attr_name} explainer")
            feat_importance = FeatureImportance(attr_methods[attr_name](model))
            explanation_equiv = explanation_equivariance_exact(
                feat_importance, translation, test_loader, device
            )
            for inv, equiv in zip(model_inv, explanation_equiv):
                metrics.append([model_type, attr_name, inv.item(), equiv.item()])
            logging.info(
                f"Explanation equivariance: {torch.mean(explanation_equiv):.3g}"
            )
    metrics_df = pd.DataFrame(
        data=metrics,
        columns=[
            "Model Type",
            "Explanation",
            "Model Invariance",
            "Explanation Equivariance",
        ],
    )
    metrics_df.to_csv(save_dir / "metrics.csv", index=False)
    if plot:
        single_robustness_plots(save_dir, "ecg", "feature_importance")
        # save_dir = Path("E:/RobustXAI-main/results/ecg/cnn32_seed42/feature_importance")
        relaxing_invariance_plots(save_dir, "ecg", "feature_importance")
        
def example_importance(
    random_seed: int,
    latent_dim: int,
    batch_size: int,
    plot: bool,
    model_name: str = "model",
    model_dir: Path = Path.cwd() / f"results/ecg/",
    data_dir: Path = Path.cwd() / "datasets/ecg",
    n_test: int = 1000,
    n_train: int = 100,
    recursion_depth: int = 100,
) -> None:
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    set_random_seed(random_seed)
    train_set = ECGDataset(data_dir, train=True, balance_dataset=False)
    train_loader = DataLoader(train_set, n_train, shuffle=True)
    X_train, Y_train = next(iter(train_loader))
    X_train, Y_train = X_train.to(device), Y_train.to(device)
    train_sampler = RandomSampler(
        train_set, replacement=True, num_samples=recursion_depth * batch_size
    )
    train_loader_replacement = DataLoader(train_set, batch_size, sampler=train_sampler)
    test_set = ECGDataset(data_dir, train=False, balance_dataset=False)
    test_subset = Subset(test_set, torch.randperm(len(test_set))[:n_test])
    test_loader = DataLoader(test_subset, batch_size)
    models = {
        "All-CNN": AllCNN(latent_dim, f"{model_name}_allcnn"),
        "Standard-CNN": StandardCNN(latent_dim, f"{model_name}_standard"),
        "Augmented-CNN": StandardCNN(latent_dim, f"{model_name}_augmented"),
    }
    attr_methods = {
        "SimplEx": SimplEx,
        "Representation Similarity": RepresentationSimilarity,
        "TracIn": TracIn,
        "Influence Functions": InfluenceFunctions,
    }
    model_dir = model_dir / model_name
    save_dir = model_dir / "example_importance"
    if not save_dir.exists():
        os.makedirs(save_dir)
    translation = Translation1D()
    metrics = []
    for model_type in models:
        logging.info(f"Now working with {model_type} classifier")
        model = models[model_type]
        model.load_metadata(model_dir)
        model.load_state_dict(torch.load(model_dir / f"{model.name}.pt"), strict=False)
        model.to(device).eval()
        model_inv = model_invariance_exact(model, translation, test_loader, device)
        logging.info(f"Model invariance: {torch.mean(model_inv):.3g}")
        model_layers = {"Lin1": model.fc1, "Conv3": model.cnn3}
        for attr_name in attr_methods:
            logging.info(f"Now working with {attr_name} explainer")
            model.load_state_dict(
                torch.load(model_dir / f"{model.name}.pt"), strict=False
            )
            if attr_name in {"TracIn", "Influence Functions"}:
                ex_importance = attr_methods[attr_name](
                    model,
                    X_train,
                    Y_train=Y_train,
                    train_loader=train_loader_replacement,
                    loss_function=nn.CrossEntropyLoss(),
                    save_dir=save_dir / model.name,
                    recursion_depth=recursion_depth,
                )
                explanation_inv = explanation_invariance_exact(
                    ex_importance, translation, test_loader, device
                )
                for inv_model, inv_expl in zip(model_inv, explanation_inv):
                    metrics.append(
                        [model_type, attr_name, inv_model.item(), inv_expl.item()]
                    )
                logging.info(
                    f"Explanation invariance: {torch.mean(explanation_inv):.3g}"
                )
            else:
                for layer_name in model_layers:
                    ex_importance = attr_methods[attr_name](
                        model, X_train, Y_train=Y_train, layer=model_layers[layer_name]
                    )
                    explanation_inv = explanation_invariance_exact(
                        ex_importance, translation, test_loader, device
                    )
                    ex_importance.remove_hook()
                    for inv_model, inv_expl in zip(model_inv, explanation_inv):
                        metrics.append(
                            [
                                model_type,
                                f"{attr_name}-{layer_name}",
                                inv_model.item(),
                                inv_expl.item(),
                            ]
                        )
                    logging.info(
                        f"Explanation invariance for {layer_name}: {torch.mean(explanation_inv):.3g}"
                    )
    metrics_df = pd.DataFrame(
        data=metrics,
        columns=[
            "Model Type",
            "Explanation",
            "Model Invariance",
            "Explanation Invariance",
        ],
    )
    metrics_df.to_csv(save_dir / "metrics.csv", index=False)
    if plot:
        single_robustness_plots(save_dir, "ecg", "example_importance")
        relaxing_invariance_plots(save_dir, "ecg", "example_importance")


def concept_importance(
    random_seed: int,
    latent_dim: int,
    batch_size: int,
    plot: bool,
    model_name: str = "model",
    model_dir: Path = Path.cwd() / f"results/ecg/",
    data_dir: Path = Path.cwd() / "datasets/ecg",
    n_test: int = 1000,
    concept_set_size: int = 100,
) -> None:
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    set_random_seed(random_seed)
    train_set = ECGDataset(
        data_dir, train=True, binarize_label=False, balance_dataset=False
    )
    test_set = ECGDataset(
        data_dir, train=False, balance_dataset=False, binarize_label=False
    )
    test_subset = Subset(test_set, torch.randperm(len(test_set))[:n_test])
    test_loader = DataLoader(test_subset, batch_size)
    models = {
        "All-CNN": AllCNN(latent_dim, f"{model_name}_allcnn"),
        "Standard-CNN": StandardCNN(latent_dim, f"{model_name}_standard"),
        "Augmented-CNN": StandardCNN(latent_dim, f"{model_name}_augmented"),
    }
    attr_methods = {"CAV": CAV, "CAR": CAR}
    model_dir = model_dir / model_name
    save_dir = model_dir / "concept_importance"
    if not save_dir.exists():
        os.makedirs(save_dir)
    translation = Translation1D()
    metrics = []
    for model_type in models:
        logging.info(f"Now working with {model_type} classifier")
        model = models[model_type]
        model.load_metadata(model_dir)
        model.load_state_dict(torch.load(model_dir / f"{model.name}.pt"), strict=False)
        model.to(device).eval()
        model_inv = model_invariance_exact(model, translation, test_loader, device)
        logging.info(f"Model invariance: {torch.mean(model_inv):.3g}")
        model_layers = {"Lin1": model.fc1, "Conv3": model.cnn3}
        for layer_name, attr_name in itertools.product(model_layers, attr_methods):
            logging.info(
                f"Now working with {attr_name} explainer on layer {layer_name}"
            )
            conc_importance = attr_methods[attr_name](
                model, train_set, n_classes=2, layer=model_layers[layer_name]
            )
            conc_importance.fit(device, concept_set_size)
            concept_acc = conc_importance.concept_accuracy(
                test_set, device, concept_set_size=concept_set_size
            )
            for concept_name in concept_acc:
                logging.info(
                    f"Concept {concept_name} accuracy: {concept_acc[concept_name]:.2g}"
                )
            explanation_inv = explanation_invariance_exact(
                conc_importance, translation, test_loader, device, similarity=accuracy
            )
            conc_importance.remove_hook()
            for inv_model, inv_expl in zip(model_inv, explanation_inv):
                metrics.append(
                    [
                        model_type,
                        f"{attr_name}-{layer_name}",
                        inv_model.item(),
                        inv_expl.item(),
                    ]
                )
            logging.info(f"Explanation invariance: {torch.mean(explanation_inv):.3g}")
    metrics_df = pd.DataFrame(
        data=metrics,
        columns=[
            "Model Type",
            "Explanation",
            "Model Invariance",
            "Explanation Invariance",
        ],
    )
    metrics_df.to_csv(save_dir / "metrics.csv", index=False)
    if plot:
        single_robustness_plots(save_dir, "ecg", "concept_importance")
        relaxing_invariance_plots(save_dir, "ecg", "concept_importance")


def enforce_invariance(
    random_seed: int,
    latent_dim: int,
    batch_size: int,
    plot: bool,
    model_name: str = "model",
    model_dir: Path = Path.cwd() / f"results/ecg/",
    data_dir: Path = Path.cwd() / "datasets/ecg",
    n_test: int = 1000,
    concept_set_size: int = 100,
) -> None:
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    set_random_seed(random_seed)
    train_set = ECGDataset(
        data_dir, train=True, binarize_label=False, balance_dataset=False
    )
    test_set = ECGDataset(data_dir, train=False, balance_dataset=False)
    test_subset = Subset(test_set, torch.randperm(len(test_set))[:n_test])
    test_loader = DataLoader(test_subset, batch_size)
    models = {"All-CNN": AllCNN(latent_dim, f"{model_name}_allcnn")}
    attr_methods = {"CAV": CAV, "CAR": CAR}
    model_dir = model_dir / model_name
    save_dir = model_dir / "enforce_invariance"
    if not save_dir.exists():
        os.makedirs(save_dir)
    translation = Translation1D()
    metrics = []
    for model_type in models:
        logging.info(f"Now working with {model_type} classifier")
        model = models[model_type]
        model.load_metadata(model_dir)
        model.load_state_dict(torch.load(model_dir / f"{model.name}.pt"), strict=False)
        model.to(device).eval()
        model_inv = model_invariance_exact(model, translation, test_loader, device)
        logging.info(f"Model invariance: {torch.mean(model_inv):.3g}")
        for attr_name in attr_methods:
            logging.info(f"Now working with {attr_name} explainer")
            attr_method = attr_methods[attr_name](
                model, train_set, n_classes=2, layer=model.cnn3
            )
            if isinstance(attr_method, ConceptExplainer):
                attr_method.fit(device, concept_set_size)
            for N_inv in [1, 5, 20, 50, 100, 187]:
                logging.info(
                    f"Now working with invariant explainer with N_inv = {N_inv}"
                )
                inv_method = InvariantExplainer(
                    attr_method,
                    translation,
                    N_inv,
                    isinstance(attr_method, ConceptExplainer),
                )
                explanation_inv = explanation_invariance_exact(
                    inv_method, translation, test_loader, device, similarity=accuracy
                )
                logging.info(
                    f"N_inv = {N_inv} - Explanation invariance = {torch.mean(explanation_inv):.3g}"
                )
                for inv_expl in explanation_inv:
                    metrics.append(
                        [model_type, f"{attr_name}-Equiv", N_inv, inv_expl.item()]
                    )
    metrics_df = pd.DataFrame(
        data=metrics,
        columns=["Model Type", "Explanation", "N_inv", "Explanation Invariance"],
    )
    metrics_df.to_csv(save_dir / "metrics.csv", index=False)
    if plot:
        enforce_invariance_plot(save_dir, "ecg")


def sensitivity_comparison(
    random_seed: int,
    latent_dim: int,
    batch_size: int,
    plot: bool,
    model_name: str = "model",
    model_dir: Path = Path.cwd() / f"results/ecg/",
    data_dir: Path = Path.cwd() / "datasets/ecg",
    n_test: int = 1000,
) -> None:
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    set_random_seed(random_seed)
    test_set = ECGDataset(data_dir, train=False, balance_dataset=False)
    test_subset = Subset(test_set, torch.randperm(len(test_set))[:n_test])
    test_loader = DataLoader(test_subset, batch_size)
    models = {"Augmented-CNN": StandardCNN(latent_dim, f"{model_name}_augmented")}
    attr_methods = {
        "Integrated Gradients": IntegratedGradients,
        "Gradient Shap": GradientShap,
        "Feature Permutation": FeaturePermutation,
        "Feature Ablation": FeatureAblation,
        "Feature Occlusion": Occlusion,
    }
    model_dir = model_dir / model_name
    save_dir = model_dir / "sensitivity"
    if not save_dir.exists():
        os.makedirs(save_dir)
    translation = Translation1D()
    metrics = []
    for model_type in models:
        logging.info(f"Now working with {model_type} classifier")
        model = models[model_type]
        model.load_metadata(model_dir)
        model.load_state_dict(torch.load(model_dir / f"{model.name}.pt"), strict=False)
        model.to(device).eval()
        for attr_name in attr_methods:
            logging.info(f"Now working with {attr_name} explainer")
            attr_method = attr_methods[attr_name](model)
            feat_importance = FeatureImportance(attr_method)
            explanation_sens = (
                sensitivity(attr_method, test_loader, device).cpu().numpy()
            )
            explanation_equiv = (
                explanation_equivariance_exact(
                    feat_importance, translation, test_loader, device
                )
                .cpu()
                .numpy()
            )
            corr = np.corrcoef(explanation_sens, explanation_equiv)
            logging.info(f"Metrics correlation: {corr[0, 1].item():.3g}")
            for sens, equiv in zip(explanation_sens, explanation_equiv):
                metrics.append([model_type, attr_name, sens, equiv])
    metrics_df = pd.DataFrame(
        data=metrics,
        columns=[
            "Model Type",
            "Explanation",
            "Explanation Sensitivity",
            "Explanation Equivariance",
        ],
    )
    metrics_df.to_csv(save_dir / "metrics.csv", index=False)
    if plot:
        sensitivity_plot(save_dir, "ecg")

In [None]:
# !pip install networkx

In [None]:
import sys
import argparse

logging.basicConfig(
        level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
    )
# sys.argv = sys.argv[:1]
# parser = argparse.ArgumentParser()
# parser.add_argument("--name", type=str, default="feature_importance")
# parser.add_argument("--seed", type=int, default=42)
# parser.add_argument("--batch_size", type=int, default=500)
# parser.add_argument("--latent_dim", type=int, default=32)
# parser.add_argument("--train", action="store_true")
# parser.add_argument("--plot", action="store_true")
# parser.add_argument("--n_test", type=int, default=1000)
# args = parser.parse_args()

class Args:
    def __init__(self):
        self.name = "feature_importance"
        self.seed = 42
        self.batch_size = 500
        self.latent_dim = 32
        self.train = False  # 默认值为 False
        self.plot = True   # 默认值为 False
        self.n_test = 1000

args = Args()
args.train = True

print(args.train)

In [None]:
model_name = f"cnn{args.latent_dim}_seed{args.seed}"
if args.train:
    train_ecg_model(
        args.seed, args.latent_dim, args.batch_size, model_name=model_name
    )

### feature_importance

In [None]:
print(args.name)

In [None]:
if args.name == "feature_importance":
    feature_importance(
        args.seed,
        args.latent_dim,
        args.batch_size,
        args.plot,
        model_name,
        n_test=args.n_test,
    )

### example_importance

In [None]:
args.name == "example_importance"
print(args.name)

In [None]:
example_importance(
                args.seed,
                args.latent_dim,
                args.batch_size,
                args.plot,
                model_name,
                n_test=args.n_test,
            )

### concept_importance

In [None]:
args.name == "concept_importance"
concept_importance(
        args.seed,
        args.latent_dim,
        args.batch_size,
        args.plot,
        model_name,
        n_test=args.n_test,
    )

### enforce_invariance

In [None]:
args.name == "enforce_invariance"
enforce_invariance(
        args.seed,
        args.latent_dim,
        args.batch_size,
        args.plot,
        model_name,
        n_test=args.n_test,
    )


### sensitiity_comparision

In [None]:
args.name == "sensitivity_comparison":
sensitivity_comparison(
        args.seed,
        args.latent_dim,
        args.batch_size,
        args.plot,
        model_name,
        n_test=args.n_test,
    )

In [None]:
# import sys
# sys.path.append('/kaggle/working')
# from datasets.ecg.loaders import ECGDataset

# # %% [code] {"jupyter":{"outputs_hidden":false}}