In [59]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

import torch.nn as nn
import torch.nn.functional as F



class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.features = torch.nn.Sequential(
            
            #conv1
            nn.Conv2d(1, 10, kernel_size=5),   
            nn.ReLU(),                         #F.ReLU when given input already, nn.ReLU is idle (curry?)
            nn.MaxPool2d(2,stride = 2, return_indices = True),        
            
            #conv2
            nn.Conv2d(10, 20, kernel_size=5),
            #nn.Dropout2d()
            nn.ReLU(),
            nn.MaxPool2d(2,stride = 2, return_indices = True))
            
#             #fully connected layers
#             nn.Linear(320, 50),
#             nn.Linear(50, 10))
        
        self.feature_outputs = [0]*len(self.features)
        self.pool_indices = dict()
        
        self.classifier = torch.nn.Sequential(
            nn.Linear(320, 50),  # 224x244 image pooled down to 7x7 from features
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(50, 10))
        
        
        def forward_features(self,x):
            output = x
            for i,layer in enumerate(self.features):
                if isinstance(layer, nn.MaxPool2d):
                    print("im a maxpool!")
                    output, indices = layer(output)
                    self.features_output[i] = output
                    self.pool_indices[i] = indices
                else:
                    output = layer(output)
                    self.features_output[i] = output
            return output
                    
                

        def forward(self, x):
            output = self.forward_features(x)
            output = output.view(output.size()[0], -1)
            output = self.classifier(output)
            return output
            
class DNN(nn.Module):
    def __init__(self):
        super(DNN,self).__init__()

        self.deconv_features = torch.nn.Sequential(
            nn.MaxUnpool2d(2, stride=2),
            nn.ConvTranspose2d(20, 10, kernel_size=5, padding=1),
            nn.MaxUnpool2d(2, stride=2),
            nn.ConvTranspose2d(10, 1, kernel_size=5, padding=1))

        # not the most elegant, given that I don't need the MaxUnpools here
        self.deconv_first_layers = torch.nn.ModuleList([
            torch.nn.MaxUnpool2d(2, stride=2),
            torch.nn.ConvTranspose2d(1, 10, 3, padding=1),
            torch.nn.MaxUnpool2d(2, stride=2),
            torch.nn.ConvTranspose2d(1, 1, 3, padding=1) ])


    def forward(self, x, layer_number, map_number, pool_indices):
        start_idx = self.conv2DeconvIdx[layer_number]
        if not isinstance(self.deconv_first_layers[start_idx], torch.nn.ConvTranspose2d):
            raise ValueError('Layer '+str(layer_number)+' is not of type Conv2d')
        # set weight and bias
        self.deconv_first_layers[start_idx].weight.data = self.deconv_features[start_idx].weight[map_number].data[None, :, :, :]
        self.deconv_first_layers[start_idx].bias.data = self.deconv_features[start_idx].bias.data        
        # first layer will be single channeled, since we're picking a particular filter
        output = self.deconv_first_layers[start_idx](x)

        # transpose conv through the rest of the network
        for i in range(start_idx+1, len(self.deconv_features)):
            if isinstance(self.deconv_features[i], torch.nn.MaxUnpool2d):
                output = self.deconv_features[i](output, pool_indices[self.unpool2PoolIdx[i]])
            else:
                output = self.deconv_features[i](output)
        return output
                    
            
      
                        



In [60]:
from math import sqrt, ceil
import numpy as np

def visualize_grid(Xs, ubound=255.0, padding=1):
  """
  Reshape a 4D tensor of image data to a grid for easy visualization.
  Inputs:
  - Xs: Data of shape (N, H, W, C)
  - ubound: Output grid will have values scaled to the range [0, ubound]
  - padding: The number of blank pixels between elements of the grid
  """
  (N, H, W, C) = Xs.shape
  grid_size = int(ceil(sqrt(N)))
  grid_height = H * grid_size + padding * (grid_size - 1)
  grid_width = W * grid_size + padding * (grid_size - 1)
  grid = np.zeros((grid_height, grid_width, C))
  next_idx = 0
  y0, y1 = 0, H
  for y in range(grid_size):
    x0, x1 = 0, W
    for x in range(grid_size):
      if next_idx < N:
        img = Xs[next_idx]
        low, high = np.min(img), np.max(img)
        grid[y0:y1, x0:x1] = ubound * (img - low) / (high - low)
        # grid[y0:y1, x0:x1] = Xs[next_idx]
        next_idx += 1
      x0 += W + padding
      x1 += W + padding
    y0 += H + padding
    y1 += H + padding
  # grid_max = np.max(grid)
  # grid_min = np.min(grid)
  # grid = ubound * (grid - grid_min) / (grid_max - grid_min)
  return grid

