# Convert TFLite model to PyTorch

This uses the model **face_detection_front.tflite** from [MediaPipe](https://github.com/google/mediapipe/tree/master/mediapipe/models).

Prerequisites:

1) Clone the MediaPipe repo:

```
git clone https://github.com/google/mediapipe.git
```

2) Install **flatbuffers**:

```
git clone https://github.com/google/flatbuffers.git
cmake -G "Unix Makefiles" -DCMAKE_BUILD_TYPE=Release
make -j

cd flatbuffers/python
python setup.py install
```

3) Clone the TensorFlow repo. We only need this to get the FlatBuffers schema files (I guess you could just download [schema.fbs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs)).

```
git clone https://github.com/tensorflow/tensorflow.git
```

4) Convert the schema files to Python files using **flatc**:

```
./flatbuffers/flatc --python tensorflow/tensorflow/lite/schema/schema.fbs
```

Now we can use the Python FlatBuffer API to read the TFLite file!

In [1]:
import os
import numpy as np
from collections import OrderedDict

## Get the weights from the TFLite file

Load the TFLite model using the FlatBuffers library:

In [2]:
from tflite import Model

data = open("./mediapipe/mediapipe/models/face_detection_front.tflite", "rb").read()
model = Model.Model.GetRootAsModel(data, 0)

In [3]:
subgraph = model.Subgraphs(0)
subgraph.Name()

b'facedetector-front.tflite.no_meta'

In [4]:
def get_shape(tensor):
    return [tensor.Shape(i) for i in range(tensor.ShapeLength())]

List all the tensors in the graph:

