In [None]:
from func import MyNet2, get_patient_loader
import torch
import torchvision.models as models
import zennit as zen
import pandas as pd
import matplotlib.pyplot as plt

from torchvision.models.vgg import VGG
from torch.nn.modules.pooling import MaxPool2d, AdaptiveAvgPool2d
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from zennit.rules import Epsilon, AlphaBeta
from zennit.types import Linear
from zennit.core import Composite
from zennit.attribution import Gradient
from torchvision.models.resnet import ResNet
from torch.nn.modules.activation import ReLU
from torch.nn.modules.container import Sequential
from torch.nn.modules.conv import Conv2d
import os
from restructure import *

"""
different components to the other model
<class 'func.MyNet'>
<class 'torchvision.models.resnet.ResNet'>
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
<class 'torchvision.models.resnet.Bottleneck'>
<class 'torch.nn.modules.batchnorm.BatchNorm1d'>
"""

# TODO: BatchNorm2d, Bottleneck, BatchNorm1d

In [None]:
model = MyNet2(my_pretrained_model=models.resnet50(weights="IMAGENET1K_V2"))
path = "./data/05072024_single__5e-06resnet2.pt"
model.load_state_dict(torch.load(path))
model.eval()
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model.to(device)

# model as a sequential for the restructuring
modules = []
modules.append(model.pretrained)
for layer in model.gene1:
    modules.append(layer)

sequential = nn.Sequential(*modules)


data_dir = "./data"
patient = "/p007"
base_path = data_dir+patient+"/Preprocessed_STDataset/"
merge = pd.read_csv(base_path + "merge.csv")
merge.head()
loader = get_patient_loader(data_dir, patient)

In [None]:
print(sequential)

In [4]:
model_copy = copy.deepcopy(sequential).to("cpu")
for i in range (1):
    input, target, name = get_img_target_name(loader,device,i)
    flood = find_a_ref(model_copy.to("cpu"), input.to("cpu"), y_ref=target.cpu(), method='flood', step_width=0.005, max_it=10e4, normalize_top=True)

In [None]:
def plot_relevance(att, filename = None):
    if filename is None:
        rel = att.sum(1).cpu()
    else:
        rel = torch.tensor(plt.imread(filename)).unsqueeze(0)
    # create an image of the visualize attribution
    img = zen.image.imgify(rel, symmetric=True, cmap='coldnhot')
    
    # show the image
    display(img)
    
def get_img_target_name(loader, device, tile_no):
    image, target, name = loader[tile_no]
    image = image.unsqueeze(0).to(device)
    image = image.float()
    target = torch.tensor(target[0]).to(device)
    return image, target, name

def relevance_and_plot(model, mapping_fn, composite, input = None):

    #composite = Composite(module_map=mapping_fn, canonizers=[canonizer])
    if input is None:
        input = torch.randn(1, 3, 224, 224).to(device)
    with Gradient(model, composite) as attributor:
        out, grad = attributor(input)
    imshow = input.to('cpu').squeeze().numpy().sum(axis=0)
    plt.imshow(imshow)
    plot_relevance(grad)
    print("out: ", out)

In [9]:
for i in range(len(out_target)):
    with open("./xai_log.txt", "a") as f:
        out, grad, target, name = out_target[i]
        s = "out: " + str(out.item()) + ", target: " + str(target.item()) + ", filename " + name.replace("//", "/") + "\n"
        f.write(s)
        plot_relevance(grad) 
        img = plt.imread(name)
        plt.imshow(img)
        plt.show()
        
        print(s)

In [5]:
from restructure import find_a_ref, restructure_model
import copy

can_res = zen.torchvision.ResNetCanonizer()
can = can_res
composite = zen.composites.EpsilonPlusFlat(canonizers=[can])

