In [36]:
from crp.attribution import CondAttribution
from zennit.attribution import Gradient
from zennit.canonizers import CompositeCanonizer
from crp.concepts import ChannelConcept
from crp.helper import get_layer_names
from crp.visualization import FeatureVisualization
from crp.image import plot_grid

from zennit.composites import EpsilonPlusFlat
from zennit.canonizers import SequentialMergeBatchNorm
import zennit as zen
import torch.nn as nn
import torch
import zennit.torchvision as ztv
from crp.image import imgify


from relevance import plot_relevance
from model import get_vggs_and_path, get_resnets_and_path, get_remote_models_and_path
from plot_and_print import plot_tile
from data_loader import TileLoader
import os
import pandas as pd
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
from data_loader import get_data_loaders, get_dataset, STDataset

In [37]:
def plot_heatmaps(imgs, image_path, width=4, subplot_size=30):
    # +1 because we have the original tile plotted as extra image and +0.999 because we want to round up in case it is not perfectly dividable
    height = int((len(imgs) + 1) / width + 0.999)
    plt.figure(figsize=(width * subplot_size, height * subplot_size))

    f, ax = plt.subplots(height, width)
    f.set_figheight(subplot_size)
    f.set_figwidth(subplot_size)
    for i in range(len(imgs)):
        ax[int(i / width), i % width].imshow(imgs[i][0])
        ax[int(i / width), i % width].set_title(imgs[i][1])
        ax[int(i / width), i % width].axis('off')

    img = Image.open(image_path)
    ax[-1, -1].imshow(img)
    ax[-1, -1].set_title('original')
    ax[-1, -1].axis('off')
    plt.show()

In [38]:
models = get_remote_models_and_path()
loader = TileLoader()

tile_path = "../Test_Data/p026/tiles/p026_11_60.tiff"

data = loader.open(tile_path).unsqueeze(0)
target = 2.580166
data.requires_grad_(True)
plot_tile(tile_path)

In [39]:
model = models[3][0]
print(models[3][1])
composite_res = zen.composites.EpsilonPlusFlat(canonizers=[ztv.ResNetCanonizer()])
# TODO: check if cannonizer uses basic block / bottle neck



# is either torchvision.models.resnet.BasicBlock or 
# torchvision.models.resnet.Bottleneck
bottleneck_type = model.encoder.layer1[0].__class__
print(bottleneck_type)

In [40]:
import matplotlib.pyplot as plt
from PIL import Image

def plot_heatmaps(imgs, image_path, width=4, subplot_size=4):
    # Calculate the height of the plot grid
    height = int((len(imgs) + 1) / width + 0.999)
    
    # Set the figure size (width * subplot_size, height * subplot_size)
    f, ax = plt.subplots(height, width, figsize=(width * subplot_size, height * subplot_size), constrained_layout=True)
    
    # Loop to display images
    for i in range(len(imgs)):
        ax[int(i / width), i % width].imshow(imgs[i][0])
        ax[int(i / width), i % width].set_title(imgs[i][1])
        ax[int(i / width), i % width].axis('off')  # Optionally remove axes for a cleaner look

    # Plot the last image
    img = Image.open(image_path)
    ax[-1, -1].imshow(img)
    ax[-1, -1].set_title('original')
    ax[-1, -1].axis('off')  # Optionally remove axes for a cleaner look

    plt.show()

# Example call to the function (assuming `imgs` and `image_path` are defined)
# plot_heatmaps(imgs, image_path)


In [41]:
attribution = CondAttribution(model)

# here, each channel is defined as a concept
# or define your own notion!
cc = ChannelConcept()

# get layer names of Conv2D and MLP layers
layer_names = get_layer_names(model, [bottleneck_type])
print(layer_names)
# get a conditional attribution for channel 50 in layer features.27 wrt. output 1
#conditions = [{"y":[0], "gene1.0": [35]}]

heatmaps = []
"""
for layer in layer_names:
    conditions = [{"y":[0], layer:[0]}]
    
    attr = attribution(data, conditions, composite_res, record_layer=layer_names)
    heatmaps.append((attr.heatmap.squeeze(), layer))
"""

for i in range (0):
    conditions = [{"y":[0], "encoder.layer4.1":[i]}]
    
    attr = attribution(data, conditions, composite_res, record_layer=layer_names)
    heatmaps.append((attr.heatmap.squeeze(), "encoder.layer4.1 " + str(i)))

