In [1]:
import torch
import os
import cv2
import sys
import matplotlib.pyplot as plt
import torchvision.transforms as T
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import numpy as np
import zennit.image as zimage
import torchvision

sys.path.append('code')
from models.pytorch_models import PretrainedCNN

from torchsummary import summary
from torchvision import datasets, transforms
from torch import nn
from PIL import Image, ImageDraw, ImageFont, ImageOps
from matplotlib.pyplot import imshow
from matplotlib import font_manager
from sklearn.model_selection import train_test_split

from crp.concepts import ChannelConcept
from crp.attribution import CondAttribution,AttributionGraph
from crp.helper import get_layer_names,abs_norm
from crp.graph import trace_model_graph

from zennit.composites import EpsilonPlusFlat
from zennit.canonizers import SequentialMergeBatchNorm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
torch.cuda.empty_cache()

#print(device)
#print(torch.__version__)

In [2]:
state_dict = torch.load('state_dict.pth')['models_state_dict']
model = PretrainedCNN('vgg', num_classes=2)
model.load_state_dict(state_dict[0])
model = model.eval()
#model.eval()
#eval_loss, eval_accuracy = evaluate_model(model=model,test_loader=test_loader, device=device, criterion=criterion)
#print(model)
#summary(model, (3, 224, 224))



In [3]:
model.to(device)
cc = ChannelConcept()
composite = EpsilonPlusFlat([SequentialMergeBatchNorm()])
attribution = CondAttribution(model)
transform = T.Compose([T.ToTensor()])
# find a font file
font = font_manager.FontProperties(family='sans-serif', weight='bold')
fontfile = font_manager.findfont(font)

In [4]:
def find_sample_attributions(path_to_file, prunable_channels,iteration,pruned_bool):
    splitpath = path_to_file.rsplit('/', 1)[-1]
    splitname = splitpath.rsplit(".",1)[0]
    #Attribute Concepts
    #Recording and Attributing Latent Concept Relevances
    true_imgs={}
    false_imgs={}
    image = Image.open(path_to_file).convert('RGB')
    sample = transform(image).unsqueeze(0)
    sample.requires_grad = True
    sample = sample.to(device)
    layer_names = get_layer_names(model, [torch.nn.Conv2d, torch.nn.Linear])
    true_conditions = [{'y': [1]}]
    false_conditions = [{'y': [0]}]
    true_attr = attribution(sample, true_conditions, composite, record_layer=layer_names)
    false_attr = attribution(sample, false_conditions, composite, record_layer=layer_names)
    torch.cuda.empty_cache()

    for model_layer in layer_names:
        name,section, _ = model_layer.split(".")
        single_layer = model_layer[model_layer.index('.') + 1 : ]
        for module_name, module in model._modules[name].named_modules():
            if single_layer == module_name:

                true_rel_c = cc.attribute(true_attr.relevances[model_layer], abs_norm=True)
                false_rel_c = cc.attribute(false_attr.relevances[model_layer], abs_norm=True)

                # the ten most irrelevant concepts and their percentage for y = 1 (true)
                true_concept_ids = torch.argsort(true_rel_c[0], descending=True)[:10]
                true_contributions =  abs_norm(true_rel_c[0])[true_concept_ids]*100
                # conditioned heatmap for specific layer and concept id in loop (for y=1)
                conditions = [{model_layer: [id], 'y': [1]} for id in true_concept_ids]
                heatmap, _, _, _ = attribution(sample, conditions, composite)
                true_img=(zimage.imgify(heatmap, symmetric=True, grid=(1, len(true_concept_ids)))) 
                true_imgs[module_name]=true_img
                torch.cuda.empty_cache()

                # the ten most relevant concepts and their percentage for y = 0 (false)
                false_concept_ids = torch.argsort(false_rel_c[0], descending=True)[:10]
                false_contributions =  abs_norm(false_rel_c[0])[false_concept_ids]*100
                conditions = [{model_layer: [id], 'y': [0]} for id in true_concept_ids]
                heatmap, _, _, _ = attribution(sample, conditions, composite)
                false_img=(zimage.imgify(heatmap, symmetric=True, grid=(1, len(false_concept_ids)))) 
                false_imgs[module_name]=false_img
                torch.cuda.empty_cache()
                #print(true_concept_ids)
                #print(false_concept_ids)

                #check which of the irrelevant true concepts are NOT the relevant false concepts
                #diff_concept_ids = list(set([int(t.item()) for t in true_concept_ids]) - set([int(t.item()) for t in false_concept_ids]))
                 
                for concept in true_concept_ids:
                    if ( (model_layer not in prunable_channels)
                       or
                       (concept not in prunable_channels[model_layer])) :
                        prunable_channels.setdefault(model_layer, dict())[concept] =1
                    else:
                        prunable_channels[model_layer][concept]+=1 
    vis_imgs(iteration,pruned_bool,true_imgs,false_imgs,image,splitname)
    return prunable_channels