In [5]:
for i in range(0, subgraph.TensorsLength()):
    tensor = subgraph.Tensors(i)
    print("%3d %30s %d %2d %s" % (i, tensor.Name(), tensor.Type(), tensor.Buffer(), 
                                  get_shape(subgraph.Tensors(i))))

  0                       b'input' 0  0 [1, 128, 128, 3]
  1               b'conv2d/Kernel' 0  1 [24, 5, 5, 3]
  2                 b'conv2d/Bias' 0  2 [24]
  3                      b'conv2d' 0  0 [1, 64, 64, 24]
  4                  b'activation' 0  0 [1, 64, 64, 24]
  5     b'depthwise_conv2d/Kernel' 0  3 [1, 3, 3, 24]
  6       b'depthwise_conv2d/Bias' 0  4 [24]
  7            b'depthwise_conv2d' 0  0 [1, 64, 64, 24]
  8             b'conv2d_1/Kernel' 0  5 [24, 1, 1, 24]
  9               b'conv2d_1/Bias' 0  6 [24]
 10                    b'conv2d_1' 0  0 [1, 64, 64, 24]
 11                         b'add' 0  0 [1, 64, 64, 24]
 12                b'activation_1' 0  0 [1, 64, 64, 24]
 13   b'depthwise_conv2d_1/Kernel' 0  7 [1, 3, 3, 24]
 14     b'depthwise_conv2d_1/Bias' 0  8 [24]
 15          b'depthwise_conv2d_1' 0  0 [1, 64, 64, 24]
 16             b'conv2d_2/Kernel' 0  9 [28, 1, 1, 24]
 17               b'conv2d_2/Bias' 0 10 [28]
 18                    b'conv2d_2' 0  0 [1, 64, 64, 28

Make a look-up table that lets us get the tensor index based on the tensor name:

In [6]:
tensor_dict = {(subgraph.Tensors(i).Name().decode("utf8")): i 
               for i in range(subgraph.TensorsLength())}

Grab only the tensors that represent weights and biases.

In [7]:
parameters = {}
for i in range(subgraph.TensorsLength()):
    tensor = subgraph.Tensors(i)
    if tensor.Buffer() > 0:
        name = tensor.Name().decode("utf8")
        parameters[name] = tensor.Buffer()

len(parameters)

85

The buffers are simply arrays of bytes. As the docs say,

> The data_buffer itself is an opaque container, with the assumption that the
> target device is little-endian. In addition, all builtin operators assume
> the memory is ordered such that if `shape` is [4, 3, 2], then index
> [i, j, k] maps to `data_buffer[i*3*2 + j*2 + k]`.

For weights and biases, we need to interpret every 4 bytes as being as float. On my machine, the native byte ordering is already little-endian so we don't need to do anything special for that.

In [8]:
def get_weights(tensor_name):
    i = tensor_dict[tensor_name]
    tensor = subgraph.Tensors(i)
    buffer = tensor.Buffer()
    shape = get_shape(tensor)
    assert(tensor.Type() == 0)  # FLOAT32
    
    W = model.Buffers(buffer).DataAsNumpy()
    W = W.view(dtype=np.float32)
    W = W.reshape(shape)
    return W

In [9]:
W = get_weights("conv2d/Kernel")
b = get_weights("conv2d/Bias")
W.shape, b.shape

((24, 5, 5, 3), (24,))

Now we can get the weights for all the layers and copy them into our PyTorch model.

## Convert the weights to PyTorch format

In [10]:
import torch
from blazeface import BlazeFace

In [11]:
net = BlazeFace()

In [12]:
net

BlazeFace(
  (backbone1): Sequential(
    (0): Conv2d(3, 24, kernel_size=(5, 5), stride=(2, 2))
    (1): ReLU(inplace)
    (2): BlazeBlock(
      (convs): Sequential(
        (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)
        (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))
      )
      (act): ReLU(inplace)
    )
    (3): BlazeBlock(
      (convs): Sequential(
        (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)
        (1): Conv2d(24, 28, kernel_size=(1, 1), stride=(1, 1))
      )
      (act): ReLU(inplace)
    )
    (4): BlazeBlock(
      (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (convs): Sequential(
        (0): Conv2d(28, 28, kernel_size=(3, 3), stride=(2, 2), groups=28)
        (1): Conv2d(28, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (act): ReLU(inplace)
    )
    (5): BlazeBlock(
      (convs): Sequential(
        (0): Conv2d(32, 32, ke

Make a lookup table that maps the layer names between the two models. We're going to assume here that the tensors will be in the same order in both models. If not, we should get an error because shapes don't match.

In [13]:
probable_names = []
for i in range(0, subgraph.TensorsLength()):
    tensor = subgraph.Tensors(i)
    if tensor.Buffer() > 0 and tensor.Type() == 0:
        probable_names.append(tensor.Name().decode("utf-8"))
        
probable_names[:5]

['conv2d/Kernel',
 'conv2d/Bias',
 'depthwise_conv2d/Kernel',
 'depthwise_conv2d/Bias',
 'conv2d_1/Kernel']

In [14]:
convert = {}
i = 0
for name, params in net.state_dict().items():
    convert[name] = probable_names[i]
    i += 1

Copy the weights into the layers.

Note that the ordering of the weights is different between PyTorch and TFLite, so we need to transpose them.

Convolution weights:

    TFLite:  (out_channels, kernel_height, kernel_width, in_channels)
    PyTorch: (out_channels, in_channels, kernel_height, kernel_width)

Depthwise convolution weights:

    TFLite:  (1, kernel_height, kernel_width, channels)
    PyTorch: (channels, 1, kernel_height, kernel_width)

In [15]:
new_state_dict = OrderedDict()

for dst, src in convert.items():
    W = get_weights(src)
    print(dst, src, W.shape, net.state_dict()[dst].shape)

    if W.ndim == 4:
        if W.shape[0] == 1:
            W = W.transpose((3, 0, 1, 2))  # depthwise conv
        else:
            W = W.transpose((0, 3, 1, 2))  # regular conv
    
    new_state_dict[dst] = torch.from_numpy(W)

backbone1.0.weight conv2d/Kernel (24, 5, 5, 3) torch.Size([24, 3, 5, 5])
backbone1.0.bias conv2d/Bias (24,) torch.Size([24])
backbone1.2.convs.0.weight depthwise_conv2d/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])
backbone1.2.convs.0.bias depthwise_conv2d/Bias (24,) torch.Size([24])
backbone1.2.convs.1.weight conv2d_1/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])
backbone1.2.convs.1.bias conv2d_1/Bias (24,) torch.Size([24])
backbone1.3.convs.0.weight depthwise_conv2d_1/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])
backbone1.3.convs.0.bias depthwise_conv2d_1/Bias (24,) torch.Size([24])
backbone1.3.convs.1.weight conv2d_2/Kernel (28, 1, 1, 24) torch.Size([28, 24, 1, 1])
backbone1.3.convs.1.bias conv2d_2/Bias (28,) torch.Size([28])
backbone1.4.convs.0.weight depthwise_conv2d_2/Kernel (1, 3, 3, 28) torch.Size([28, 1, 3, 3])
backbone1.4.convs.0.bias depthwise_conv2d_2/Bias (28,) torch.Size([28])
backbone1.4.convs.1.weight conv2d_3/Kernel (32, 1, 1, 28) torch.Size([32, 28, 1, 1])
backb

In [16]:
net.load_state_dict(new_state_dict, strict=True)

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

No errors? Then the conversion was successful!

## Save the checkpoint

In [17]:
torch.save(net.state_dict(), "blazeface.pth")