In [None]:
from __future__ import print_function

import numpy as np 
import tensorflow as tf
import torch

import torchvision.models as models
import torchvision


In [None]:
model = models.squeezenet1_1(pretrained=True)

In [None]:
type_lookups = {}

def conv2d(c,**kwargs):
    padding = 'VALID' if c.padding[0] is 0 else 'SAME'
    #print(c,padding)
    filters = c.out_channels
    size = c.kernel_size
    parameters = [p for p in c.parameters()]
    W = parameters[0].data.numpy()
    if len(parameters) > 1:
        b = parameters[1].data.numpy()
    #print(W.shape,)

    W = np.transpose(W,[2,3,1,0])
    #print(W.shape)

    with tf.variable_scope("conv"):
        wi = tf.constant_initializer(W)
        if len(parameters) > 1:
            bi = tf.constant_initializer(b)
        W = tf.get_variable('weights',shape=W.shape,initializer=wi)#,
        if len(parameters) > 1:
            b = tf.get_variable('bias',shape=b.shape,initializer=bi)#,
    x = tf.nn.conv2d(kwargs['inp'],W,[1,c.stride[0],c.stride[1],1],padding)
    if len(parameters) > 1:
        x = tf.nn.bias_add(x,b)
    #print('conv ', x.get_shape())
    return x

def relu(c,**kwargs):
    return tf.nn.relu(kwargs['inp'])
def max_pool(c,**kwargs):
    padding = 'VALID' if c.padding is 0 else 'SAME'
    x = tf.nn.max_pool(kwargs['inp'],[1,c.kernel_size,c.kernel_size,1],strides=[1,c.stride,c.stride,1],padding=padding)
    #print('max ', x.get_shape())
    return x
def avg_pool(c,**kwargs):
    padding = 'VALID' if c.padding is 0 else 'SAME'
    x = tf.nn.avg_pool(kwargs['inp'],[1,c.kernel_size,c.kernel_size,1],strides=[1,c.stride,c.stride,1],padding=padding)
    return x
def dropout(c,**kwargs):
    #print('dropout')
    return kwargs['inp']
def fire_module(c,**kwargs):
    # couldn't figure out how to
    # automatically unravel it
    with tf.variable_scope("fire"):
        with tf.variable_scope("squeeze"):
            s = conv2d(c.squeeze,inp=kwargs['inp'])
            s = tf.nn.relu(s)
        with tf.variable_scope("e11"):
            e11 = conv2d(c.expand1x1,inp=s)
            e11 = tf.nn.relu(e11)
        with tf.variable_scope("e33"):
            e33 = conv2d(c.expand3x3,inp=s)
            e33 = tf.nn.relu(e33)
    x = tf.concat([e11,e33],3)
    #print('fire ',x.get_shape())
    return x

def seq_container(c,**kwargs):
    x = kwargs['inp']
    for c2 in enumerate(c.children()):
        c2_class = c2[1].__class__
        if c2_class in type_lookups:
            with tf.variable_scope('layer' + str(c2[0])):
                x = type_lookups[c2_class](c2[1],inp = x)
        else:
            unknown_class(c2[1])
            print(c2_class)
    return x
def batch_norm(c,**kwargs):
    print('batch_norm')
    return kwargs['inp']
type_lookups[torch.nn.modules.conv.Conv2d] = conv2d
type_lookups[torch.nn.modules.activation.ReLU] = relu
type_lookups[torch.nn.modules.container.Sequential] = seq_container
type_lookups[torch.nn.modules.pooling.MaxPool2d] = max_pool
type_lookups[torch.nn.modules.pooling.AvgPool2d] = avg_pool
type_lookups[torch.nn.modules.dropout.Dropout] = dropout
type_lookups[torchvision.models.squeezenet.Fire] = fire_module
type_lookups[torch.nn.modules.batchnorm.BatchNorm2d] = batch_norm
tf.reset_default_graph()
input_image = tf.placeholder('float',shape=[None,224,224,3])
if len([_ for _ in model.children()]) == 2:
    for idx,c in enumerate(model.children()):
        if idx is 0:
            with tf.variable_scope('features'):
                features = type_lookups[c.__class__](c,inp=input_image)
        elif idx is 1:
            with tf.variable_scope('classifier'):
                classifier = type_lookups[c.__class__](c,inp=features)
                classifier = tf.reshape(classifier,[-1,1000])
else:
    x = input_image
    for idx,c in enumerate(model.children()):
        x = type_lookups[c.__class__](c,inp=x)

In [None]:
classifier.get_shape()

In [None]:
from PIL import Image
from scipy.misc import imresize
import os

with open('labels.txt') as fp:
    labels = [c[:-2].split(':')[1] for c in fp.readlines()]
def get_img(filename):
    vec = np.array(Image.open(filename))
    vec = imresize(vec,(224,224)).astype(np.float32)/255.0
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    vec = (vec-mean)/std
    return vec
    
img_dir = '.'
img_names = [x for x in os.listdir(img_dir) if 'jpeg' in x.lower()]
imgs = [get_img(os.path.join(img_dir,x)) for x in img_names]

saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
scores = sess.run(classifier,feed_dict={input_image:np.array(imgs).reshape([-1,224,224,3])})
for idx,s in enumerate(np.argmax(scores,1)):
    print(img_names[idx],labels[s])

In [None]:
saver.save(sess, 'squeezenet.ckpt')

In [None]:
from torch.autograd import Variable
input_data = torch.FloatTensor(np.transpose(np.array(imgs),[0,3,1,2]))
model.eval()
pyt_scores = model(Variable(input_data))
scores_ref = pyt_scores.data.numpy()

In [None]:
def rel_error(x, y):
    return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))
print(rel_error(scores,scores_ref))