In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from models import LeNet
import os

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_dir = '../data/'
classes = sorted(os.listdir(data_dir + 'test/'))
batch_size = 64

In [74]:
mean = [0.44947562, 0.46524084, 0.40037745]
std = [0.18456618, 0.16353698, 0.20014246]

data_transforms = {
        'test': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)])}

test_images = datasets.ImageFolder(os.path.join(data_dir, 'test'),
                    data_transforms['test'])

test_dataloader = DataLoader(test_images, batch_size=batch_size, shuffle=False, num_workers=4)

In [67]:
model = LeNet()
#model.load_state_dict(torch.load('model', map_location=str(device)))
model.eval()

LeNet(
  (features): Sequential(
    (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=59536, out_features=120, bias=True)
    (1): ReLU(inplace)
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): ReLU(inplace)
    (4): Linear(in_features=84, out_features=10, bias=True)
  )
)

In [84]:
def imshow(filters, num_cols, title=None):
    #inp = inp.numpy().transpose((1, 2, 0))
    filters = np.asarray(filters)
    filters = np.asarray(std).mean() * filters + np.asarray(mean).mean()
    filters = np.clip(filters, 0, 1)
    num_rows = 1
    fig = plt.figure(figsize=(num_cols * 5, num_rows * 5))
    i = 0
    for kernel in filters:
        ax1 = fig.add_subplot(num_rows, num_cols, i+1)
        ax1.imshow(kernel, interpolation='none')
        ax1.axis('off')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        i += 1
    if title is not None:
        plt.title(title)
    plt.pause(0.001)

In [85]:
with torch.no_grad():
    for data in test_dataloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        for layer in model.get_features:
            outputs = layer(inputs)
            inputs = outputs
            if isinstance(layer, nn.Conv2d):
                for i in range(len(outputs)):
                    imshow(outputs[i], outputs[i].size()[0])
                    break
        break

ValueError: operands could not be broadcast together with shapes (252,252) (3,) 

<Figure size 2160x360 with 0 Axes>

In [None]:
def plot_kernels(tensor, num_cols):
    """
    mean = 0.5
    std = 0.5
    maxVal = tensor.max()
    minVal = abs(tensor.min())
    maxVal = max(maxVal,minVal)
    tensor = tensor / maxVal
    tensor = tensor / 2
    tensor = tensor + 0.5
    """
    tensor = tensor * np.asarray(std).mean() + np.asarray(mean).mean()
    num_rows = 1
    fig = plt.figure(figsize=(num_cols, num_rows))
    i = 0
    for t in tensor:
        ax1 = fig.add_subplot(num_rows, num_cols, i+1)
        pilTrans = transforms.ToPILImage()
        pilImg = pilTrans(t)
        ax1.imshow(pilImg, interpolation='none')
        ax1.axis('off')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        i += 1
    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    plt.show()
    
for m in model.modules():
    if isinstance(m, nn.Conv2d):
        plot_kernels(m.weight.data, len(m.weight.data))


In [None]:
# plot filter by filter
for m in model.modules():
    if isinstance(m, nn.Conv2d):
        for i in range(m.weight.shape[0]):
            imshow(m.weight.data[i], m.weight.shape[0])