with Gradient(model, composite_res) as attributor:
    out, grad = attributor(data)
rel = plot_relevance(grad, filename=None, only_return=True)
heatmaps.append((rel, "unfiltered LRP"))

plot_heatmaps(heatmaps, tile_path, 4)


# heatmap and prediction
#attr.heatmap, attr.prediction
# activations and relevances for each layer name
#attr.activations, attr.relevances

# relative importance of each concept for final prediction
rel_c = cc.attribute(attr.relevances['encoder.layer1.0'])
concept_ids = torch.argsort(rel_c, descending=True)


----------------------------------------------------------------------------------------------------------------------------------------------------------------------------


In [8]:
model

In [42]:
from crp.helper import get_layer_names

layer_names = get_layer_names(model, [bottleneck_type])

conditions = [{'y': [0]}]
attr = attribution(data, conditions, composite_res, record_layer=layer_names)

print(attr.activations['encoder.layer4.1'].shape, attr.relevances['encoder.layer4.1'].shape)
# attr[1]["features.40"].shape, attr[2]["features.40"].shape # is equivalent
# layer features.40 has 512 channel concepts
rel_c = cc.attribute(attr.relevances['encoder.layer4.1'], abs_norm=True)
print(rel_c.shape)
# the six most relevant concepts and their contribution to final classification in percent
rel_values, concept_ids = torch.topk(rel_c[0], 6)
concept_ids, rel_values*100



In [43]:

print(concept_ids)
conditions = [{'encoder.layer4.1': [id], 'y': [0]} for id in concept_ids]

heatmap, _, _, _ = attribution(data, conditions, composite_res)

imgify(heatmap, symmetric=True, grid=(1, len(concept_ids)))



In [11]:
conditions = [{'encoder.layer4.1': [id], 'y': [0]} for id in torch.arange(0, 512)]

for attr in attribution.generate(data, conditions, composite_res, record_layer=layer_names, batch_size=10):
    pass


In [45]:


mask = torch.zeros(224, 224).to(attribution.device)
mask[:, 180:] = 1

imgify(mask, symmetric=True)



In [46]:
from crp.helper import abs_norm

rel_c = []
for attr in attribution.generate(data, conditions, composite_res, record_layer=layer_names, batch_size=10):
    
    masked = attr.heatmap * mask[None, :, :]
    rel_c.append(torch.sum(masked, dim=(1, 2)))

rel_c = torch.cat(rel_c)

indices = torch.topk(rel_c, 5).indices
# we norm here, so that we clearly see the contribution inside the masked region as percentage
indices, abs_norm(rel_c)[indices]*100

In [47]:
conditions = [{"y": [0], "encoder.layer4.1": [469]}]

attr = attribution(data, conditions, composite_res, record_layer=["encoder.layer2.1"])

rel_c = cc.attribute(attr.relevances["encoder.layer2.1"], abs_norm=True)

# five concepts in features.37 that contributed the most to the activation of channel 469 in features.40
# while being relevant for the classification of the lizard class
torch.argsort(rel_c, descending=True)[0, :5]


In [48]:
from crp.graph import trace_model_graph

graph = trace_model_graph(model, data, layer_names)
print(graph)

In [49]:
graph.find_input_layers("encoder.layer4.1")

In [50]:
from crp.attribution import AttributionGraph

layer_names = get_layer_names(model, [bottleneck_type])
    
layer_map = {name: cc for name in layer_names}
print(layer_map)
attgraph = AttributionGraph(attribution, graph, layer_map)

# decompose concept 71 in features.40 w.r.t. target 46 (lizard class)
# width=[5, 2] returns first the 5 most relevant concepts in the previous lower-level layer
# and in the second iteration returns for each of the 5 most relevant concepts again the two
# most relevant concepts in the previous lower-level layer
nodes, connections = attgraph(data, composite_res, 71, "encoder.layer4.1", 0, width=[5, 2], abs_norm=True)
print("Nodes:\n", nodes, "\nConnections:\n", connections)


In [51]:
nodes

In [52]:
connections[("encoder.layer4.1", 71)]

In [53]:
#device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
device = "cpu"
model.to(device)

