In [4]:
from __future__ import print_function
import argparse
import os
import shutil
import time

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from nets.light_cnn import LightCNN_9Layers
from nets.deepid import *
from nets.net_sphere import *
from nets.vgg import *
import numpy as np
import cv2

import numpy as np
import matplotlib.pyplot as plt

def vis_square(data):
    """Take an array of shape (n, height, width) or (n, height, width, 3)
       and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)"""
    
    # normalize data for display
    data = (data - data.min()) / (data.max() - data.min())
    
    # force the number of filters to be square
    n = int(np.ceil(np.sqrt(data.shape[0])))
    padding = (((0, n ** 2 - data.shape[0]),
               (0, 1), (0, 1))                 # add some space between filters
               + ((0, 0),) * (data.ndim - 3))  # don't pad the last dimension (if there is one)
    data = np.pad(data, padding, mode='constant', constant_values=1)  # pad with ones (white)
    
    # tile the filters into an image
    data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
    data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
    
    plt.imshow(data); plt.axis('off')
    
    
    
networks_map = {'LightCNN-9': LightCNN_9Layers,
		'DeepID_256' : DeepID_256,
               'DeepID_256_gray' : DeepID_256_gray,
               'DeepID_128_gray' : DeepID_128_gray,
               'sphere20' : sphere20,
               'vgg11' : vgg11}

def get_network_fn(model_name, num_classes, weight_decay=0):
    if model_name not in networks_map:
        raise ValueError('Name of network unknown %s' % model_name)
    func = networks_map[model_name]
    return func(num_classes=num_classes)


    
cudnn.benchmark = True
    
resume = True
    

    
model = get_network_fn(model_name='sphere20', num_classes=10572, weight_decay=0)

    
#model.eval()
   
model = torch.nn.DataParallel(model).cuda()



if resume:
    checkpoint = torch.load('/data/zeng/pytorch_model/lightCNN_18_checkpoint.pth.tar')
    model.load_state_dict(checkpoint['state_dict'])
else:
    print("=> no checkpoint found at '{}'".format(args.resume))

transform = transforms.Compose([transforms.ToTensor()])
count     = 0
input     = torch.zeros(1, 1, 256, 256)
    

img   = cv2.imread('256.jpg', cv2.IMREAD_GRAYSCALE)
#img   = cv2.resize(img, (128,128))
img   = np.reshape(img, (256, 256, 1))
img   = transform(img)
input[0,:,:,:] = img

     
nput = input.cuda()
input_var   = torch.autograd.Variable(input, volatile=True)
_, features, conv1_map = model(input_var)

print('done')


done


In [5]:
model_dict = model.state_dict()

In [6]:
x = model_dict['module.conv1_1.weight']
x.shape

torch.Size([64, 1, 7, 7])

In [7]:
x = x.cpu().numpy()


In [9]:
x

array([[[ -7.70100916e-04,  -5.30923717e-02,  -8.27594772e-02, ...,
          -1.12971440e-01,  -9.19398665e-02,  -5.19350059e-02],
        [ -9.24979821e-02,  -1.34965777e-01,  -1.66005194e-01, ...,
          -2.05518723e-01,  -1.79196224e-01,  -1.38189226e-01],
        [ -4.59795147e-02,  -8.79488364e-02,  -1.14809044e-01, ...,
          -1.44203007e-01,  -1.20055206e-01,  -8.58046487e-02],
        ..., 
        [  1.48585346e-02,  -6.85511250e-03,  -2.20863689e-02, ...,
          -3.08924560e-02,  -1.20551707e-02,  -2.11474369e-03],
        [  7.58529408e-03,  -1.01346280e-02,  -1.78624522e-02, ...,
          -1.60192121e-02,  -5.92501508e-03,   7.74477096e-03],
        [  1.49266934e-02,   5.71597740e-03,   1.20110111e-02, ...,
           2.33823825e-02,   2.38838382e-02,   3.63039412e-02]],

       [[ -2.42462084e-02,  -2.83199083e-02,  -3.08440626e-02, ...,
          -3.20107751e-02,  -2.70665847e-02,  -1.96341742e-02],
        [ -2.63262633e-02,  -2.96616443e-02,  -3.38706039e-0