def vis_grid(Xs):
  """ visualize a grid of images """
  (N, H, W, C) = Xs.shape
  A = int(ceil(sqrt(N)))
  G = np.ones((A*H+A, A*W+A, C), Xs.dtype)
  G *= np.min(Xs)
  n = 0
  for y in range(A):
    for x in range(A):
      if n < N:
        G[y*H+y:(y+1)*H+y, x*W+x:(x+1)*W+x, :] = Xs[n,:,:,:]
        n += 1
  # normalize to [0,1]
  maxg = G.max()
  ming = G.min()
  G = (G - ming)/(maxg-ming)
  return G
  
def vis_nn(rows):
  """ visualize array of arrays of images """
  N = len(rows)
  D = len(rows[0])
  H,W,C = rows[0][0].shape
  Xs = rows[0][0]
  G = np.ones((N*H+N, D*W+D, C), Xs.dtype)
  for y in range(N):
    for x in range(D):
      G[y*H+y:(y+1)*H+y, x*W+x:(x+1)*W+x, :] = rows[y][x]
  # normalize to [0,1]
  maxg = G.max()
  ming = G.min()
  G = (G - ming)/(maxg-ming)
  return G

In [61]:
#!/usr/bin/env python3


import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import sys

def vis_layer(activ_map):
    plt.clf()
    plt.subplot(121)
    plt.imshow(activ_map[:,:,0], cmap='gray')

def decon_img(layer_output):
    raw_img = layer_output.data.numpy()[0].transpose(1,2,0)
    img = (raw_img-raw_img.min())/(raw_img.max()-raw_img.min())*255
    img = img.astype(np.uint8)
    return img

if __name__ == '__main__':
#     if len(sys.argv) < 2:
#         print('Usage: '+sys.argv[0]+' img_file')
#         sys.exit(0)

    img_filename = '6.png'

    n_classes = 1000 # using ImageNet pretrained weights

    vgg16_c = Net()
    conv_layer_indices = [0,3]

    img = np.asarray(Image.open(img_filename).resize((28,28)))
    
    
    img_var = torch.autograd.Variable(torch.FloatTensor(img.transpose(1,0)[np.newaxis,:,:].astype(float)))

    conv_out = vgg16_c(img_var)
    print('VGG16 model:')
    print(vgg16_c)

    plt.ion() # remove blocking
    plt.figure(figsize=(10,5))
    vgg16_d = DNN()
    done = False
    while not done:
        layer = input('Layer to view (0-30, -1 to exit): ')
        try:
            layer = int(layer)
        except ValueError:
            continue
            
        if layer < 0:
            sys.exit(0)
        activ_map = vgg16_c.feature_outputs[layer].data.numpy()
        activ_map = activ_map.transpose(1,2,3,0)
        activ_map_grid = vis_grid(activ_map)
        vis_layer(activ_map_grid)

        # only transpose convolve from Conv2d or ReLU layers
        conv_layer = layer
        if conv_layer not in conv_layer_indices:
            conv_layer -= 1
            if conv_layer not in conv_layer_indices:
                continue

        n_maps = activ_map.shape[0]

        marker = None
        while True:
            choose_map = input('Select map?  (y/[n]): ') == 'y'
            if marker != None:
                marker.pop(0).remove()

            if not choose_map:
                break

            _, map_x_dim, map_y_dim, _ = activ_map.shape
            map_img_x_dim, map_img_y_dim, _ = activ_map_grid.shape
            x_step = map_img_x_dim//(map_x_dim+1)

            print('Click on an activation map to continue')
            x_pos, y_pos = plt.ginput(1)[0]
            x_index = x_pos // (map_x_dim+1)
            y_index = y_pos // (map_y_dim+1)
            map_idx = int(x_step*y_index + x_index)

            if map_idx >= n_maps:
                print('Invalid map selected')
                continue

            decon = vgg16_d(vgg16_c.feature_outputs[layer][0][map_idx][None,None,:,:], conv_layer, map_idx, vgg16_c.pool_indices)
            img = decon_img(decon)
            plt.subplot(121)
            marker = plt.plot(x_pos, y_pos, marker='+', color='red')
            plt.subplot(122)
            plt.imshow(img)

NotImplementedError: 