In [5]:
def evaluate_model(eval_model,iteration,pruned_bool):
    # All Mitoses
    path = 'examples_ds_from_MP.train.HTW.train/True'
    model_pred =0
    prunable_channels={}
    for file in os.listdir(path):
        path_to_file = os.path.join(path,file)
        prunable_channels = find_sample_attributions(path_to_file,prunable_channels,iteration,pruned_bool)
        img = cv2.imread(path_to_file)
        # Preprocess for the model
        original = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img2 = original[144:144+224,144:144+224]
        img = img2 / 255.
        img -= [0.86121925, 0.86814779, 0.88314296] # MEAN-Values
        img /= [0.13475281, 0.10909398, 0.09926313] # STD-Values
        imgtensor = torch.tensor(img.swapaxes(1,2).swapaxes(0,1),dtype=torch.float32)[None]
        pred = eval_model.predict(imgtensor.to(device),logits=True).detach().cpu()
        pred = float(torch.nn.functional.softmax(pred)[0,1].numpy())
        #add overall prediction confidence
        model_pred+=pred
    model_pred = model_pred / len(os.listdir(path))
    return prunable_channels, model_pred

In [6]:
def return_non_empty(my_dict):
    temp_dict = {}
    for k, v in my_dict.items():
        if v:
            if isinstance(v, dict):
                return_dict = return_non_empty(v)
                if return_dict:
                    temp_dict[k] = return_dict
                else:
                    temp_dict[k] = v
    return temp_dict

In [7]:
def get_max_ids(test_dict: dict):
    max_vote_ids={}
    while sum(len(lst) for lst in max_vote_ids.values())<5:
        # step 1
        k_outer, v_outer = max(test_dict.items(), key=lambda x: max(x[1]),default=0)
        # step 2
        k_inner, v_inner = max(v_outer.items(), key=lambda x: x[1],default=0)
        if k_outer not in max_vote_ids:
            max_vote_ids.update( {k_outer : [k_inner]} )
        else:
            max_vote_ids[k_outer].append(k_inner)
        test_dict[k_outer].pop(k_inner, None)
        test_dict=return_non_empty(test_dict)      
    return max_vote_ids

In [8]:
def prune_selected_concepts(max_vote_ids,model):
    for k,v in max_vote_ids.items():
        name,section, no = k.split(".")
        #print(name,section,no)
        for module_name, module in model._modules[name].named_modules():
            #print(module_name)
            #print(k)
            if k.split('.', 1)[1] == module_name:
                #print(k)
                for concept in v:

                    mask_tensor = torch.ones(module.weight.shape, device='cuda:0')
                    #pruning differences between Conv2d and Linear layer
                    if isinstance(module, nn.Conv2d):
                        #masking channel by id so nothing is passed further
                        mask_tensor[concept, :, :, :] = torch.zeros(module.weight[0].shape)
                    elif isinstance(module, nn.Linear):
                        mask_tensor[concept,:] = torch.zeros(module.weight[0].shape)
                        #mask_tensor=mask_tensor

                    m = prune.custom_from_mask(module, name='weight', mask=mask_tensor)
                    #print(m.weight_mask)
                    prune.remove(module,"weight")
    return model

