# Integrated Gradients using Captum Insights
https://captum.ai/docs/captum_insights


## Installation
* first install with conda : conda install -c pytorch captum 
* download captum locally in this folder
* in ```captum.insights.attr_vis.features```, in class ```ImageFeature```, method ```visualize```, add this snippet of code at the start of the method: 
   ```
        class UnNormalize(object):
            def __init__(self, mean, std):
                self.mean = mean
                self.std = std

            def __call__(self, tensor):
                for t, m, s in zip(tensor, self.mean, self.std):
                    t.mul_(s).add_(m)
                return tensor
                
        unorm = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        data = unorm(data)
        print('data is unormalized')
   ```
    
    This code is here so that visualization of integrated gradients is done with the image before the normalization from ImageNet
    
* then install captum locally from notebooks/captum/ : pip install -e .


In [1]:
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn.functional as F
import pkg_resources
import pandas as pd 


from captum.insights import AttributionVisualizer, Batch
from captum.insights.attr_vis.features import ImageFeature

from src.database import Database
from src.protocol import Protocol 
from src import utils



## Choose expert 

In [23]:
rbp = 'TDP-43'
classifier = 'als'
channels = ['TDP-43']

In [24]:
classifiers = {
    'als': 'control_als_untreated',
    'osmotic': 'control_untreated_osmotic',
    'heat': 'control_untreated_heat',
    'oxidative': 'control_untreated_oxidative',
}


In [25]:
protocol_name = f'{classifiers[classifier]}_{rbp}'
protocol = Protocol.from_name(protocol_name)
print(protocol)

if classifier == 'als': 
    classes = ['healthy', 'als']
    classification = 'als'
else:
    classes = ['untreated','stress']
    classification = 'stress'


labels: ['control', 'als']
conditions: ['untreated']
rbp: TDP-43


## Define functions for classification classes and pretrained model

In [29]:
def get_classes():
    return classes


def get_pretrained_model():
    model = models.mobilenet_v2(pretrained=True)
    for f in model.features: 
        f.requires_grad = False
    num_ftrs = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(num_ftrs, 2)
    model.load_state_dict(torch.load(f'../models/{classifier}_models/state_dict_{protocol_name}_{"_".join(channels)}_fold_0.pt', map_location=torch.device('cpu')))
    model.eval()
    return model


def baseline_func(input):
    return input * 0


def formatted_data_iter():
    database=Database()
    config = utils.get_config('notebook')
    train, test = database.cross_validation(classification, protocol, fold=0)
    
    train_loader = utils.create_train_dataloader(config, database, train, classification, protocol, channels, fold=0)
    test_loader, test_dataset = utils.create_test_dataloader(config, database, test, classification, protocol, channels, fold=0)

    dataloader = iter(test_loader)
    
    # select 10 images of each class 
    images = list()
    labels = list()
    n_images = dict()
    n_images[classes[0]] = 0
    n_images[classes[1]] = 0
    limit_images = 10

    print(len(dataloader))
    while n_images[classes[0]] != limit_images or n_images[classes[1]] != limit_images:
        image, label, _ = dataloader.next()    
        while image==None: 
            print('image is None')
            image, label, _ = dataloader.next()  

        output = model(image)
        output = F.softmax(output, dim=1)
        prediction_score, pred_label_idx = torch.topk(output, 1)
        predicted = classes[pred_label_idx.item()]
        output = output.cpu().detach()
        if prediction_score > 0.95 and n_images[predicted] < limit_images:
#         if prediction_score > 0.4 and prediction_score < 0.6 and n_images[predicted] < limit_images:
            n_images[predicted] += 1
            image.requires_grad = True
            images.append(image)
            labels.append(label)
            
    images = torch.cat(images)
    labels = torch.cat(labels)
    yield Batch(inputs=images, labels=labels)
    


## Run the visualizer and render inside notebook 

In [30]:
model = get_pretrained_model()
visualizer = AttributionVisualizer(
    models=[model],
    score_func=lambda o: F.softmax(o, 1),
    classes=get_classes(),
    features=[
        ImageFeature(
            "Photo",
            baseline_transforms=[baseline_func],
            input_transforms= [],
        )
    ],
    dataset=formatted_data_iter(),
)


In [31]:
visualizer.render()

CaptumInsights(insights_config={'classes': ['healthy', 'als'], 'methods': ['Deconvolution', 'Deep Lift', 'Guid…

Output()