out_target = []
for i in range(20):
    input, target, name = get_img_target_name(loader,device,i)
    model_copy = copy.deepcopy(model)
    
    #ref = find_a_ref(model_copy)
    gene1 = restructure_model(model_copy.gene1, torch.tensor(0), in_layer=-3, out_layer=-1)
    model_copy.gene1 = gene1
    model_copy.to(device)
    with Gradient(model_copy, composite) as attributor:
        out_cpy, grad_cpy = attributor(input)
        if grad_cpy.count_nonzero() == 0:
            continue
            
            
    model_copy = copy.deepcopy(model)
    
    #ref = find_a_ref(model_copy)
    
    gene1 = restructure_model(model_copy.gene1, target, in_layer=-3, out_layer=-1)
    model_copy.gene1 = gene1
    model_copy.to(device)
    with Gradient(model_copy, composite) as attributor:
        out_tar, grad_tar = attributor(input)
        if grad_cpy.count_nonzero() == 0:
            continue
            
            
    model_copy = copy.deepcopy(sequential).to("cpu")
    #model_copy.to(device)

    flood = find_a_ref(model_copy, input.to("cpu"), y_ref=0, method='flood', step_width=0.005, max_it=10e4, normalize_top=False)
    #ref = find_a_ref(model_copy)
    gene1 = restructure_model(model_copy, flood, in_layer=-3, out_layer=-1)
    model_copy.gene1 = gene1
    model_copy.to(device)
    with Gradient(model_copy, composite) as attributor:
        out_flood0, grad_flood0 = attributor(input)
        if grad_cpy.count_nonzero() == 0:
            continue
            
    model_copy = copy.deepcopy(model)
    model_copy.to(device)
    
    flood = find_a_ref(model_copy, input, y_ref=target, method='flood', step_width=0.005, max_it=10e4, normalize_top=False)
    #ref = find_a_ref(model_copy)
    gene1 = restructure_model(model_copy.gene1, flood, in_layer=-3, out_layer=-1)
    model_copy.gene1 = gene1
    model_copy.to(device)
    with Gradient(model_copy, composite) as attributor:
        out_flood_tar, grad_flood_tar = attributor(input)
        if grad_cpy.count_nonzero() == 0:
            continue
    
    with Gradient(model, composite) as attributor:
        out_ori, grad_ori = attributor(input)
        if grad_ori.count_nonzero() == 0:
            continue
            
    
    out_target.append((out_ori, out_cpy, out_tar, out_flood0, out_flood_tar, grad_ori, grad_cpy, grad_tar, grad_flood0, grad_flood_tar, target, name))


for i in range(len(out_target)):
    with open("./xai_log.txt", "a") as f:
        out_ori, out_cpy, out_tar, out_flood0, out_flood_tar, grad_ori, grad_cpy, grad_tar, grad_flood0, grad_flood_tar, target, name = out_target[i]
        s = "out_ori: " + str(out_ori.item()) + ", out_ref0: " + str(out_cpy.item()) + ", out_tar: " + str(out_tar.item()) + ", out_flood0: " + str(out_flood0.item()) + ", out_flood_tar: " + str(out_flood_tar.item()) + ", target: " + str(target.item()) + ", filename: " + os.path.basename(name.replace("//", "/")) + "\n"
        f.write(s)
        plot_relevance(grad_ori)
        plot_relevance(grad_cpy)
        img = plt.imread(name)
        plt.imshow(img)
        plt.show()
        
        print(s)

In [45]:
model_copy.to(device)
y_t_ = model_copy.forward(input)
a_ref_ = model_copy[:-1](input)
b_ref_ = model_copy(input)[:-1]
c_ref_ = model_copy(input)



print(a_ref_)
print(b_ref_)
print(c_ref_)
print(c_ref_[:-1])
xxx = ["a"]
print(xxx[:-1] and xxx[-1])

#print(y_t_.item(), " ", a_ref_.item())
#model_copy

In [10]:
log = False
def module_map_orig(ctx, name, module):
    # check whether there is at least one child, i.e. the module is not a leaf
    try:
        next(module.children())
    except StopIteration:
        # StopIteration is raised if the iterator has no more elements,
        # which means in this case there are no children and module is a leaf
        pass
    else:
        # if StopIteration is not raised on the first element, module is not a leaf
        return None

    # if the module is not Linear, we do not want to assign a hook
    if not isinstance(module, Linear):
        return None

    # count the number of the leaves processed yet in 'leafnum'
    if 'leafnum' not in ctx:
        ctx['leafnum'] = 0
    else:
        ctx['leafnum'] += 1

    # the first 10 leaf-modules which are of type Linear should be assigned
    # the Alpha2Beta1 rule
    if ctx['leafnum'] < 10:
        if log:
            print(type(module), " -> AlphaBeta(alpha=2, beta=1)")
        return AlphaBeta(alpha=2, beta=1)
    if log:
        print(type(module), " -> Epsilon(epsilon=1e-3)")
    # all other rules should be assigned Epsilon
    return Epsilon(epsilon=1e-3)


def module_map_debug(ctx, name, module):
    if type(module) is MyNet:             return None
    if type(module) is ResNet:            return None
    if type(module) is ReLU:              return None
    if type(module) is Sequential:        return None
    if type(module) is VGG:               return None
    if type(module) is MaxPool2d:         return None
    if type(module) is Dropout:           return None
    if type(module) is AdaptiveAvgPool2d: return None
    
    if type(module) is Conv2d:            return Epsilon(epsilon=1e-3)
    if type(module) is Linear:
        # count the number of the leaves processed yet in 'leafnum'
        if 'leafnum' not in ctx:
            ctx['leafnum'] = 0
        else:
            ctx['leafnum'] += 1
        if ctx['leafnum'] < 10:
            if log:
                print(type(module), " -> AlphaBeta(alpha=2, beta=1)")
            return AlphaBeta(alpha=2, beta=1)
        if log:
            print(type(module), " -> Epsilon(epsilon=1e-3)")
        return Epsilon(epsilon=1e-3)
    
    # check whether there is at least one child, i.e. the module is not a leaf
    try:
        next(module.children())
    except StopIteration:
        # StopIteration is raised if the iterator has no more elements,
        # which means in this case there are no children and module is a leaf
        pass
    else:
        if log:
            print(type(module), " -> None")
        # if StopIteration is not raised on the first element, module is not a leaf
        return None
    if log:
        print(type(module), " isinstance(Linear): ", isinstance(module, Linear))
    # if the module is not Linear, we do not want to assign a hook
    if not isinstance(module, Linear):
        if log:
            print(type(module), " -> None")
        return None

    return None
