In [23]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import zennit
from zennit.torchvision import VGGCanonizer
from zennit.composites import EpsilonPlusFlat
from crp.attribution import CondAttribution
from crp.concepts import ChannelConcept
from crp.helper import get_layer_names
from crp.visualization import FeatureVisualization
import torchvision.transforms as T
from PIL import Image


In [24]:
from data_loader import get_dataset
from torchvision.models import vgg16

model = vgg16(weights="IMAGENET1K_V1")
model.eval()
model = model.to()
data_dir = "../Training_Data/"
dataset = get_dataset(data_dir)
from torchvision import transforms

class STDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, device="mps", transforms=transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            # mean and std of the whole dataset
            transforms.Normalize([0.7406, 0.5331, 0.7059], [0.1651, 0.2174, 0.1574])
            ])):
        self.dataframe = dataframe
        self.transforms = transforms
        self.device = device

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        gene_names = list(self.dataframe)[1:]
        gene_vals = []
        row = self.dataframe.iloc[index]
        a = Image.open(row["tile"]).convert("RGB")
        # print(x.size)
        for j in gene_names:
            gene_val = float(row[j])
            gene_vals.append(gene_val)
        e = row["tile"]
        # apply normalization transforms as for pretrained colon classifier
        a = self.transforms(a)
        a = a.to(self.device)
        return a, 0
datasetST = STDataset(dataset)
attribution = CondAttribution(model)
composite = EpsilonPlusFlat(canonizers=[VGGCanonizer()])


In [25]:
layer_names = get_layer_names(model, [torch.nn.Conv2d])
layer_map = {layer: ChannelConcept() for layer in layer_names}


preprocessing =  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
fv = FeatureVisualization(attribution, datasetST, layer_map, preprocess_fn=preprocessing, path=f"../crp_out/tmp")
fv.run(composite, 0, len(dataset) // 1, batch_size=32) # needs to be run only once
print("CRP preprocessing done.")

