# Introduction

* A script to load models and their learnable parameters (weights, biases, etc.) which are trained using:
    1. Keras (or TensorFlow)
    2. PyTorch
* References: 
    - https://www.adrian.idv.hk/2020-12-31-torch2tf/
    - https://stackoverflow.com/questions/42703500/best-way-to-save-a-trained-model-in-pytorch


**By MGM, ORNL**

2022 May 15

In [1]:
import numpy as np

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'    # to hide TF info messages
import tensorflow as tf

import torch

# Model names

In [2]:
Ntot            = np.int(1e6) # number of total data points
nneurons        = 10     # number of neurons in MLP (***)
fname_NNkeras   = f'Data_models/supercell_micro_Keras_SingleMLP_{str(Ntot)}_Nneu{str(nneurons)}'
model_name      = 'SingleMLP'        # model name: 'SingleMLP' 'ResNet' 'DenseNet'
fname_NNpytorch = f'Data_models/supercell_micro_PyTorch_{model_name}_{str(Ntot)}_Nneu{str(nneurons)}'

# Keras (or TensorFlow) model

In [3]:
# load json and create model
json_file = open(f'{fname_NNkeras}_modelparameters.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = tf.keras.models.model_from_json(loaded_model_json)
# load weights into new model
loaded_model.load_weights(f'{fname_NNkeras}_modelparameters.h5')

for i,l in enumerate(loaded_model.layers):
    print("L{}:\tParameters: {}".format(i, ", ".join(str(w.shape) for w in loaded_model.layers[i].weights)))
    print(f"{l}\n\t{loaded_model.layers[i].weights}\n")

L0:	Parameters: (12, 10), (10,)
<tensorflow.python.keras.layers.core.Dense object at 0x7eff5fa8df70>
	[<tf.Variable 'dense/kernel:0' shape=(12, 10) dtype=float32, numpy=
array([[ 4.4839606e-02,  3.4142290e-03,  6.2647603e-02, -5.2080089e-03,
         6.5400094e-02,  1.1367138e-02, -1.2471851e-02,  2.2232479e-01,
        -1.2108028e-02,  9.6040681e-02],
       [ 6.2061793e-01, -2.2211903e-01,  6.1980444e-01,  1.8864895e-01,
        -2.6342615e-01,  5.2737045e-01,  2.2463527e-01, -4.9450159e-01,
         1.9757085e-01,  5.8157426e-01],
       [ 1.8169809e-02,  7.4737710e-03,  8.9990035e-02, -5.0105426e-02,
         1.3751762e-02,  9.8412503e-03, -8.3880676e-03,  2.3363391e-01,
         1.7302802e-02,  1.0582845e-01],
       [-6.7542037e-03, -9.7089246e-02,  9.3365818e-02, -1.4075327e-02,
         7.7508472e-02, -1.2855682e-01, -2.3073874e-02,  2.7188972e-02,
         1.3528147e-02,  3.5556637e-02],
       [-3.4937628e-02, -7.3892653e-01,  2.7074626e-01,  5.8355039e-01,
        -2.2066720

# PyTorch model

* `model.stat_dict()` comprise of the learnable parameters of the model
* 2 methods to access model parameters, stat_dict
* **NOTE**: PyTorch saves weight tensors as the transpose of that of TensorFlow

In [4]:
# Method 1: from full model
# Load full model using TorchScript format
model = torch.jit.load(f'{fname_NNpytorch}_Cpp.pt')
# Print model's state_dict
print("Model's state_dict (from full model):")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

Model's state_dict (from full model):
hidden1.weight 	 torch.Size([10, 12])
hidden1.bias 	 torch.Size([10])
hidden3.weight 	 torch.Size([4, 10])
hidden3.bias 	 torch.Size([4])


In [5]:
# Method 2: just from the stat_dict
# Load only model's state_dict to access learnable parameters
model_state_dict = torch.load(f'{fname_NNpytorch}_modelparameters.pkl')
# Print model's state_dict
print("Model's state_dict (from state_dict only):")
for param_tensor in model_state_dict:
    print(param_tensor, "\t", model_state_dict[param_tensor])

Model's state_dict (from state_dict only):
hidden1.weight 	 tensor([[-0.4795, -0.4573, -0.4388,  0.9110, -0.6623,  0.0104, -0.3194, -0.6936,
         -0.5155, -1.0117,  1.9673, -0.3153],
        [ 0.1120, -0.3004,  0.5225, -0.1805,  0.5088,  0.1800,  0.4128, -0.1120,
         -0.4814, -0.2813,  0.3594,  0.2713],
        [ 0.3613,  0.4166,  0.2360, -0.7671, -0.0962,  0.1965,  0.2009, -0.0219,
         -0.2966,  0.3146,  0.5174,  0.1808],
        [-0.2181, -0.8262, -0.1929,  1.0687, -0.0543, -0.5197,  0.0689, -0.6304,
          0.0523, -0.6432,  1.6410, -0.2651],
        [ 0.0706,  0.3954, -0.2650, -0.1155,  0.2107, -0.1223, -0.2126,  0.5290,
          0.4203,  0.5050,  0.5425, -0.1520],
        [ 0.5908, -0.6495, -0.7403, -1.6795,  0.3208,  0.5091, -0.9094,  1.0736,
         -0.6483,  0.0156, -0.9285,  0.8217],
        [ 0.3941, -0.0728,  0.2308, -0.4549,  0.2599, -0.0496,  0.2809, -0.4778,
         -0.0686,  0.5648,  0.3378, -0.0847],
        [ 0.3825, -0.4405,  0.4552,  0.3143, -0.159