In [1]:
import torch
import os
import cv2
import sys
import matplotlib.pyplot as plt
import torchvision.transforms as T
import torch.nn.utils.prune as prune
from operator import itemgetter
import torch.nn.functional as F
import numpy as np
import zennit.image as zimage
import torchvision
import operator
from itertools import islice

sys.path.append('code')
import imp
Lamb = imp.load_source('lamb', '/home/ieisenbraun/Documents/FP/MAI_FINAL_MODEL_VGG/SOURCE/code/utils/lamb.py')
from models.pytorch_models import PretrainedCNN

from torchsummary import summary
from torchvision import datasets, transforms
from torch import nn
from PIL import Image, ImageDraw, ImageFont, ImageOps
from matplotlib.pyplot import imshow
from matplotlib import font_manager
from sklearn.model_selection import train_test_split

from crp.concepts import ChannelConcept
from crp.attribution import CondAttribution,AttributionGraph
from crp.helper import get_layer_names,abs_norm
from crp.graph import trace_model_graph

from zennit.composites import EpsilonPlusFlat
from zennit.canonizers import SequentialMergeBatchNorm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
torch.cuda.empty_cache()

#print(device)
#print(torch.__version__)

  import imp


In [2]:
state_dict = torch.load('state_dict.pth')['models_state_dict']
model = PretrainedCNN('vgg', num_classes=2)
model.load_state_dict(state_dict[0])
model = model.eval()
#model.eval()
#eval_loss, eval_accuracy = evaluate_model(model=model,test_loader=test_loader, device=device, criterion=criterion)
print(model)
#summary(model, (3, 224, 224))