In [54]:
data_dir = "../Training_Data/"
train_loader, val_loader = get_data_loaders(data_dir, 64)
dataset = get_dataset(data_dir)
from torchvision import transforms

class STDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, device="mps", transforms=transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            # mean and std of the whole dataset
            transforms.Normalize([0.7406, 0.5331, 0.7059], [0.1651, 0.2174, 0.1574])
            ])):
        self.dataframe = dataframe
        self.transforms = transforms
        self.device = device

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        gene_names = list(self.dataframe)[1:]
        gene_vals = []
        row = self.dataframe.iloc[index]
        a = Image.open(row["tile"]).convert("RGB")
        # print(x.size)
        for j in gene_names:
            gene_val = float(row[j])
            gene_vals.append(gene_val)
        e = row["tile"]
        # apply normalization transforms as for encoder colon classifier
        a = self.transforms(a)
        a = a.to(self.device)
        return a, 0
datasetST = STDataset(dataset)

In [55]:
device = "cpu"
print(device)
print(model.to(device))

In [56]:
import torchvision
from crp.concepts import ChannelConcept
from crp.helper import get_layer_names
from crp.attribution import CondAttribution
from crp.visualization import FeatureVisualization
import torchvision.transforms as T

cc = ChannelConcept()

layer_names = get_layer_names(model, [bottleneck_type])
layer_map = {layer : cc for layer in layer_names}
model.to(device)
print(next(model.parameters()).is_mps)
attribution = CondAttribution(model)

# separate normalization from resizing for plotting purposes later
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
preprocessing =  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])


# apply no normalization here!
#imagenet_data = torchvision.datasets.ImageNet(train_loader, transform=transform, split="val")  
fv_path = "output_crp"
fv = FeatureVisualization(attribution, datasetST, layer_map, preprocess_fn=preprocessing, path=fv_path)



# it will take approximately 20 min on a Titan RTX
#print(device)
saved_files = fv.run(composite_res, 0, len(datasetST), 124, 100)



In [None]:
len(datasetST)/64



In [None]:
#%matplotlib inline
from crp.image import plot_grid

ref_c = fv.get_max_reference([469, 35, 89, 316, 161], "encoder.layer4.1", "relevance", (0, 8))

plot_grid(ref_c, figsize=(6, 5), padding=False)


In [57]:
ref_c = fv.get_max_reference([469, 35, 89, 316, 161], "encoder.layer4.1", "relevance", (0, 8), composite=composite_res, plot_fn=None)

plot_grid(ref_c, figsize=(6, 9))

In [30]:
from crp.image import vis_opaque_img

ref_c = fv.get_max_reference([469, 35, 89, 316, 161], "encoder.layer4.1", "relevance", (0, 8), composite=composite_res, plot_fn=vis_opaque_img)

plot_grid(ref_c, cmap="bwr", symmetric=True, figsize=(6, 5))


In [31]:
ref_c = fv.get_max_reference([469, 35, 89, 316, 161], "encoder.layer4.1", "relevance", (0, 8), rf=True, composite=composite_res, plot_fn=vis_opaque_img)

plot_grid(ref_c, figsize=(6, 5), padding=False)


In [32]:
targets, rel = fv.compute_stats(469, "encoder.layer4.1", "relevance", top_N=5, norm=True)
targets, rel 


In [33]:
ref_t = fv.get_stats_reference(161, "encoder.layer4.1", targets, "relevance", (0, 8), rf=True, composite=composite_res, plot_fn=vis_opaque_img)

plot_grid(ref_t, figsize=(6, 5), padding=False)

In [34]:
from crp.cache import ImageCache

cache = ImageCache(path="cache")

fv = FeatureVisualization(attribution, datasetST, layer_map, preprocess_fn=preprocessing, path=fv_path, cache=cache)


In [35]:
from crp.helper import get_output_shapes
import numpy as np

layer_names = get_layer_names(model, [bottleneck_type])
output_shape = get_output_shapes(model, fv.get_data_sample(0)[0], layer_names)
layer_id_map = {l_name: np.arange(0, out[0]) for l_name, out in output_shape.items()}

fv.precompute_ref(layer_id_map,  plot_list=[vis_opaque_img], mode="relevance", r_range=(0, 16), composite=composite_res, rf=True, batch_size=32, stats=False)