def module_map_my_net_debug(ctx, name, module):
    if type(module) is MyNet:      return None
    if type(module) is ResNet:     return None
    if type(module) is ReLU:       return None
    if type(module) is Sequential: return None
    
    if type(module) is Conv2d:     return Epsilon(epsilon=1e-3)
    
    
    
    
    # check whether there is at least one child, i.e. the module is not a leaf
    try:
        next(module.children())
    except StopIteration:
        # StopIteration is raised if the iterator has no more elements,
        # which means in this case there are no children and module is a leaf
        pass
    else:
        if log:
            print(type(module), " -> None")
        # if StopIteration is not raised on the first element, module is not a leaf
        return None
    # if the module is not Linear, we do not want to assign a hook
    if not isinstance(module, Linear):
        if log:
            print(type(module), " -> None")
        return None

    # count the number of the leaves processed yet in 'leafnum'
    if 'leafnum' not in ctx:
        ctx['leafnum'] = 0
    else:
        ctx['leafnum'] += 1

    # the first 10 leaf-modules which are of type Linear should be assigned
    # the Alpha2Beta1 rule
    if ctx['leafnum'] < 10:
        if log:
            print("leafnum ", ctx['leafnum'], type(module) , " -> AlphaBeta(alpha=2, beta=1)")
        return AlphaBeta(alpha=2, beta=1)
    if log:
        print(type(module), " -> Epsilon(epsilon=1e-3)")
    # all other rules should be assigned Epsilon
    return Epsilon(epsilon=1e-3)

In [None]:
print(grad.shape)
print(grad.squeeze().shape)
r = plt.imread(filename)
print(r.shape)
plot_relevance(None, filename)

In [None]:
for i in range (len(out_target)):
    out, grad, target, name = out_target[i]
    print(target)

In [None]:
out, grad, target, name = out_target[1]
print(name)

In [None]:
class ResNetCanonizer(zen.torchvision.CompositeCanonizer):
    '''Canonizer for torchvision.models.resnet* type models. This applies SequentialMergeBatchNorm, as well as
    add a Sum module to the Bottleneck modules and overload their forward method to use the Sum module instead of
    simply adding two tensors, such that forward and backward hooks may be applied.'''

    def __init__(self):
        super().__init__((
            zen.torchvision.SequentialMergeBatchNorm(),
            zen.torchvision.ResNetBottleneckCanonizer(),
            zen.torchvision.ResNetBasicBlockCanonizer(),
        ))

In [None]:
newmodel = torch.nn.Sequential(*(list(model.children())[:-1]))
model_drop = models.resnet50(pretrained=True)
print(newmodel)

In [8]:
vgg = models.vgg16()
res = models.resnet50()
current_model = model
img, target, filename = get_img_target_name(loader, device, 2)

can_res = zen.torchvision.ResNetCanonizer()
composite = zen.composites.EpsilonPlusFlat(canonizers=[can_res])

out_target = []
for i in range(20):
    input, target, name = get_img_target_name(loader,device,i)
    with Gradient(model, composite) as attributor:
        out, grad = attributor(input)
        if grad.count_nonzero() == 0:
            continue
        out_target.append((out, grad, target, name))


#relevance_and_plot(current_model.to(device), module_map_debug, composite, img)
#print("target: ", target)

In [8]:
vgg = models.vgg16()
res = models.resnet50()
current_model = model
img, target, filename = get_img_target_name(loader, device, 2)

can_res = zen.torchvision.ResNetCanonizer()
composite = zen.composites.EpsilonPlusFlat(canonizers=[can_res])

out_target = []
for i in range(20):
    input, target, name = get_img_target_name(loader,device,i)
    with Gradient(model, composite) as attributor:
        out, grad = attributor(input)
        if grad.count_nonzero() == 0:
            continue
        out_target.append((out, grad, target, name))


#relevance_and_plot(current_model.to(device), module_map_debug, composite, img)
#print("target: ", target)

In [7]:
out_target = []
for i in range(20):
    input, target, name = get_img_target_name(loader,device,i)
    with Gradient(sequential, composite) as attributor:
        out, grad = attributor(input)
        if grad.count_nonzero() == 0:
            continue
        out_target.append((out, grad, target, name))