PretrainedCNN(
  (model_ft): VGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU(inplace=True)
      (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): ReLU(inplace=True)
      (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=

In [3]:
model.to(device)
cc = ChannelConcept()
composite = EpsilonPlusFlat([SequentialMergeBatchNorm()])
attribution = CondAttribution(model)
transform = T.Compose([T.ToTensor()])
layer_names = get_layer_names(model, [torch.nn.Conv2d, torch.nn.Linear])
# find a font file
print(layer_names)
font = font_manager.FontProperties(family='sans-serif', weight='bold')
fontfile = font_manager.findfont(font)

['model_ft.features.0', 'model_ft.features.4', 'model_ft.features.8', 'model_ft.features.11', 'model_ft.features.15', 'model_ft.features.18', 'model_ft.features.22', 'model_ft.features.25', 'model_ft.classifier.0', 'model_ft.classifier.3', 'model_ft.classifier.6']


In [4]:
def iterate_images(path,iteration,is_pruned,folder,descending=True):
    condition = [{'y': [1]}]
    concept_atlas={}
  
    for file in os.listdir(path):
        imgs={}
        path_to_file = os.path.join(path,file)
        splitpath = path_to_file.rsplit('/', 1)[-1]
        splitname = splitpath.rsplit(".",1)[0]
        image = Image.open(path_to_file).convert('RGB')
        sample = transform(image).unsqueeze(0)
        sample.requires_grad = True
        sample = sample.to(device)
        attr = attribution(sample, condition, composite, record_layer=layer_names)
        torch.cuda.empty_cache()
        
        for model_layer in layer_names:
            name,section, _ = model_layer.split(".")
            single_layer = model_layer[model_layer.index('.') + 1 : ]
            for module_name, module in model._modules[name].named_modules():
                if single_layer == module_name:
                    rel_c = cc.attribute(attr.relevances[model_layer], abs_norm=True)
                    concept_ids = torch.argsort(rel_c[0], descending=descending)[:10]
                    
                    conditions = [{model_layer: [id], 'y': [1]} for id in concept_ids]
                    heatmap, _, _, _ = attribution(sample, conditions, composite)
                    img=(zimage.imgify(heatmap, symmetric=True, grid=(1, len(concept_ids)))) 
                    imgs[module_name]=img
                    
                    del conditions
                    
                    for concept in concept_ids:
                        concept = concept.cpu().item()
                        if ( (single_layer not in concept_atlas)
                           or
                           (concept not in concept_atlas[single_layer])) :
                            concept_atlas.setdefault(single_layer, dict())[concept] =1
                        else:
                            concept_atlas[single_layer][concept]+=1 
                            
                    del concept_ids
                    torch.cuda.empty_cache() 
                    
        del attr
        del sample
        torch.cuda.empty_cache() 
        
        vis_imgs(iteration,is_pruned,imgs,image,splitname,folder)
    return concept_atlas

In [5]:
import collections
def nested_dict_iter(nested):
    for key, value in nested.items():
        if isinstance(value, collections.Mapping):
            for inner_key, inner_value in nested_dict_iter(value):
                yield key, inner_key, inner_value
        else:
            yield key, value

In [6]:
def find_sample_attributions(iteration,is_pruned):
    prunable_channels={}
    true_images_path='examples_ds_from_MP.train.HTW.train/True'
    false_images_path='examples_ds_from_MP.train.HTW.train/False'
    true_concept_atlas = iterate_images(true_images_path, iteration,is_pruned, 'true_attributions', False)
    false_concept_atlas=iterate_images(false_images_path,iteration,is_pruned,'false_attributions',True)
    #print(false_concept_atlas)
    torch.cuda.empty_cache()
    irrelevant_channels_dict={}
    most_irrelevant_channels_dict={}
    prunable_channels ={}
    
    for layer in false_concept_atlas:
        #print(false_concept_atlas[layer])
        final_dict = dict(true_concept_atlas[layer].items() & false_concept_atlas[layer].items())
        #print ("final dictionary", str(final_dict))
        irrelevant_channels_dict[layer] = {k: false_concept_atlas[layer][k] for k in false_concept_atlas[layer]
                                      if k not in final_dict}   
        #print(irrelevant_channels_dict[layer])
        irrelevant_channels_dict[layer] = dict( sorted(irrelevant_channels_dict[layer].items(),
                                                       key=operator.itemgetter(1),reverse=True))
        #print(irrelevant_channels_dict[layer])
      
    del true_concept_atlas
    del false_concept_atlas
    torch.cuda.empty_cache()
    irrelevant_channels_dict_lst = list(nested_dict_iter((irrelevant_channels_dict)))
    sorting_record = sorted(irrelevant_channels_dict_lst, key = lambda i: i[2], reverse = True)[0:10]
    
    for item in sorting_record:
        if (item[0] not in prunable_channels):
            prunable_channels[item[0]] =[item[1]]
        else:
            prunable_channels[item[0]].append(item[1])
         
    #for layer in irrelevant_channels_dict:
        
        #most_irrelevant_channels_dict[layer] = take(2, irrelevant_channels_dict[layer].items())
        #print(most_irrelevant_channels_dict[layer])
        #prunable_channels[layer] = list(most_irrelevant_channels_dict[layer].keys())

    return prunable_channels

In [7]:
def evaluate_model(eval_model):
    # All Mitoses
    path = 'examples_ds_from_MP.train.HTW.train/True'
    model_pred =0
    for file in os.listdir(path):
        path_to_file = os.path.join(path,file)  
        img = cv2.imread(path_to_file)
        # Preprocess for the model
        original = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img2 = original[144:144+224,144:144+224]
        img = img2 / 255.
        img -= [0.86121925, 0.86814779, 0.88314296] # MEAN-Values
        img /= [0.13475281, 0.10909398, 0.09926313] # STD-Values
        imgtensor = torch.tensor(img.swapaxes(1,2).swapaxes(0,1),dtype=torch.float32)[None]
        pred = eval_model.predict(imgtensor.to(device),logits=True).detach().cpu()
        #print(pred)
        pred = float(torch.nn.functional.softmax(pred)[0,1].numpy())
        #print(pred)
        model_pred+=pred

        
    model_pred = model_pred / len(os.listdir(path))
    return model_pred

In [8]:
def take(n, iterable):
    "Return first n items of the iterable as a list"
    return dict(islice(iterable, n))

In [9]:
def return_non_empty(my_dict):
    temp_dict = {}
    for k, v in my_dict.items():
        if v:
            if isinstance(v, dict):
                return_dict = return_non_empty(v)
                if return_dict:
                    temp_dict[k] = return_dict
                else:
                    temp_dict[k] = v
    return temp_dict

In [10]:
def get_max_ids(test_dict: dict):
    max_vote_ids={}
    while sum(len(lst) for lst in max_vote_ids.values())<10:
        
        for layer in test_dict:
            max_votes =  take(2, test_dict[layer].items())
            max_vote_ids[layer].append(k_inner)
            print(max_votes, layer)
                
    return max_vote_ids

In [11]:
def prune_selected_concepts(max_vote_ids,model):
    for k,v in max_vote_ids.items():
        section, no = k.split(".")
        #v = v.cpu().data.numpy()
        name = "model_ft"
        #print(name,section,no)
        for module_name, module in model._modules[name].named_modules():
            #print(module_name)
            #print(k)
            nr = int(k.split('.', 1)[1])
            mod = k.split('.', 1)[0]
            if mod == module_name:
                #print(k)
                for concept in v:
                    #print(concept)

                    mask_tensor = torch.ones(module[nr].weight.shape, device='cuda:0')
                    #print(module[nr])
                    #print(module[nr].weight.shape)
                    
                    #pruning differences between Conv2d and Linear layer
                    if isinstance(module[nr], nn.Conv2d):
                        #masking channel by id so nothing is passed further
                        mask_tensor[concept, :, :, :] = torch.zeros(module[nr].weight[0].shape)
                    elif isinstance(module[nr], nn.Linear):
                        mask_tensor[concept,:] = torch.zeros(module[nr].weight[0].shape)
                        #mask_tensor=mask_tensor

                    m = prune.custom_from_mask(module[nr], name='weight', mask=mask_tensor)
                    #print(m.weight_mask)
                    prune.remove(module[nr],"weight")
            torch.cuda.empty_cache()
    return model

In [12]:
def irrelevant_concepts(prunable_channels,iteration):
    path = 'examples_ds_from_MP.train.HTW.train/True'
    
    for file in os.listdir(path):
        
  
        imgs={}
        
        path_to_file = os.path.join(path,file)
        splitpath = path_to_file.rsplit('/', 1)[-1]
        splitname = splitpath.rsplit(".",1)[0]
        image = Image.open(path_to_file).convert('RGB')
        sample = transform(image).unsqueeze(0)
        sample.requires_grad = True
        sample = sample.to(device)
        
        imgs = {}

        for layer in prunable_channels:
            conditions = [{'model_ft.'+layer: [torch.tensor(id).to(device)], 'y': [0]} for id in prunable_channels[layer]]
            if len(conditions):
                heatmap, _, _, _ = attribution(sample, conditions, composite)
                img=(zimage.imgify(heatmap, symmetric=True, grid=(1, len(prunable_channels[layer])))) 
                imgs[layer]=img
            del conditions
            del attribution
        
        del sample
        torch.cuda.empty_cache()
            
        vis_imgs(iteration,False,imgs,image,splitname,'irrelevant_concepts')

In [13]:
def vis_imgs(iteration,pruned_bool,true_imgs,sample,samplename,path):
    fontsize=1
    img_fraction=0.1
    font = ImageFont.truetype(fontfile, fontsize)
    
    if pruned_bool:
        flag = '_after_pruning_'
    else:
        flag = '_before_pruning_'
        
    min_img_width = min(i.width for i in true_imgs.values())
    total_height = 0
    img=0
    for i, key in enumerate(true_imgs.copy()):
        img = true_imgs[key]
        img.convert('RGB')
        img = ImageOps.expand(img, border=int(0.2*img.size[1]), fill=(255,255,255))
        while font.getsize(key)[0] < img_fraction*img.size[0]:
            # iterate until the text size is just larger than the criteria
            fontsize += 1
            font = ImageFont.truetype(fontfile, fontsize)
        fontsize -= 1
        font = ImageFont.truetype(fontfile, fontsize)
        d = ImageDraw.Draw(img)
        w,h = font.getsize(key)
        d.text((int((img.width-w)/2),0), key, fill=(0,0,0),font=font)
        if img.width > min_img_width:
            true_imgs[key] = img.resize((min_img_width, int(img.height / img.width * min_img_width)), Image.ANTIALIAS)
        total_height += true_imgs[key].height
        

    wpercent = (min_img_width/float(sample.size[0]))
    hsize = int((float(sample.size[1])*float(wpercent)))
    sample = sample.resize((min_img_width,hsize), Image.ANTIALIAS)
    
    sample = ImageOps.expand(sample,  border=int(0.1*sample.size[1]), fill=(255,255,255))
    d = ImageDraw.Draw(sample)
    w,h = font.getsize("Original")
    d.text((int((sample.width-w)/2),0), "Original", fill=(0,0,0),font=font)
    
    img_merge = Image.new(img.mode, (min_img_width+sample.width+(2*int(0.1*sample.size[1])), total_height)).convert('RGB')
    img_merge.paste((255,255,255), [0,0,min_img_width+sample.width+2*int(0.1*sample.size[1]),total_height])
    y = 0
    for image in true_imgs.values():
        img_merge.paste(image, (0, y))

        y += image.height
    img_merge.paste(sample, (sample.width, int(sample.height/2)-2*int(0.1*sample.size[1])))
    img_merge.save(path+'/iteration_'+str(iteration)+str(flag)+samplename+'.jpg')
    ##########################

In [14]:
def fine_tune_model(train_loader,test_loader,model,optimizer, criterion,epochs):
    
    train_len = len(train_loader)    
    test_len = len(test_loader)
    for epoch in range(epochs):
        train_loss = 0.0
        total_val_loss = 0.0
        model_pred =0
        # Training the model
        model.train()
        counter = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            output = model(inputs)
            log_prob = torch.nn.functional.log_softmax(output, dim=1)
            loss = torch.nn.functional.nll_loss(log_prob, labels)
            #optimizer.zero_grad()
            #outputs = model.forward(inputs)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * inputs.size(0)

        # Evaluating the model
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)  
                
                
                output = model(inputs)
                val_log_prob = torch.nn.functional.log_softmax(output, dim=1)
                val_loss = torch.nn.functional.nll_loss(val_log_prob, labels)
            
                total_val_loss += val_loss.item() * inputs.size(0)

        train_loss = train_loss/train_len
        valid_loss = total_val_loss/test_len
        #print('[%d] Training Loss: %.6f, Validation Loss: %.6f'  % (epoch + 1, train_loss, valid_loss))
    torch.cuda.empty_cache()
    return model


In [15]:
def train_val_dataset(dataset, val_split=0.25):
    train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=val_split)
    datasets = {}
    datasets['train'] = torch.utils.data.Subset(dataset, train_idx)
    datasets['val'] = torch.utils.data.Subset(dataset, val_idx)
    return datasets


