In [19]:
from __future__ import print_function
import argparse
import os
import shutil
import time

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from nets.light_cnn import LightCNN_9Layers
from nets.deepid import *
from nets.net_sphere import *
from nets.vgg import *
import numpy as np
import cv2

import matplotlib
matplotlib.use('Agg')


import numpy as np
import matplotlib.pyplot as plt

def vis_square(data):
    """Take an array of shape (n, height, width) or (n, height, width, 3)
       and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)"""
    
    # normalize data for display
    data = (data - data.min()) / (data.max() - data.min())
    
    
    # force the number of filters to be square
    n = int(np.ceil(np.sqrt(data.shape[0])))
    padding = (((0, n ** 2 - data.shape[0]),
               (0, 1), (0, 1))                 # add some space between filters
               + ((0, 0),) * (data.ndim - 3))  # don't pad the last dimension (if there is one)
    data = np.pad(data, padding, mode='constant', constant_values=1)  # pad with ones (white)
    
    # tile the filters into an image
    data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
    data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
    
    return data
    #plt.imshow(data); 
    #plt.axis('off')
    
    
    
networks_map = {'LightCNN-9': LightCNN_9Layers,
		'DeepID_256' : DeepID_256,
               'DeepID_256_gray' : DeepID_256_gray,
               'DeepID_128_gray' : DeepID_128_gray,
               'sphere20' : sphere20,
               'vgg11' : vgg11}

def get_network_fn(model_name, num_classes, weight_decay=0):
    if model_name not in networks_map:
        raise ValueError('Name of network unknown %s' % model_name)
    func = networks_map[model_name]
    return func(num_classes=num_classes)


    
cudnn.benchmark = True
    
resume = True
    

    
model = get_network_fn(model_name='sphere20', num_classes=10572, weight_decay=0)

    
#model.eval()
   
model = torch.nn.DataParallel(model).cuda()



if resume:
    checkpoint = torch.load('/data/zeng/pytorch_model/lightCNN_18_checkpoint.pth.tar')
    model.load_state_dict(checkpoint['state_dict'])
else:
    print("=> no checkpoint found at '{}'".format(args.resume))

transform = transforms.Compose([transforms.ToTensor()])
count     = 0
input     = torch.zeros(1, 1, 256, 256)
    

img   = cv2.imread('256.jpg', cv2.IMREAD_GRAYSCALE)
#img   = cv2.resize(img, (128,128))
img   = np.reshape(img, (256, 256, 1))
img   = transform(img)
input[0,:,:,:] = img

     
nput = input.cuda()
input_var   = torch.autograd.Variable(input, volatile=True)
_, features, conv1_map = model(input_var)

print('done')


because the backend has already been chosen;
matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
or matplotlib.backends is imported for the first time.



done


In [20]:
model_dict = model.state_dict()

In [21]:
x = model_dict['module.conv1_1.weight']
x.shape

torch.Size([64, 1, 7, 7])

In [22]:
x = x.cpu().numpy()

In [23]:
x = np.squeeze(x)
x.shape
#vis_square(x)

(64, 7, 7)

In [24]:
data = vis_square(x)

In [25]:
data.shape

(64, 64)

In [26]:
x = np.squeeze(x)

In [27]:
x.shape

(64, 7, 7)

In [28]:
data = vis_square(x)

In [29]:
data.shape

(64, 64)

In [31]:
data

array([[ 0.39445651,  0.32021371,  0.27811751, ...,  0.32452577,
         0.32295936,  1.        ],
       [ 0.26429904,  0.20403926,  0.15999581, ...,  0.3057771 ,
         0.30058306,  1.        ],
       [ 0.3303065 ,  0.27075404,  0.2326407 , ...,  0.27441439,
         0.26661143,  1.        ],
       ..., 
       [ 0.41551644,  0.406591  ,  0.39262244, ...,  0.39546499,
         0.39559203,  1.        ],
       [ 0.43560123,  0.43309858,  0.43345582, ...,  0.39556533,
         0.39559284,  1.        ],
       [ 1.        ,  1.        ,  1.        , ...,  1.        ,
         1.        ,  1.        ]], dtype=float32)