In [None]:
import os
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

## Utils define

In [None]:
def get_path_labels(path):
    imgnames = os.listdir(path)
    imgnames.sort()
    imgpathes = []
    imglabels = []
    for img in imgnames:
        imgpathes.append(os.path.join(path,img))
        imglabels.append(int(img.split('_')[0]))
    return imgpathes,imglabels

def normalize(image):
    return (image - image.min())/(image.max() - image.min())

def compute_saliency_maps(x,y,model,device):
    model.eval()
    model = model.to(device)
    x = x.to(device)
    y = y.to(device)
    x.requires_grad_()
    y_pred = model(x)
    criterion = torch.nn.CrossEntropyLoss()
    loss = criterion(y_pred,y)
    loss.backward()
    saliencies = x.grad.abs().detach().cpu()
    saliencies = torch.stack([normalize(item) for item in saliencies])
    return saliencies


## Classifier

In [None]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        # input [3,128,128]
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),  # [64,128,128]
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),  # [64,64,64]

            nn.Conv2d(64, 128, 3, 1, 1),  # [128,64,64]
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),  # [128,32,32]

            nn.Conv2d(128, 256, 3, 1, 1),  # [256,32,32]
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),  # [256,16,16]

            nn.Conv2d(256, 512, 3, 1, 1),  # [512,16,16]
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),  # [512,8,8]

            nn.Conv2d(512, 512, 3, 1, 1),  # [512,8,8]
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),  # [512,4,4]
        )
        self.fc = nn.Sequential(
            nn.Linear(512 * 4 * 4, 1024),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(512, 11)
        )

    def forward(self, x):
        out = self.cnn(x)
#         import pdb
#         pdb.set_trace()
        out = out.reshape(out.shape[0], -1)
        return self.fc(out)

## Dataset define

In [None]:
from PIL import Image