In [16]:
path = 'examples_ds_from_MP.train.HTW.train'
train_transforms = transforms.Compose([transforms.ToTensor(),
                                       transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0),
                                       #transforms.Resize(317),
                                       transforms.RandomResizedCrop(224),
                                       transforms.RandomRotation(30),
                                       #transforms.CenterCrop(224),
                                       transforms.Normalize(
                                           mean=[0.86121925, 0.86814779, 0.88314296], 
                                           std=[0.13475281, 0.10909398, 0.09926313])
                                      ])

test_transforms = transforms.Compose([
                                     #transforms.Resize(317),
                                      # transforms.CenterCrop(224),
                                      transforms.Normalize(
                                           mean=[0.86121925, 0.86814779, 0.88314296], 
                                           std=[0.13475281, 0.10909398, 0.09926313]),
                                     transforms.ToTensor()])


train_data = torchvision.datasets.ImageFolder(path, transform=train_transforms)
test_data = torchvision.datasets.ImageFolder(path, transform=test_transforms)


#print(len(train_data))
datasets = train_val_dataset(train_data)
print(train_data.class_to_idx)
#print(datasets['val'][5])
# The original dataset is available in the Subset class
#print(datasets['train'].dataset)

dataloaders = {x:torch.utils.data.DataLoader(datasets[x],32, shuffle=True, num_workers=4) for x in ['train','val']}
#x,y = next(iter(dataloaders['train']))

