Reproduce example from Medium article [link](https://towardsdatascience.com/converting-a-simple-deep-learning-model-from-pytorch-to-tensorflow-b6b353351f5d).

In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import numpy as np

import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import onnx
from onnx_tf.backend import prepare
import tensorflow as tf
import pickle



The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.







## Check `onnx` and `onnx-tf` version

In [3]:
import subprocess
pip_packages = subprocess.run(['pip', 'freeze'], stdout=subprocess.PIPE)
pip_packages_list = pip_packages.stdout.decode('utf-8').strip().split('\n')

In [4]:
print(list(filter(lambda x: 'onnx' in x, pip_packages_list)))

if 'onnx==1.5.0' not in pip_packages_list:
    print("onnx version is incorrect!")
    print("Target version: 1.5.0")
    print("Re-Installing packages...")
    os.system('pip install onnx==1.5.0')
else:
    print("Correct onnx version!")
    
if 'onnx-tf==1.2.1' not in pip_packages_list:
    print("onnx-tf version is incorrect!")
    print("Target version: 1.2.1")
    print("Re-Installing packages...")
    os.system('pip install onnx-tf==1.2.1')
else:
    print("Correct onnx-tf version!")
    
!pip freeze |grep onnx

['onnx==1.5.0', 'onnx-tf==1.2.1']
Correct onnx version!
Correct onnx-tf version!
onnx==1.5.0
onnx-tf==1.2.1


# Global setup

In [5]:
train_size = 8000
test_size = 2000

input_size = 20
hidden_sizes = [50, 50]
output_size = 1
num_classes = 2

# Set device

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device used:', device)

Device used: cpu


# Generate simulated data

In [7]:
if not os.path.exists('data/X_train.pk'):
    X_train = np.random.randn(train_size, input_size).astype(np.float32)
    with open('data/X_train.pk', 'wb') as p:
        pickle.dump(X_train, p)
else: 
    with open('data/X_train.pk', 'rb') as p:
        X_train = pickle.load(p)
        
if not os.path.exists('data/X_test.pk'):
    X_test = np.random.randn(test_size, input_size).astype(np.float32)
    with open('data/X_test.pk', 'wb') as p:
        pickle.dump(X_test, p)
else: 
    with open('data/X_test.pk', 'rb') as p:
        X_test = pickle.load(p)

if not os.path.exists('data/y_train.pk'):
    y_train = np.random.randint(num_classes, size=train_size)
    with open('data/y_train.pk', 'wb') as p:
        pickle.dump(y_train, p)
else: 
    with open('data/y_train.pk', 'rb') as p:
        y_train = pickle.load(p)
    
if not os.path.exists('data/y_test.pk'):
    y_test = np.random.randint(num_classes, size=train_size)
    with open('data/y_test.pk', 'wb') as p:
        pickle.dump(y_test, p)
else: 
    with open('data/y_test.pk', 'rb') as p:
        y_test = pickle.load(p)


print('Shape of X_train:', X_train.shape)
print('Shape of X_train:', X_test.shape)
print('Shape of y_train:', y_train.shape)
print('Shape of y_test:', y_test.shape)

Shape of X_train: (8000, 20)
Shape of X_train: (2000, 20)
Shape of y_train: (8000,)
Shape of y_test: (8000,)


# Build or load a pytorch toy model

In [8]:
FORCE_RETRAIN = False

In [9]:
class SimpleModel(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(SimpleModel, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.fcs = []  # List of fully connected layers
        in_size = input_size
        
        for i, next_size in enumerate(hidden_sizes):
            fc = nn.Linear(in_features=in_size, out_features=next_size)
            in_size = next_size
            self.__setattr__('fc{}'.format(i), fc)  # set name for each fullly connected layer
            self.fcs.append(fc)
            
        self.last_fc = nn.Linear(in_features=in_size, out_features=output_size)
        
    def forward(self, x):
        for i, fc in enumerate(self.fcs):
            x = fc(x)
            x = nn.ReLU()(x)
        out = self.last_fc(x)
        return nn.Sigmoid()(out)
    
class SimpleDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
    
if os.path.exists('models/model_simple.pt') == False or FORCE_RETRAIN == True:
    print("Building model...")
    
    # Create DataLoaders for training and test set, for batch training and evaluation
    train_loader = DataLoader(dataset=SimpleDataset(X_train, y_train), batch_size=8, shuffle=True)
    test_loader = DataLoader(dataset=SimpleDataset(X_test, y_test), batch_size=8, shuffle=False)
    
    # Initialize the model and set device to be used
    model_pytorch = SimpleModel(input_size=input_size, hidden_sizes=hidden_sizes, output_size=output_size)
    model_pytorch = model_pytorch.to(device)
    
    # Set binary cross entropy loss since 2 classes only
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model_pytorch.parameters(), lr=1e-3)
       
    # Train model
    num_epochs = 20

    time_start = time.time()

    for epoch in range(num_epochs):
        model_pytorch.train() # Turn on training mode

        train_loss_total = 0 # A flag to record the total loss for each epoch

        for data, target in train_loader:
            data, target = data.to(device), target.float().to(device) # Prepare data (features/target)
            optimizer.zero_grad() # Initialize paramters
            output = model_pytorch(data) # Forward propogation
            train_loss = criterion(output, target) # Compute the loss
            train_loss.backward() # Back-propogate the loss
            optimizer.step() # Update the weights/biases
            train_loss_total += train_loss.item() * data.size(0) # Add up the loss for each batch

        print('Epoch {} completed. Train loss is {:.3f}'.format(epoch + 1, train_loss_total / train_size))
    print('Time taken to completed {} epochs: {:.2f} minutes'.format(num_epochs, (time.time() - time_start) / 60))
    
    # Save model
    torch.save(model_pytorch.state_dict(), './models/model_simple.pt')
    print('Model has been saved to ./models/model_simple.pt.')
else:
    # Load model
    print("Loading model...")
    model_pytorch = SimpleModel(input_size=input_size, hidden_sizes=hidden_sizes, output_size=output_size)
    model_pytorch.load_state_dict(torch.load('./models/model_simple.pt'))
    
# Display model
display(model_pytorch)

Loading model...


SimpleModel(
  (fc0): Linear(in_features=20, out_features=50, bias=True)
  (fc1): Linear(in_features=50, out_features=50, bias=True)
  (last_fc): Linear(in_features=50, out_features=1, bias=True)
)

## Note: 
`model.train()` and `model.eval()` will affect the process for only particular models. Check the [source code](https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.train) for more information.

# Evalulate model

In [10]:
eval_pytorch_model = False

In [11]:
if eval_pytorch_model:
    model_pytorch.eval()

    test_loss_total = 0
    total_num_corrects = 0
    threshold = 0.5
    time_start = time.time()

    for data, target in test_loader:
        data, target = data.to(device), target.float().to(device)
        optimizer.zero_grad()
        output = model_pytorch(data)
        train_loss = criterion(output, target)
        train_loss.backward()
        optimizer.step()
        train_loss_total += train_loss.item() * data.size(0)

        pred = (output >= threshold).view_as(target)  # to make pred have same shape as target
        num_correct = torch.sum(pred == target.byte()).item()
        total_num_corrects += num_correct

    print('Evaluation completed. Test loss is {:.3f}'.format(test_loss_total / test_size))
    print('Test accuracy is {:.3f}'.format(total_num_corrects / test_size))
    print('Time taken to complete evaluation: {:.2f} seconds'.format((time.time() - time_start)))

# Converting the model to ONNX

In [12]:
FORCE_to_ONNX = False

In [13]:
if os.path.exists('./models/model_simple.onnx') == False or FORCE_to_ONNX == True:
    # Single pass of dummy variable required
    dummy_input = torch.from_numpy(X_test[0].reshape(1, -1)).float().to(device)
    dummy_output = model_pytorch(dummy_input)
    print(dummy_output)
    
    # Export to ONNX format
    torch.onnx.export(
        model_pytorch, dummy_input, 
        './models/model_simple.onnx', input_names=['input'], output_names=['output']
    )

# Load ONNX model and convert to TensorFlow format

In [14]:
FORCE_to_TF = False

In [15]:
if os.path.exists('./models/model_simple.pb') == False or FORCE_to_TF == True:
    model_onnx = onnx.load('./models/model_simple.onnx')

    tf_rep = prepare(model_onnx)

    # Print out tensors and placeholders in model (helpful during inference in TensorFlow)
    print(tf_rep.tensor_dict)

    # Export model as .pb file
    tf_rep.export_graph('./models/model_simple.pb')

# Doing inference in TensorFlow

In [16]:
def load_pb(path_to_pb):
    with tf.gfile.GFile(path_to_pb, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph

In [17]:
tf_graph = load_pb('./models/model_simple.pb')
sess = tf.Session(graph=tf_graph)

In [18]:
tf_graph.get_operations()

[<tf.Operation 'Const' type=Const>,
 <tf.Operation 'Const_1' type=Const>,
 <tf.Operation 'Const_2' type=Const>,
 <tf.Operation 'Const_3' type=Const>,
 <tf.Operation 'Const_4' type=Const>,
 <tf.Operation 'Const_5' type=Const>,
 <tf.Operation 'input' type=Placeholder>,
 <tf.Operation 'flatten/Reshape/shape' type=Const>,
 <tf.Operation 'flatten/Reshape' type=Reshape>,
 <tf.Operation 'transpose/perm' type=Const>,
 <tf.Operation 'transpose' type=Transpose>,
 <tf.Operation 'MatMul' type=MatMul>,
 <tf.Operation 'mul/x' type=Const>,
 <tf.Operation 'mul' type=Mul>,
 <tf.Operation 'mul_1/x' type=Const>,
 <tf.Operation 'mul_1' type=Mul>,
 <tf.Operation 'add' type=Add>,
 <tf.Operation 'Relu' type=Relu>,
 <tf.Operation 'flatten_1/Reshape/shape' type=Const>,
 <tf.Operation 'flatten_1/Reshape' type=Reshape>,
 <tf.Operation 'transpose_1/perm' type=Const>,
 <tf.Operation 'transpose_1' type=Transpose>,
 <tf.Operation 'MatMul_1' type=MatMul>,
 <tf.Operation 'mul_2/x' type=Const>,
 <tf.Operation 'mul_2' t

In [19]:
output_tensor = tf_graph.get_tensor_by_name('Sigmoid:0')
input_tensor = tf_graph.get_tensor_by_name('input:0')

In [20]:
input_tensor

<tf.Tensor 'input:0' shape=(1, 20) dtype=float32>

In [21]:
output_tensor

<tf.Tensor 'Sigmoid:0' shape=(1, 1) dtype=float32>

In [43]:
total_test_cases = 0
correct_cases = 0
for test_id in range(2000):
    total_test_cases += 1
    tf_output = sess.run(output_tensor, feed_dict={input_tensor: X_test[test_id].reshape(1, -1)})
    pt_output = model_pytorch(torch.from_numpy(X_test[test_id].reshape(1, -1)).float())
    
    tf_data = tf_output[0][0]
    pt_data = pt_output.data.numpy()[0][0]
    
    if tf_data - pt_data < 1e-5:
        correct_cases += 1
        
print("# of total cases: %d" %total_test_cases)
print("# of correct cases: %d" %correct_cases)

# of total cases: 2000
# of correct cases: 2000
