# Convert TFLite model to PyTorch

This uses the model **face_detection_front.tflite** from [MediaPipe](https://github.com/google/mediapipe/tree/master/mediapipe/models).

Using conda environnement:
```
conda create -c pytorch -c conda-forge -n BlazeConv 'pytorch=1.6' jupyter opencv matplotlib
```
```
conda activate BlazeConv
```
```
pip install tflite
```

## Convert front camera TFLite model

In [1]:
import os
import numpy as np
from collections import OrderedDict

### Get the weights from the TFLite file

Load the TFLite model using the FlatBuffers library:

In [2]:
!wget -N https://github.com/google/mediapipe/raw/master/mediapipe/models/face_detection_front.tflite

--2021-02-09 23:17:46--  https://github.com/google/mediapipe/raw/master/mediapipe/models/face_detection_front.tflite
Résolution de github.com (github.com)… 140.82.121.3
Connexion à github.com (github.com)|140.82.121.3|:443… connecté.
requête HTTP transmise, en attente de la réponse… 302 Found
Emplacement : https://raw.githubusercontent.com/google/mediapipe/master/mediapipe/models/face_detection_front.tflite [suivant]
--2021-02-09 23:17:46--  https://raw.githubusercontent.com/google/mediapipe/master/mediapipe/models/face_detection_front.tflite
Résolution de raw.githubusercontent.com (raw.githubusercontent.com)… 151.101.120.133
Connexion à raw.githubusercontent.com (raw.githubusercontent.com)|151.101.120.133|:443… connecté.
requête HTTP transmise, en attente de la réponse… 200 OK
Taille : 229032 (224K) [application/octet-stream]
Enregistre : «face_detection_front.tflite»


En-tête de dernière modification manquant — horodatage arrêté.
2021-02-09 23:17:46 (18,7 MB/s) - «face_detection_fro

In [3]:
from tflite import Model

front_data = open("./face_detection_front.tflite", "rb").read()
front_model = Model.GetRootAsModel(front_data, 0)

In [4]:
front_subgraph = front_model.Subgraphs(0)
front_subgraph.Name()

b'keras2tflite_facedetector-front.tflite.generated'

In [5]:
def get_shape(tensor):
    return [tensor.Shape(i) for i in range(tensor.ShapeLength())]

List all the tensors in the graph:

In [6]:
def print_graph(graph):
    for i in range(0, graph.TensorsLength()):
        tensor = graph.Tensors(i)
        print("%3d %30s %d %2d %s" % (i, tensor.Name(), tensor.Type(), tensor.Buffer(), 
                                      get_shape(graph.Tensors(i))))

print_graph(front_subgraph)

  0                       b'input' 0  0 [1, 128, 128, 3]
  1               b'conv2d/Kernel' 1  1 [24, 5, 5, 3]
  2                 b'conv2d/Bias' 1  2 [24]
  3                      b'conv2d' 0  0 [1, 64, 64, 24]
  4                  b'activation' 0  0 [1, 64, 64, 24]
  5     b'depthwise_conv2d/Kernel' 1  3 [1, 3, 3, 24]
  6       b'depthwise_conv2d/Bias' 1  4 [24]
  7            b'depthwise_conv2d' 0  0 [1, 64, 64, 24]
  8             b'conv2d_1/Kernel' 1  5 [24, 1, 1, 24]
  9               b'conv2d_1/Bias' 1  6 [24]
 10                    b'conv2d_1' 0  0 [1, 64, 64, 24]
 11         b'add__xeno_compat__1' 0  0 [1, 64, 64, 24]
 12                b'activation_1' 0  0 [1, 64, 64, 24]
 13   b'depthwise_conv2d_1/Kernel' 1  7 [1, 3, 3, 24]
 14     b'depthwise_conv2d_1/Bias' 1  8 [24]
 15          b'depthwise_conv2d_1' 0  0 [1, 64, 64, 24]
 16             b'conv2d_2/Kernel' 1  9 [28, 1, 1, 24]
 17               b'conv2d_2/Bias' 1 10 [28]
 18                    b'conv2d_2' 0  0 [1, 64, 64, 28

Make a look-up table that lets us get the tensor index based on the tensor name:

In [7]:
front_tensor_dict = {(front_subgraph.Tensors(i).Name().decode("utf8")): i 
               for i in range(front_subgraph.TensorsLength())}

Grab only the tensors that represent weights and biases.

In [8]:
def get_parameters(graph):
    parameters = {}
    for i in range(graph.TensorsLength()):
        tensor = graph.Tensors(i)
        if tensor.Buffer() > 0:
            name = tensor.Name().decode("utf8")
            parameters[name] = tensor.Buffer()
    return parameters

front_parameters = get_parameters(front_subgraph)
len(front_parameters)

85

The buffers are simply arrays of bytes. As the docs say,

> The data_buffer itself is an opaque container, with the assumption that the
> target device is little-endian. In addition, all builtin operators assume
> the memory is ordered such that if `shape` is [4, 3, 2], then index
> [i, j, k] maps to `data_buffer[i*3*2 + j*2 + k]`.

For weights and biases, we need to interpret every 4 bytes as being as float. On my machine, the native byte ordering is already little-endian so we don't need to do anything special for that.

Found some weights and biases stored as float16 instead of float32 corresponding to Type 1 instead of 0.

In [9]:
def get_weights(model, graph, tensor_dict, tensor_name):
    i = tensor_dict[tensor_name]
    tensor = graph.Tensors(i)
    buffer = tensor.Buffer()
    shape = get_shape(tensor)
    assert(tensor.Type() == 0 or tensor.Type() == 1)  # FLOAT32
    
    W = model.Buffers(buffer).DataAsNumpy()
    if tensor.Type() == 0:
        W = W.view(dtype=np.float32)
    elif tensor.Type() == 1:
        W = W.view(dtype=np.float16)
    W = W.reshape(shape)
    return W

In [10]:
W = get_weights(front_model, front_subgraph, front_tensor_dict, "conv2d/Kernel")
b = get_weights(front_model, front_subgraph, front_tensor_dict, "conv2d/Bias")
W.shape, b.shape

((24, 5, 5, 3), (24,))

Now we can get the weights for all the layers and copy them into our PyTorch model.

### Convert the weights to PyTorch format

In [11]:
import torch
from blazeface import BlazeFace

In [12]:
front_net = BlazeFace()

In [13]:
front_net

BlazeFace(
  (backbone1): Sequential(
    (0): Conv2d(3, 24, kernel_size=(5, 5), stride=(2, 2))
    (1): ReLU(inplace=True)
    (2): BlazeBlock(
      (convs): Sequential(
        (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)
        (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))
      )
      (act): ReLU(inplace=True)
    )
    (3): BlazeBlock(
      (convs): Sequential(
        (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)
        (1): Conv2d(24, 28, kernel_size=(1, 1), stride=(1, 1))
      )
      (act): ReLU(inplace=True)
    )
    (4): BlazeBlock(
      (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (convs): Sequential(
        (0): Conv2d(28, 28, kernel_size=(3, 3), stride=(2, 2), groups=28)
        (1): Conv2d(28, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (act): ReLU(inplace=True)
    )
    (5): BlazeBlock(
      (convs): Sequential(
        (0

Make a lookup table that maps the layer names between the two models. We're going to assume here that the tensors will be in the same order in both models. If not, we should get an error because shapes don't match.

In [14]:
def get_probable_names(graph):
    probable_names = []
    for i in range(0, graph.TensorsLength()):
        tensor = graph.Tensors(i)
        if tensor.Buffer() > 0 and (tensor.Type() == 0 or tensor.Type() == 1):
            probable_names.append(tensor.Name().decode("utf-8"))
    return probable_names

front_probable_names = get_probable_names(front_subgraph)
        
front_probable_names[:5]

['conv2d/Kernel',
 'conv2d/Bias',
 'depthwise_conv2d/Kernel',
 'depthwise_conv2d/Bias',
 'conv2d_1/Kernel']

In [15]:
def get_convert(net, probable_names):
    convert = {}
    i = 0
    for name, params in net.state_dict().items():
        convert[name] = probable_names[i]
        i += 1
    return convert

front_convert = get_convert(front_net, front_probable_names)

Copy the weights into the layers.

Note that the ordering of the weights is different between PyTorch and TFLite, so we need to transpose them.

Convolution weights:

    TFLite:  (out_channels, kernel_height, kernel_width, in_channels)
    PyTorch: (out_channels, in_channels, kernel_height, kernel_width)

Depthwise convolution weights:

    TFLite:  (1, kernel_height, kernel_width, channels)
    PyTorch: (channels, 1, kernel_height, kernel_width)

In [16]:
def build_state_dict(model, graph, tensor_dict, net, convert):
    new_state_dict = OrderedDict()

    for dst, src in convert.items():
        W = get_weights(model, graph, tensor_dict, src)
        print(dst, src, W.shape, net.state_dict()[dst].shape)

        if W.ndim == 4:
            if W.shape[0] == 1:
                W = W.transpose((3, 0, 1, 2))  # depthwise conv
            else:
                W = W.transpose((0, 3, 1, 2))  # regular conv
    
        new_state_dict[dst] = torch.from_numpy(W)
    return new_state_dict

front_state_dict = build_state_dict(front_model, front_subgraph, front_tensor_dict, front_net, front_convert)

backbone1.0.weight conv2d/Kernel (24, 5, 5, 3) torch.Size([24, 3, 5, 5])
backbone1.0.bias conv2d/Bias (24,) torch.Size([24])
backbone1.2.convs.0.weight depthwise_conv2d/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])
backbone1.2.convs.0.bias depthwise_conv2d/Bias (24,) torch.Size([24])
backbone1.2.convs.1.weight conv2d_1/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])
backbone1.2.convs.1.bias conv2d_1/Bias (24,) torch.Size([24])
backbone1.3.convs.0.weight depthwise_conv2d_1/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])
backbone1.3.convs.0.bias depthwise_conv2d_1/Bias (24,) torch.Size([24])
backbone1.3.convs.1.weight conv2d_2/Kernel (28, 1, 1, 24) torch.Size([28, 24, 1, 1])
backbone1.3.convs.1.bias conv2d_2/Bias (28,) torch.Size([28])
backbone1.4.convs.0.weight depthwise_conv2d_2/Kernel (1, 3, 3, 28) torch.Size([28, 1, 3, 3])
backbone1.4.convs.0.bias depthwise_conv2d_2/Bias (28,) torch.Size([28])
backbone1.4.convs.1.weight conv2d_3/Kernel (32, 1, 1, 28) torch.Size([32, 28, 1, 1])
backb

  new_state_dict[dst] = torch.from_numpy(W)


In [17]:
front_net.load_state_dict(front_state_dict, strict=True)

<All keys matched successfully>

No errors? Then the conversion was successful!

### Save the checkpoint

In [18]:
torch.save(front_net.state_dict(), "blazeface.pth")

## Convert back camera TFLite model

In [19]:
!wget -N https://github.com/google/mediapipe/raw/master/mediapipe/models/face_detection_back.tflite

--2021-02-09 23:19:58--  https://github.com/google/mediapipe/raw/master/mediapipe/models/face_detection_back.tflite
Résolution de github.com (github.com)… 140.82.121.3
Connexion à github.com (github.com)|140.82.121.3|:443… connecté.
requête HTTP transmise, en attente de la réponse… 302 Found
Emplacement : https://raw.githubusercontent.com/google/mediapipe/master/mediapipe/models/face_detection_back.tflite [suivant]
--2021-02-09 23:19:58--  https://raw.githubusercontent.com/google/mediapipe/master/mediapipe/models/face_detection_back.tflite
Résolution de raw.githubusercontent.com (raw.githubusercontent.com)… 151.101.120.133
Connexion à raw.githubusercontent.com (raw.githubusercontent.com)|151.101.120.133|:443… connecté.
requête HTTP transmise, en attente de la réponse… 200 OK
Taille : 315332 (308K) [application/octet-stream]
Enregistre : «face_detection_back.tflite»


En-tête de dernière modification manquant — horodatage arrêté.
2021-02-09 23:19:58 (17,0 MB/s) - «face_detection_back.tf

In [20]:
back_data = open("./face_detection_back.tflite", "rb").read()
back_model = Model.GetRootAsModel(back_data, 0)
back_subgraph = back_model.Subgraphs(0)
back_subgraph.Name()

b'keras2tflite_facedetector-back.tflite.generated'

In [21]:
print_graph(back_subgraph)

  0                       b'input' 0  0 [1, 256, 256, 3]
  1               b'conv2d/Kernel' 1  1 [24, 5, 5, 3]
  2                 b'conv2d/Bias' 1  2 [24]
  3                      b'conv2d' 0  0 [1, 128, 128, 24]
  4                  b'activation' 0  0 [1, 128, 128, 24]
  5     b'depthwise_conv2d/Kernel' 1  3 [1, 3, 3, 24]
  6       b'depthwise_conv2d/Bias' 1  4 [24]
  7            b'depthwise_conv2d' 0  0 [1, 128, 128, 24]
  8             b'conv2d_1/Kernel' 1  5 [24, 1, 1, 24]
  9               b'conv2d_1/Bias' 1  6 [24]
 10                    b'conv2d_1' 0  0 [1, 128, 128, 24]
 11                         b'add' 0  0 [1, 128, 128, 24]
 12                b'activation_1' 0  0 [1, 128, 128, 24]
 13   b'depthwise_conv2d_1/Kernel' 1  7 [1, 3, 3, 24]
 14     b'depthwise_conv2d_1/Bias' 1  8 [24]
 15          b'depthwise_conv2d_1' 0  0 [1, 128, 128, 24]
 16             b'conv2d_2/Kernel' 1  9 [24, 1, 1, 24]
 17               b'conv2d_2/Bias' 1 10 [24]
 18                    b'conv2d_2' 0  0 

In [22]:
back_tensor_dict = {(back_subgraph.Tensors(i).Name().decode("utf8")): i 
               for i in range(back_subgraph.TensorsLength())}

In [23]:
back_parameters = get_parameters(back_subgraph)
len(back_parameters)

140

In [24]:
W = get_weights(back_model, back_subgraph, back_tensor_dict, "conv2d/Kernel")
b = get_weights(back_model, back_subgraph, back_tensor_dict, "conv2d/Bias")
W.shape, b.shape

((24, 5, 5, 3), (24,))

In [25]:
back_net = BlazeFace(back_model=True)

In [26]:
back_net

BlazeFace(
  (backbone): Sequential(
    (0): Conv2d(3, 24, kernel_size=(5, 5), stride=(2, 2))
    (1): ReLU(inplace=True)
    (2): BlazeBlock(
      (convs): Sequential(
        (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)
        (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))
      )
      (act): ReLU(inplace=True)
    )
    (3): BlazeBlock(
      (convs): Sequential(
        (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)
        (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))
      )
      (act): ReLU(inplace=True)
    )
    (4): BlazeBlock(
      (convs): Sequential(
        (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)
        (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))
      )
      (act): ReLU(inplace=True)
    )
    (5): BlazeBlock(
      (convs): Sequential(
        (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24

In [27]:
back_probable_names = get_probable_names(back_subgraph)
back_probable_names[:5]

['conv2d/Kernel',
 'conv2d/Bias',
 'depthwise_conv2d/Kernel',
 'depthwise_conv2d/Bias',
 'conv2d_1/Kernel']

In [28]:
back_convert = get_convert(back_net, back_probable_names)

In [29]:
back_state_dict = build_state_dict(back_model, back_subgraph, back_tensor_dict, back_net, back_convert)

backbone.0.weight conv2d/Kernel (24, 5, 5, 3) torch.Size([24, 3, 5, 5])
backbone.0.bias conv2d/Bias (24,) torch.Size([24])
backbone.2.convs.0.weight depthwise_conv2d/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])
backbone.2.convs.0.bias depthwise_conv2d/Bias (24,) torch.Size([24])
backbone.2.convs.1.weight conv2d_1/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])
backbone.2.convs.1.bias conv2d_1/Bias (24,) torch.Size([24])
backbone.3.convs.0.weight depthwise_conv2d_1/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])
backbone.3.convs.0.bias depthwise_conv2d_1/Bias (24,) torch.Size([24])
backbone.3.convs.1.weight conv2d_2/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])
backbone.3.convs.1.bias conv2d_2/Bias (24,) torch.Size([24])
backbone.4.convs.0.weight depthwise_conv2d_2/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])
backbone.4.convs.0.bias depthwise_conv2d_2/Bias (24,) torch.Size([24])
backbone.4.convs.1.weight conv2d_3/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])
backbone.4.convs.1

In [30]:
back_net.load_state_dict(back_state_dict, strict=True)

<All keys matched successfully>

In [31]:
torch.save(back_net.state_dict(), "blazefaceback.pth")