<a href="https://colab.research.google.com/github/komazawa-deep-learning/komazawa-deep-learning.github.io/blob/master/notebooks/2020_0514komazawa_visualise_first_layers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PyTorch で定義されている第一層目の結合の視覚化

<font size="+2" color="teal"><strong>ヒューベルとウィーゼルの受容野，線分の方位選択性，色ブロッブが形成されていることを確認する</strong></font>

- date: 2020-0514
- author: 浅川伸一
- note: 定義済モデルの第一層を視覚化するデモ


In [0]:
from IPython.display import clear_output, Image, display
import numpy as np
import PIL.Image

import matplotlib.pyplot as plt
%matplotlib inline

In [0]:
import IPython.display

In [0]:
import torch
import torchvision.models as models
from matplotlib import pyplot as plt
import torchvision

In [0]:
#help(dir)

In [5]:
# ここで Pytorch にはどのようなモデルが定義されているか探してみましょう。
print(dir(models))

['AlexNet', 'DenseNet', 'GoogLeNet', 'GoogLeNetOutputs', 'Inception3', 'InceptionOutputs', 'MNASNet', 'MobileNetV2', 'ResNet', 'ShuffleNetV2', 'SqueezeNet', 'VGG', '_GoogLeNetOutputs', '_InceptionOutputs', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_utils', 'alexnet', 'densenet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'detection', 'googlenet', 'inception', 'inception_v3', 'mnasnet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3', 'mobilenet', 'mobilenet_v2', 'quantization', 'resnet', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'resnext101_32x8d', 'resnext50_32x4d', 'segmentation', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0', 'shufflenetv2', 'squeezenet', 'squeezenet1_0', 'squeezenet1_1', 'utils', 'vgg', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'video', 'wide_resnet101_2', 'wide_r

In [0]:
# ResNet を見てみましょう
model = models.resnet50(pretrained=True)
w = model.conv1.weight.data

grid = torchvision.utils.make_grid(w, nrow=8, normalize=True, scale_each=True)
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.imshow(grid.permute(1, 2, 0))

In [0]:
# 続いて AlexNet
model = models.alexnet(pretrained=True)
w = model.features[0].weight.data

grid = torchvision.utils.make_grid(w, nrow=8, normalize=True, scale_each=True)
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.imshow(grid.permute(1, 2, 0))

In [0]:
# こんどは MNasNet です
model = models.mnasnet1_0(pretrained=True)

w = model.layers[0].weight.data
grid = torchvision.utils.make_grid(w, nrow=8, normalize=True, scale_each=True)
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.imshow(grid.permute(1, 2, 0))

In [0]:
model = models.mnasnet0_5(pretrained=True)

w = model.layers[0].weight.data
grid = torchvision.utils.make_grid(w, nrow=8, normalize=True, scale_each=True)
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.imshow(grid.permute(1, 2, 0))

In [10]:
model = models.vgg16(pretrained=True)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/checkpoints/vgg16-397923af.pth


HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))




In [0]:
#print(dir(model.features[0]))
#print(model.features[0].weight.data)
w = model.features[0].weight.data
grid = torchvision.utils.make_grid(w, nrow=8, normalize=True, scale_each=True)
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.imshow(grid.permute(1, 2, 0))

In [0]:
model.features

In [0]:
!wget https://raw.githubusercontent.com/komazawa-deep-learning/komazawa-deep-learning.github.io/master/assets/1991Felleman_VanEssen_fig2.jpg


In [0]:
filename = "1991Felleman_VanEssen_fig2.jpg"

plt.figure(figsize=(10,10))
plt.axis('off')
plt.imshow(plt.imread(filename))


In [0]:
#!ls *.{jpg,png} 

#plt.imshow(plt.imread('hogehoge.png'))

##display(PIL.Image.open('keras-vgg16-model_09.png'))
#IPython.display.Image(filename='hogehoge.png')
##help(clear_output)
##help(Image)
##help(display)

#import IPython.display
#help(IPython.display)
#help(Image)
# help(display)

#plt.imshow(plt.imread('deer.jpg'));plt.show()
#display(PIL.Image.open('deer.jpg'))
#IPython.display.Image(filename="deer.jpg")

#!ls ${HOME}/study/2020chuo/assets/*.pdf

In [0]:
# source: https://discuss.pytorch.org/t/understanding-deep-network-visualize-weights/2060/7
import torch
import torchvision.models as models
from matplotlib import pyplot as plt

#def plot_kernels(tensor, num_cols=6):
def plot_kernels(tensor, num_cols=8):
    if not tensor.ndim==4:
        raise Exception("assumes a 4D tensor")
    if not tensor.shape[-1]==3:
        raise Exception("last dim needs to be 3 to plot")
    num_kernels = tensor.shape[0]
    num_rows = 1+ num_kernels // num_cols
    fig = plt.figure(figsize=(num_cols,num_rows))
    for i in range(tensor.shape[0]):
        ax1 = fig.add_subplot(num_rows,num_cols,i+1)
        #ax1.imshow(tensor[i])
        ax1.imshow((255 * tensor[i]).astype('uint8'))
        ax1.axis('off')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])

    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    plt.show()
    
  
vgg = models.vgg16(pretrained=True)
mm = vgg.double()
filters = mm.modules
body_model = [i for i in mm.children()][0]
layer1 = body_model[0]
tensor = layer1.weight.data.numpy()
plot_kernels(tensor)