class FoodDataset(Dataset):
    def __init__(self, pathes, labels, model):
        self.pathes = pathes
        self.labels = labels
        trainTransform = transforms.Compose([
            transforms.Resize(size=(128, 128)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
        ])
        evalTransform = transforms.Compose([
            transforms.Resize(size=(128, 128)),
            transforms.ToTensor(),
        ])
        self.transform = trainTransform if model == 'train' else evalTransform

    def __len__(self):
        return len(self.pathes)

    def __getitem__(self, index):
        X = Image.open(self.pathes[index])
        X = self.transform(X)
        y = self.labels[index]
        return X,y

    def getbatch(self,indices):
        images,labels = [],[]
        for index in indices:
            X,y = self.__getitem__(index)
            images.append(X)
            labels.append(y)
        return torch.stack(images),torch.tensor(labels)

## Global define

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Loading Model

In [None]:
model_path = '../input/hw3-model/model_hw3.pth'
model = Classifier().to(device)
model.load_state_dict(torch.load(model_path,map_location=device))

## Saliency Maps

In [None]:
dataset_dir = '../input/ml2020spring-hw3/food-11/'

# select images to get saliencies
img_indices = [83, 4218, 4707, 8598]

# getting selected images
imgpathes, imglabels = get_path_labels(os.path.join(dataset_dir,'training'))
train_dataset = FoodDataset(imgpathes,imglabels,'train')
images,labels = train_dataset.getbatch(img_indices)

# computing saliency maps
saliencies = compute_saliency_maps(images,labels,model,device)

fig,axes = plt.subplots(2,len(images),figsize=(15,8))
for row,target in enumerate([images,saliencies]):
    for column,img in enumerate(target):
        axes[row][column].imshow(img.permute(1,2,0).detach())

可以看到，我的模型并没有很好的找到目标的特征。它把注意力集中再目标的边缘，这可能就是我的模型效果不够好的原因。

下面我来看看，CNN中的filter到底看到了些什么，还有什么样的图片最能激活filter。

In [None]:
layer_activations = None #定义一个全局变量，用于将目标filter的输出结果抽出来

def filter_explaination(x,model,cnnid,filterid,device,iteration,lr):
    '''
    该函数用于解释CNN中的filter行为
    cnnid, filterid - 指定要分析的是第几层layer的第几个filter
    输出fiter activations和filter visulization
    '''
    # 将model和x移动到device上
    model = model.to(device)
    x = x.to(device)
    
    #set model in eval model
    model.eval() 
    
    # define a hook to save FILTER OUTPUT in layer_activations
    def hook(model,input,output):
        global layer_activations
        layer_activations = output

    # 挂钩子
    hook_handle = model.cnn[cnnid].register_forward_hook(hook)
    
    # forward model
    model(x)
    
    # 现在抽取了指定层layer的所有filter，现在只要特定的filter
    filter_activations = layer_activations[:,filterid,:,:]
    
    # 我们已经获取了activation map，我们只是要将它画出来，所以可以将它从graph中detach出来，并转换成为一个cpu tensor
    filter_activations = filter_activations.detach().cpu()
    
    # 现在我们来找到最大程度activate filter的图片，可以从random noise的图片开始找，也可以从指定图片开始找
    # 要找到可以最大程度activate filter的图片，需要对x做grad descent
    x.requires_grad_()
    optimizer = optim.Adam([x],lr=lr)
    for it in range(iteration):
        optimizer.zero_grad() # 不要忘记zero grad
        model(x)
        
        # 定义一个优化目标，这里用fiter输出的sum
        # 之所以加上负号，是因为要找最大值
        # 之所以没有用filter_activations这个变量，是因为它已经不在graph里了，无法计算grad
        objective = -layer_activations[:,filterid,:,:].sum()
    
        objective.backward()
        optimizer.step()
        
        print(f'iter: {it+1}/{iteration} objective: {objective}',end='\r')
    
    # filter_visualization只是用来图形化，所以可以从graph detach掉，并转换为一个cpu tensor
    filter_visualization = x.detach().cpu().squeeze()[0]
    
    # 摘钩子
    hook_handle.remove()
    
    return filter_activations,filter_visualization

filter_activations,filter_visualization = filter_explaination(images,model,cnnid=8,filterid=0,
                                                              device=device,iteration=100,lr=1)

In [None]:
# show the layer_activations
images,labels = train_dataset.getbatch(img_indices)
fig,axes = plt.subplots(2,len(images),figsize=(15,8))
for i,img in enumerate(images):
    axes[0][i].imshow(img.permute(1,2,0).detach())
for i,img in enumerate(filter_activations):
    axes[1][i].imshow(img)

plt.figure()
plt.imshow(normalize(filter_visualization.permute(1,2,0)))

## Lime
Lime是一个现成的套件，只要实现两个函数，就可以用lime来解释模型的行为

In [None]:
def predict(input):
    # input: numpy array, (batches, height, width, channels)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model.to(device)
    model.eval()
    
    # 需要先将input转成pytorch tensor (batches, channels, height, width)
    input = torch.FloatTensor(input).permute(0,3,1,2)
    
    input = input.to(device)

    output = model(input)
    
    return output.detach().cpu().numpy()

from skimage.segmentation import slic
from lime import lime_image

def segmentation(input):
    # 利用 skimage 提供的 segmentation 将图片分成 100 块
    return slic(input,n_segments=100,compactness=1,sigma=1)

img_indices = [83, 4218, 4707, 8598]
# img_indices = [i+1 for i in img_indices]

images, labels = train_dataset.getbatch(img_indices)

# ************************************
# 我的模型太烂了，lime只能解释最重要的6种label，不打个补丁无法跑啊...
# ************************************
labels = torch.tensor([2, 9, 2, 2])

# 让实验reproducible
np.random.seed(16)

fig,axes = plt.subplots(1,4,figsize=(15,8))

for idx, (image, label) in enumerate(zip(images.permute(0,2,3,1).numpy(),labels)):
    x = image.astype(np.double)
    # lime 這個套件要吃 numpy array

    explainer = lime_image.LimeImageExplainer()                                                                                                                              
    explaination = explainer.explain_instance(image=x, classifier_fn=predict, segmentation_fn=segmentation)
    # 基本上只要提供給 lime explainer 兩個關鍵的 function，事情就結束了
    # classifier_fn 定義圖片如何經過 model 得到 prediction
    # segmentation_fn 定義如何把圖片做 segmentation
    # doc: https://lime-ml.readthedocs.io/en/latest/lime.html?highlight=explain_instance#lime.lime_image.LimeImageExplainer.explain_instance

    lime_img, mask = explaination.get_image_and_mask(                                                                                                                         
                                label=label.item(),                                                                                                                           
                                positive_only=False,                                                                                                                         
                                hide_rest=False,                                                                                                                             
                                num_features=11,                                                                                                                              
                                min_weight=0.05                                                                                                                              
                            )
    # 把 explainer 解釋的結果轉成圖片
    # doc: https://lime-ml.readthedocs.io/en/latest/lime.html?highlight=get_image_and_mask#lime.lime_image.ImageExplanation.get_image_and_mask
    
    axes[idx].imshow(lime_img)