# Copy weights form TF Hub
This code is based on https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/biggan_generation_with_tf_hub.ipynb

In [None]:
# module_path = 'https://tfhub.dev/deepmind/biggan-128/2'  # 128x128 BigGAN
module_path = 'https://tfhub.dev/deepmind/biggan-256/2'  # 256x256 BigGAN
# module_path = 'https://tfhub.dev/deepmind/biggan-512/2'  # 512x512 BigGAN

In [None]:
import io
import IPython.display
import numpy as np
import PIL.Image
from scipy.stats import truncnorm
import tensorflow as tf
import tensorflow_hub as hub

## Load a BigGAN generator module from TF Hub

In [None]:
tf.reset_default_graph()
print('Loading BigGAN module from:', module_path)
module = hub.Module(module_path)
inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)
          for k, v in module.get_input_info_dict().items()}
output = module(inputs)

print()
print ('Inputs:\n', '\n'.join('  {}: {}'.format(*kv) for kv in inputs.items()))
print()
print('Output:', output)

## Define some functions for sampling and displaying BigGAN images

In [None]:
input_z = inputs['z']
input_y = inputs['y']
input_trunc = inputs['truncation']

dim_z = input_z.shape.as_list()[1]
vocab_size = input_y.shape.as_list()[1]

def truncated_z_sample(batch_size, truncation=1., seed=None):
    state = None if seed is None else np.random.RandomState(seed)
    values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state)
    return truncation * values

def one_hot(index, vocab_size=vocab_size):
    index = np.asarray(index)
    if len(index.shape) == 0:
        index = np.asarray([index])
    assert len(index.shape) == 1
    num = index.shape[0]
    output = np.zeros((num, vocab_size), dtype=np.float32)
    output[np.arange(num), index] = 1
    return output

def one_hot_if_needed(label, vocab_size=vocab_size):
    label = np.asarray(label)
    if len(label.shape) <= 1:
        label = one_hot(label, vocab_size)
    assert len(label.shape) == 2
    return label

## Create a TensorFlow session and initialize variables

In [None]:
initializer = tf.global_variables_initializer()
sess = tf.Session()
sess.run(initializer)

## Copy weight from tfhub

In [None]:
variables = module.variable_map

In [None]:
keys = variables.keys()
values = [variables[k] for k in keys]

In [None]:
label = np.asarray([0])
label = one_hot_if_needed(label, vocab_size)

noise = truncated_z_sample(1, 1, 0)
noise = np.asarray(noise)
feed_dict = {input_z: noise, input_y: label, input_trunc: 1}
weights = sess.run(values, feed_dict=feed_dict)

In [None]:
import pickle
weights_dict = {k:w for k, w in zip(keys, weights)}
with open("BIGGAN_weights.pkl", "wb") as f:
    pickle.dump(weights_dict, f)