In [9]:
def vis_imgs(iteration,pruned_bool,true_imgs, false_imgs,sample,samplename):
    fontsize=1
    img_fraction=0.1
    font = ImageFont.truetype(fontfile, fontsize)
    
    if pruned_bool:
        flag = '_after_pruning_'
    else:
        flag = '_before_pruning_'
        
    min_img_width = min(i.width for i in true_imgs.values())
    total_height = 0
    img=0
    for i, key in enumerate(true_imgs.copy()):
        img = true_imgs[key]
        img.convert('RGB')
        img = ImageOps.expand(img, border=int(0.2*img.size[1]), fill=(255,255,255))
        while font.getsize(key)[0] < img_fraction*img.size[0]:
            # iterate until the text size is just larger than the criteria
            fontsize += 1
            font = ImageFont.truetype(fontfile, fontsize)
        fontsize -= 1
        font = ImageFont.truetype(fontfile, fontsize)
        d = ImageDraw.Draw(img)
        w,h = font.getsize(key)
        d.text((int((img.width-w)/2),0), key, fill=(0,0,0),font=font)
        if img.width > min_img_width:
            true_imgs[key] = img.resize((min_img_width, int(img.height / img.width * min_img_width)), Image.ANTIALIAS)
        total_height += true_imgs[key].height
        

    wpercent = (min_img_width/float(sample.size[0]))
    hsize = int((float(sample.size[1])*float(wpercent)))
    sample = sample.resize((min_img_width,hsize), Image.ANTIALIAS)
    
    sample = ImageOps.expand(sample,  border=int(0.1*sample.size[1]), fill=(255,255,255))
    d = ImageDraw.Draw(sample)
    w,h = font.getsize("Original")
    d.text((int((sample.width-w)/2),0), "Original", fill=(0,0,0),font=font)
    
    img_merge = Image.new(img.mode, (min_img_width+sample.width+(2*int(0.1*sample.size[1])), total_height)).convert('RGB')
    img_merge.paste((255,255,255), [0,0,min_img_width+sample.width+2*int(0.1*sample.size[1]),total_height])
    y = 0
    for image in true_imgs.values():
        img_merge.paste(image, (0, y))

        y += image.height
    img_merge.paste(sample, (sample.width, int(sample.height/2)-2*int(0.1*sample.size[1])))
    img_merge.save('true_attributions/true_iteration_'+str(iteration)+str(flag)+samplename+'.jpg')
    ##########################

    min_img_width = min(i.width for i in false_imgs.values())
    total_height = 0
    img=0
    for i, key in enumerate(false_imgs.copy()):
        img = false_imgs[key]
        img.convert('RGB')
        img = ImageOps.expand(img, border=int(0.2*img.size[1]), fill=(255,255,255))
        while font.getsize(key)[0] < img_fraction*img.size[0]:
            # iterate until the text size is just larger than the criteria
            fontsize += 1
            font = ImageFont.truetype(fontfile, fontsize)
        fontsize -= 1
        font = ImageFont.truetype(fontfile, fontsize)
        d = ImageDraw.Draw(img)
        w,h = font.getsize(key)
        d.text((int((img.width-w)/2),0), key, fill=(0,0,0),font=font)
        if img.width > min_img_width:
            false_imgs[key] = img.resize((min_img_width, int(img.height / img.width * min_img_width)), Image.ANTIALIAS)
        total_height += false_imgs[key].height
    
    img_merge = Image.new(img.mode, (min_img_width+sample.width+(2*int(0.1*sample.size[1])), total_height)).convert('RGB')
    img_merge.paste((255,255,255), [0,0,min_img_width+sample.width+2*int(0.1*sample.size[1]),total_height])
    y = 0
    for image in false_imgs.values():
        img_merge.paste(image, (0, y))

        y += image.height
    img_merge.paste(sample, (sample.width, int(sample.height/2)-2*int(0.1*sample.size[1])))
    img_merge.save('false_attributions/false_iteration_'+str(iteration)+str(flag)+samplename+'.jpg')

In [10]:
def fine_tune_model(train_loader,test_loader,model,optimizer, criterion,epochs):
    model.train()
    train_len = len(train_loader)    
    test_len = len(test_loader)
    for epoch in range(epochs):
        train_loss = 0.0
        val_loss = 0.0

        # Training the model
        model.train()
        counter = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model.forward(inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * inputs.size(0)

        # Evaluating the model
        model.eval()
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)           
                output = model.forward(inputs)            
                valloss = criterion(output, labels)
                val_loss += valloss.item() * inputs.size(0)

        train_loss = train_loss/train_len
        valid_loss = val_loss/test_len
        print('[%d] Training Loss: %.6f, Validation Loss: %.6f'  % (epoch + 1, train_loss, valid_loss))
    torch.cuda.empty_cache()
    return model


