### Interpretability of Attention-based MIL 

In this session, we will explore two post-hoc interpretability methods for understanding the behaviour of an Attention-based MIL model trained to subtype lung cancer into LUAD and LUSC, the two most common lung cancer subtypes. 


#### On the importance of interpretability in medical AI applications 

- **Trust & control**: Medical AI systems are often used in critical decision-making processes that directly impact patients' health and lives. Interpretability aims to provide the right level of insights and control over the system such that trust and confidence can be established between users and the system.


- **Insights and Understanding**: Interpretable models are also valuable to provide insights into the underlying factors or features that contribute to a prediction. In medical AI, this understanding can be critical for healthcare professionals to gain insights into disease mechanisms, identify risk factors, or discover novel biomarkers. This aspect goes in the direction of biomarker discovery, where we assume that the AI system might use different features than what form the current standard (e.g., the current grading criteria in cancer). 


- **Error Analysis and Diagnosis**: Interpretability helps in error analysis, allowing the identification and understanding of the model's mistakes or mispredictions. In medical AI, where misdiagnoses can have severe consequences, interpretability enables clinicians to evaluate cases where the model failed and diagnose potential pitfalls or limitations. This feedback loop can guide improvements in the model, dataset, or feature engineering, leading to better performance and more reliable predictions.


And the wishful thinking considerations...


- **Legal and Ethical Considerations**: The interpretability of AI models could be used for addressing legal and ethical concerns. In healthcare, decisions made by AI systems need to be explainable to patients, healthcare professionals, regulatory bodies, and other stakeholders. By providing interpretability, AI systems can adhere to legal requirements, such as the General Data Protection Regulation (GDPR), which grants individuals the right to an explanation for automated decisions that significantly impact them.


- **Safety and Robustness**: Deep learning models are susceptible to biases, adversarial attacks, or data distribution shifts that can lead to incorrect or unreliable predictions. Interpretability helps in detecting these issues and assessing the model's safety and robustness. By understanding the model's internal workings, it becomes possible to identify potential biases, investigate cases where the model may be overconfident or underperform, and design safeguards to mitigate risks.


- **Regulatory Compliance**: Interpretability is increasingly becoming a regulatory requirement in various domains, including healthcare. Regulatory bodies, such as the U.S. Food and Drug Administration (FDA), often demand explanations and justifications for the decisions made by AI systems before approving their deployment. Interpretability allows the model's behavior to be audited, validated, and aligned with regulatory standards, ensuring compliance and patient safety.


In [None]:
# Load ABMIL model previsoulsy trained

model = torch.load(PATH)
model.eval()

#### Interpreting Attention-based MIL with Attention weights

Attention weights are a mechanism used in deep learning models to determine the importance or relevance of different parts of the input data. These weights can be utilized to interpret deep learning predictions by providing insights into which parts of the input contribute more strongly to the model's decision-making process.

###

In [None]:
# add the code for that here. 

#### Interpretating with Integrated Gradients 

In this section, we will use Captum, an open-source package that provides off-the-shelf post-hoc interpretability techniques, including Integrated Gradients (IG). 

1. **Notation:**
   - Let's consider a deep learning model with an input vector `x` and output `f(x)`.
   - The baseline or reference point is denoted as `x'`, typically chosen as a point with low complexity (e.g., all zeros or random noise).
   - The attribution score for each input feature `i` is denoted as `A_i`.

2. **Gradient Calculation:**
   - Compute the gradients of the model's output with respect to the input features:
   
     $$\vec{\nabla} f(x) = \left(\frac{\partial f(x)}{\partial x_1}, \frac{\partial f(x)}{\partial x_2}, \ldots, \frac{\partial f(x)}{\partial x_n}\right)$$

3. **Integrated Gradients Formula:**
   - The Integrated Gradients score for each feature `i` is calculated as follows:
   
     $$A_i = (x_i - x'_i) \times \int_{\alpha=0}^1 \left(\frac{\partial f(x'+\alpha(x-x'))}{\partial x_i}\right) d\alpha$$

4. **Explanation:**
   - Integrated Gradients computes the contribution of each feature `i` by taking into account the difference between the input `x` and the baseline `x'`.
   - It then integrates the gradients of the model's output with respect to feature `i` along a straight path from the baseline `x'` to the input `x`.
   - The integral is calculated over a series of steps (α) from 0 to 1, representing the interpolation between the baseline and the input.
   - The gradients at each interpolation point measure the sensitivity of the output to changes in feature `i` as we move from the baseline to the input.
   - The contribution of feature `i` is multiplied by the difference between the input and baseline for that feature, capturing the change in the model's output caused by that feature.

5. **Implementation Steps:**
   - Select a baseline point `x'` (all zeros, random noise, or other relevant choice).
   - Define the number of steps or intervals for the integration.
   - For each step α from 0 to 1, calculate `x'+α(x-x')` as the intermediate input.
   - Compute the gradients of the model's output with respect to the intermediate input at each step.
   - Accumulate the gradients and multiply by the difference between `x` and `x'` for each feature `i` to calculate the attribution scores `A_i`.

The Integrated Gradients technique provides feature-level attribution scores, allowing us to understand the importance of each input feature in the model's predictions. By visualizing or analyzing these scores, we can gain insights into which features are influential or critical for the model's decision-making process.


#### Discussion

- What is the main advantage of IG over attention for model interpretability?


- What are the limitations of feature-attribution methods? 


- What's the difference interpretability and explaianability? Link it to the notion of control. 


- If you were a clinical pathologist looking at these visualizations, what insights or concerns would you have in letting an AI algorithm assist you medical diagnoses?

The following link at [http://clam.mahmoodlab.org](http://clam.mahmoodlab.org) visualizes high-attention heatmaps for LUAD vs LUSC subtyping via CLAM (similar to `ABMIL`) and confidence scores for each slides. 