{'False': 0, 'True': 1}


In [None]:

train_loader = dataloaders['train']
test_loader = dataloaders['val']

trainloader = torch.utils.data.DataLoader(train_loader, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(test_loader, batch_size=32,shuffle=True)
optimizer = Lamb.Lamb(model.parameters(), lr=0.00025, weight_decay=0.1)
criterion = nn.CrossEntropyLoss()
epochs = 10

f = open('./result_log.txt', 'w') 
f.write("Concept Pruning log")
f.write("\n=========================================================================================")
pred= evaluate_model(model)
print("\nInitial pred: {}".format(pred))
f.write("\nInitial pred: {}".format(pred))
for i in range(0,5):
    f.write("\n=========================================================================================")
    f.write("\niteration {} , finding irrelevant concepts...".format(i))
    prunable_channels=find_sample_attributions(i,False)
    print("Voted least relevant concept ids to prune (max vote ids): {}".format(prunable_channels))
    f.write("\nVoted least relevant concept ids to prune (max vote ids): {}".format(prunable_channels))
   
    #print("Visualising irrelevant concepts...")
    #f.write("Visualising irrelevant concepts...")
    #irrelevant_concepts(prunable_channels,i)
    
    model = prune_selected_concepts(prunable_channels,model)
    pred = evaluate_model(model)
    print("iteration {} , pred after pruning:  {}".format(i, pred))
    f.write("\niteration {} , pred after pruning:  {}".format(i, pred))
    
    
    model = fine_tune_model(train_loader,test_loader,model,optimizer, criterion,epochs)
 
    pred = evaluate_model(model)
    print("iteration {} , pred after pruning AND finetuning:  {}".format(i, pred))
    f.write("\niteration {} , pred after pruning AND finetuning:  {}".format(i, pred))
    
    #prunable_channels=find_sample_attributions(i,True)
    



    f.write("\n=========================================================================================")
f.close()


  pred = float(torch.nn.functional.softmax(pred)[0,1].numpy())



Initial pred: 0.8473105363547802


  array = (array - vmin) / (vmax - vmin)
  if isinstance(value, collections.Mapping):


Voted least relevant concept ids to prune (max vote ids): {'features.0': [30, 29, 24, 26, 19, 18, 20, 23, 35, 33]}
iteration 0 , pred after pruning:  0.8418153431266546
iteration 0 , pred after pruning AND finetuning:  0.9132532857358455
Voted least relevant concept ids to prune (max vote ids): {'features.15': [222], 'classifier.3': [983], 'classifier.0': [2061, 3898, 3571, 1524, 1153], 'features.4': [55], 'features.0': [22, 53]}
iteration 1 , pred after pruning:  0.9099357143044472
iteration 1 , pred after pruning AND finetuning:  0.9030077219009399
