In [1]:
import numpy as np

# Load
tf_weights = np.load('my_file.npy', allow_pickle=True).item()

In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
from MnasNet_Functional import build_mnasnet_model

In [6]:
model = build_mnasnet_model('mnasnet-a1')
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [7]:
def verify_and_set(layer, block_id, tf_weights, tf_keys, curr_index, verbose):
    name = layer.name
    start = 'mnasnet-a1'
    block_name = f'mnas_blocks_{block_id}'
    
    if name.endswith('depthwise_conv'):
        key = '{}/mnas_net_model/{}/depthwise_conv2d'.format(start, block_name)
        forward = 1
    
    elif name.endswith('se_reduce_conv') or name.endswith('se_expand_conv'):
        key = '{}/mnas_net_model/{}/se/conv2d'.format(start, block_name)
        forward = 2
    
    elif name.endswith('stem_conv') or name.endswith('head_conv'):
        key = '{}/mnas_net_model/mnas_{}/conv2d'.format(start, name[:4])
        forward = 1
    
    elif name.endswith('stem_conv'):
        key = '{}/mnas_net_model/mnas_stem/conv2d'.format(start)
        forward = 1
    
    elif name.endswith('conv'):
        key = '{}/mnas_net_model/{}/conv2d'.format(start, block_name)
        forward = 1

    elif name.endswith('stem_conv_BN') or name.endswith('head_conv_BN'):
        key = '{}/mnas_{}/batch_normalization'.format(start, name[:4])
        forward = 4

    elif name.endswith('BN'):
        key = '{}/{}/batch_normalization'.format(start, block_name)
        forward = 4
    
    elif name == 'FC':
        key = '{}/mnas_net_model/mnas_head/dense'.format(start)
        forward = 2

    else:
        if layer.variables != []:
            raise ValueError(f'Layer "{name}" is not supported')
        return curr_index
    
    weights = []
    for i in range(curr_index, curr_index+forward):
        if not tf_keys[i].startswith(key):
            msg = 'For layer={}, an exception occurred\n'
            msg += "\ttf_index:\t{}\n\taccess_key:\t{}\n\treal_key:\t{}"
            raise ValueError(msg.format(name, i, key, tf_keys[i]))
        weights.append(tf_weights[tf_keys[i]])
    layer.set_weights(weights)  

    if verbose:
        print(f'Processesd "{name}"')
    return curr_index+forward


def keras_set_weights_from_tf_model(model, tf_weights, verbose=False):
    tf_index = 1
    block_id = 0
    b_chars = ('0',) * 2
    tf_keys = list(tf_weights.keys())
    
    for layer in model.layers:
        name = layer.name
        if name.startswith('block'):
            if b_chars != (name[6], name[12]):
                block_id += 1
                b_chars = (name[6], name[12])
        
        tf_index = verify_and_set(
            layer,
            block_id,
            tf_weights,
            tf_keys,
            tf_index,
            verbose
            )      

In [9]:
keras_set_weights_from_tf_model(model, tf_weights);

In [11]:
from IPython import display
import pylab
import PIL
import numpy as np
filename = 'panda.jpg'
img = np.array(PIL.Image.open(filename).resize((224, 224))).astype(np.float)

In [8]:
# MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
# STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]

# img2 = np.copy(img)
# img2 -= MEAN_RGB
# img2 /= STDDEV_RGB

In [13]:
from scipy.special import softmax

In [14]:
import common.imagenet as imagenet
import keras

logits = model.predict(img[np.newaxis,...])
top_class = np.argmax(logits)
probs = softmax(logits)

print("Top class: ", top_class, " with Probability= ", probs[0][top_class])
label_map = imagenet.create_readable_names_for_imagenet_labels()  
for idx, label_id in enumerate(reversed(list(np.argsort(probs)[0][-5:]))):
  print("Top %d Prediction: %d, %s, probs=%f" % (idx+1, label_id, label_map[label_id], probs[0][label_id]))
  

Top class:  388  with Probability=  0.8776357
Top 1 Prediction: 388, lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens, probs=0.877636
Top 2 Prediction: 245, Tibetan mastiff, probs=0.002865
Top 3 Prediction: 384, Madagascar cat, ring-tailed lemur, Lemur catta, probs=0.002584
Top 4 Prediction: 296, American black bear, black bear, Ursus americanus, Euarctos americanus, probs=0.001733
Top 5 Prediction: 222, Irish water spaniel, probs=0.001599