In [11]:
def train_val_dataset(dataset, val_split=0.25):
    train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=val_split)
    datasets = {}
    datasets['train'] = torch.utils.data.Subset(dataset, train_idx)
    datasets['val'] = torch.utils.data.Subset(dataset, val_idx)
    return datasets

path = 'examples_ds_from_MP.train.HTW.train'
train_transforms = transforms.Compose([transforms.RandomRotation(30),
                                      transforms.RandomResizedCrop(224),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor()])

test_transforms = transforms.Compose([transforms.RandomRotation(30),
                                     transforms.RandomResizedCrop(224),
                                     transforms.ToTensor()])


train_data = torchvision.datasets.ImageFolder(path, transform=train_transforms)
test_data = torchvision.datasets.ImageFolder(path, transform=test_transforms)


print(len(train_data))
datasets = train_val_dataset(train_data)
print(len(datasets['train']))
print(datasets['val'][5])
# The original dataset is available in the Subset class
print(datasets['train'].dataset)

dataloaders = {x:torch.utils.data.DataLoader(datasets[x],32, shuffle=True, num_workers=4) for x in ['train','val']}
x,y = next(iter(dataloaders['train']))

80
60
(tensor([[[0.9961, 0.9961, 0.9961,  ..., 0.9294, 0.9255, 0.9294],
         [0.9961, 0.9961, 0.9961,  ..., 0.9255, 0.9255, 0.9294],
         [1.0000, 1.0000, 1.0000,  ..., 0.9216, 0.9176, 0.9255],
         ...,
         [0.9647, 0.9569, 0.9686,  ..., 0.9686, 0.9451, 0.9373],
         [0.9843, 0.9647, 0.9686,  ..., 0.9765, 0.9608, 0.9569],
         [0.9961, 0.9804, 0.9725,  ..., 0.9804, 0.9765, 0.9765]],

        [[0.9922, 0.9922, 0.9922,  ..., 0.9216, 0.9216, 0.9255],
         [0.9922, 0.9922, 0.9922,  ..., 0.9176, 0.9216, 0.9255],
         [0.9961, 0.9961, 0.9961,  ..., 0.9098, 0.9137, 0.9216],
         ...,
         [0.9490, 0.9333, 0.9373,  ..., 0.9373, 0.9255, 0.9294],
         [0.9804, 0.9529, 0.9490,  ..., 0.9529, 0.9451, 0.9529],
         [1.0000, 0.9725, 0.9647,  ..., 0.9647, 0.9647, 0.9725]],

        [[0.9765, 0.9765, 0.9765,  ..., 0.9059, 0.9137, 0.9176],
         [0.9765, 0.9765, 0.9765,  ..., 0.9020, 0.9137, 0.9176],
         [0.9804, 0.9804, 0.9804,  ..., 0.8941, 0.9

In [None]:
train_loader = dataloaders['train']
test_loader = dataloaders['val']

trainloader = torch.utils.data.DataLoader(train_loader, batch_size=32)
testloader = torch.utils.data.DataLoader(test_loader, batch_size=32)

optimizer= torch.optim.Adam(model.parameters(), lr=0.01, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
criterion = nn.CrossEntropyLoss()
epochs = 100

with open('result_log.txt', 'w') as f:  
    for i in range(0,5):
        model=fine_tune_model(train_loader,test_loader,model,optimizer, criterion,epochs)
        prunable_channels,pred= evaluate_model(model,i,False)
        print("iteration {} , pred before pruning:  {}".format(i, pred))
        f.write("iteration {} , pred before pruning:  {}".format(i, pred))
        res = {key : dict(sorted(val.items(), key = lambda ele: ele[1],reverse=True)[:2])
           for key, val in prunable_channels.items()}
        max_vote_ids = get_max_ids(res)
        print("Voted least relevant concept ids to prune (max vote ids): {}".format(max_vote_ids))
        f.write("Voted least relevant concept ids to prune (max vote ids): {}".format(max_vote_ids))
        model = prune_selected_concepts(max_vote_ids,model)
        prunable_channels, pred = evaluate_model(model,i,True)
        print("iteration {} , pred after pruning:  {}".format(i, pred))
        f.write("iteration {} , pred after pruning:  {}".format(i, pred))
        model=fine_tune_model(train_loader,test_loader,model,optimizer, criterion,epochs)
    f.close()


[1] Training Loss: 475.555823, Validation Loss: 9864751.875000
[2] Training Loss: 3609.126465, Validation Loss: 1554092.500000
[3] Training Loss: 943.928491, Validation Loss: 19215.383301
[4] Training Loss: 144.360062, Validation Loss: 22218.823242
[5] Training Loss: 157.043590, Validation Loss: 3568.167114
[6] Training Loss: 62.932612, Validation Loss: 2506.511688
[7] Training Loss: 95.111670, Validation Loss: 12637.155762
[8] Training Loss: 84.453060, Validation Loss: 4253.510742
[9] Training Loss: 39.183821, Validation Loss: 805.597992
[10] Training Loss: 66.666698, Validation Loss: 226.292782
[11] Training Loss: 57.670331, Validation Loss: 78.798585
[12] Training Loss: 39.600875, Validation Loss: 44.777136
[13] Training Loss: 39.447758, Validation Loss: 130.885715
[14] Training Loss: 21.647552, Validation Loss: 54.817972
[15] Training Loss: 22.418296, Validation Loss: 99.592152
[16] Training Loss: 24.822658, Validation Loss: 14.387201
[17] Training Loss: 24.030066, Validation Loss:

  array = (array - vmin) / (vmax - vmin)
  while font.getsize(key)[0] < img_fraction*img.size[0]:
  w,h = font.getsize(key)
  true_imgs[key] = img.resize((min_img_width, int(img.height / img.width * min_img_width)), Image.ANTIALIAS)
  sample = sample.resize((min_img_width,hsize), Image.ANTIALIAS)
  w,h = font.getsize("Original")
  while font.getsize(key)[0] < img_fraction*img.size[0]:
  w,h = font.getsize(key)
  false_imgs[key] = img.resize((min_img_width, int(img.height / img.width * min_img_width)), Image.ANTIALIAS)
  pred = float(torch.nn.functional.softmax(pred)[0,1].numpy())


In [None]:
summary(model, (3, 224, 224))

In [None]:
# compute heatmap wrt. output 46 (green lizard class)
conditions = [{"y": 1}]

# or use a dictionary for mask_map
layer_names = get_layer_names(model, [torch.nn.Conv2d, torch.nn.Linear])
mask_map = {name: cc.mask for name in layer_names}
attr = attribution(sample, conditions, composite, mask_map=mask_map)
zimage.imgify(attr.heatmap, symmetric=True)

In [None]:
conditions = [{"model_ft.features.0": [2],"y": 1}]
heatmap, _, _, _ = attribution(sample, conditions, composite)
zimage.imgify(heatmap, symmetric=True)

In [None]:
conditions = [{"model_ft.features.11": [55]}]
heatmap, _, _, _ = attribution(sample, conditions, composite, start_layer="model_ft.features.11")
zimage.imgify(heatmap, symmetric=True)


In [None]:
# All Mitoses
path = 'examples_ds_from_MP.train.HTW.train/True'
model_pred =0
for file in os.listdir(path):
    path_to_file = os.path.join(path,file)
    #print(path_to_file)
    image = Image.open(path_to_file).convert('RGB')
    sample = transform(image).unsqueeze(0).to(device)
    
    
    img = cv2.imread(path_to_file)
    
    # Preprocess for the model
    # -1 convert from to RGB
    original = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    # 1. Centercrop
    img2 = original[144:144+224,144:144+224]
    
    # 2. Normalize, so that the mean value for each channel for
    # the complete train dataset is 0 and the std 1
    img = img2 / 255.
    img -= [0.86121925, 0.86814779, 0.88314296] # MEAN-Values
    img /= [0.13475281, 0.10909398, 0.09926313] # STD-Values
    
    # 3. swap axes, convert to tensor and load to GPU
    imgtensor = torch.tensor(img.swapaxes(1,2).swapaxes(0,1),dtype=torch.float32)[None]
    imgtensor = imgtensor.cuda()
    
    # 4. Get model prediction
    pred = model.predict(imgtensor,logits=True).detach().cpu()
    pred = float(torch.nn.functional.softmax(pred)[0,1].numpy())
    model_pred+=pred
    
    #print(torch.exp(model.predict(imgtensor,logits=True).detach().cpu()))
model_pred = model_pred / len(os.listdir(path))
print(model_pred)
    
    #plt.title(f'Pred: {pred}')
    #plt.imshow(img2)
    #plt.show()

In [None]:
class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask

In [None]:
min_img_width = min(i.width for i in true_imgs)
total_height = 0
for i, img in enumerate(true_imgs):
    img.convert('RGB')
    if img.width > min_img_width:
        true_imgs[i] = img.resize((min_img_width, int(img.height / img.width * min_img_width)), Image.ANTIALIAS)
    total_height += true_imgs[i].height
img_merge = Image.new(true_imgs[0].mode, (min_img_width, total_height)).convert('RGB')
y = 0
for img in true_imgs:
    img_merge.paste(img, (0, y))

    y += img.height
img_merge.save('true.jpg')
##########################

min_img_width = min(i.width for i in false_imgs)
total_height = 0
for i, img in enumerate(false_imgs):
    img.convert('RGB')
    if img.width > min_img_width:
        false_imgs[i] = img.resize((min_img_width, int(img.height / img.width * min_img_width)), Image.ANTIALIAS)
    total_height += false_imgs[i].height
img_merge = Image.new(false_imgs[0].mode, (min_img_width, total_height)).convert('RGB')
y = 0
for img in false_imgs:
    img_merge.paste(img, (0, y))

    y += img.height
img_merge.save('false.jpg')

In [None]:
cc = ChannelConcept()
composite = EpsilonPlusFlat([SequentialMergeBatchNorm()])
attribution = CondAttribution(model)

conditions = [{"model_ft.features.18": [30], "y": [0]}, {"model_ft.features.15": [10], "y": [0]}]
heatmaps, _, _, _ = attribution(sample, conditions, composite)


zimage.imgify(heatmaps, symmetric=True, grid=(1, len(heatmaps)))



In [None]:
conditions = [{"model_ft.features.25": [19]}]

heatmap, _, _, _ = attribution(sample, conditions, composite, start_layer="model_ft.features.25")

zimage.imgify(heatmap, symmetric=True)

In [None]:
layer_names = get_layer_names(model, [torch.nn.Conv2d, torch.nn.Linear])

conditions = [{'y': [0]}]
attr = attribution(sample, conditions, composite, record_layer=layer_names)

attr.activations['model_ft.features.4'].shape, attr.relevances['model_ft.features.4'].shape


In [None]:
# layer features.40 has 512 channel concepts
rel_c = cc.attribute(attr.relevances['model_ft.features.4'], abs_norm=True)
rel_c.shape

# the five most relevant concepts
concept_ids = torch.argsort(rel_c[0], descending=True)[:6]
concept_ids, abs_norm(rel_c[0])[concept_ids]*100

In [None]:
conditions = [{'model_ft.features.4': [id], 'y': [0]} for id in concept_ids]

heatmap, _, _, _ = attribution(sample, conditions, composite)

zimage.imgify(heatmap, symmetric=True, grid=(1, len(concept_ids)))

In [None]:
torch.cuda.memory_summary(device=None, abbreviated=False)

In [None]:
conditions = [{'model_ft.features.4': [id], 'y': [0]} for id in np.arange(0, 128)]

for attr in attribution.generate(sample, conditions, composite, record_layer=layer_names, batch_size=1):
    pass

torch.cuda.empty_cache()


In [None]:
mask = torch.zeros(512, 512)
mask[:, 180:] = 1

zimage.imgify(mask, symmetric=True)
rel_c = []
for attr in attribution.generate(sample, conditions, composite, record_layer=layer_names, batch_size=1):
    
    masked = attr.heatmap * mask[None, :, :]
    rel_c.append(torch.sum(masked, dim=(1, 2)))

rel_c = torch.cat(rel_c)

indices = torch.argsort(rel_c, descending=True)[:5]
# we norm here, so that we clearly see the contribution inside the masked region as percentage
indices, abs_norm(rel_c)[indices]*100

In [None]:
conditions = [{"y": [0], 'model_ft.features.4': [9]}]

attr = attribution(sample, conditions, composite, record_layer=["model_ft.features.2"])

rel_c = cc.attribute(attr.relevances["model_ft.features.2"], abs_norm=True)

# five concepts in features.37 that contributed the most to the activation of channel 469 in features.40
# while being relevant for the classification of the lizard class
torch.argsort(rel_c, descending=True)[0, :5]



In [None]:
graph = trace_model_graph(model, sample, layer_names)
print(graph)

In [None]:
graph.find_input_layers('model_ft.features.4')

In [None]:
layer_names = get_layer_names(model, [torch.nn.Conv2d, torch.nn.Linear])

layer_map = {name: cc for name in layer_names}
attgraph = AttributionGraph(attribution, graph, layer_map)

# decompose concept 71 in features.40 w.r.t. target 46 (lizard class)
# width=[5, 2] returns first the 5 most relevant concepts in the first lower-level layer
# and in the second iteration returns for each of the 5 most relevant concepts again the two
# most relevant concepts in the following lower-level layer
nodes, connections = attgraph(sample, composite, 82, 'model_ft.features.4', 0, width=[5, 2], abs_norm=True)
print("Nodes:\n", nodes, "\nConnections:\n", connections)

In [None]:
connections[('model_ft.features.4', 82)]

In [None]:
# All Mitoses
path = 'examples_ds_from_MP.train.HTW.train/True'
for file in os.listdir(path):
    path_to_file = os.path.join(path,file)
    print(path_to_file)
    img = cv2.imread(path_to_file)
    
    # Preprocess for the model
    # -1 convert from to RGB
    original = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    # 1. Centercrop
    img2 = original[144:144+224,144:144+224]
    
    # 2. Normalize, so that the mean value for each channel for
    # the complete train dataset is 0 and the std 1
    img = img2 / 255.
    img -= [0.86121925, 0.86814779, 0.88314296] # MEAN-Values
    img /= [0.13475281, 0.10909398, 0.09926313] # STD-Values
    
    # 3. swap axes, convert to tensor and load to GPU
    imgtensor = torch.tensor(img.swapaxes(1,2).swapaxes(0,1),dtype=torch.float32)[None]
    
    # 4. Get model prediction
    pred = model.predict(imgtensor,logits=True).detach()
    pred = float(torch.nn.functional.softmax(pred)[0,1].numpy())
    
    print(torch.exp(model.predict(imgtensor,logits=True).detach()))
    print(pred)
    
    plt.title(f'Pred: {pred}')
    plt.imshow(img2)
    plt.show()

In [None]:
# All Non-Mitoses
path = 'examples_ds_from_MP.train.HTW.train/False'
for file in os.listdir(path):
    path_to_file = os.path.join(path,file)
    print(path_to_file)
    img = cv2.imread(path_to_file)
    
    # Preprocess for the model
    # -1 convert from to RGB
    original = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    # 1. Centercrop
    img2 = original[144:144+224,144:144+224]
    
    # 2. Normalize, so that the mean value for each channel for
    # the complete train dataset is 0 and the std 1
    img = img2 / 255.
    img -= [0.86121925, 0.86814779, 0.88314296] # MEAN-Values
    img /= [0.13475281, 0.10909398, 0.09926313] # STD-Values
    
    # 3. swap axes, convert to tensor and load to GPU
    imgtensor = torch.tensor(img.swapaxes(1,2).swapaxes(0,1),dtype=torch.float32)[None]
    
    # 4. Get model prediction
    pred = model.predict(imgtensor,logits=True).detach()
    pred = float(torch.nn.functional.softmax(pred)[0,1].numpy())
    
    print(torch.exp(model.predict(imgtensor,logits=True).detach())
    print(pred)
    
    plt.title(f'Pred: {pred}')
    plt.imshow(img2)
    plt.show()