In [1]:
import torch
import torch.nn as nn

import tensorflow.compat.v2 as tf
import tensorflow.keras.layers as layers

from collections import OrderedDict

In [2]:
# helper functions converting numpy and pytorch weights to each other
# from timm @https://github.com/rwightman/pytorch-image-models/blob/01a0e25a67305b94ea767083f4113ff002e4435c/timm/models/vision_transformer.py#L608
@torch.no_grad()
def n2p(w, t=True):
    if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
        w = w.flatten()
    if t:
        if w.ndim == 4:
            w = w.transpose([3, 2, 0, 1])
        elif w.ndim == 3:
            w = w.transpose([2, 0, 1])
        elif w.ndim == 2:
            w = w.transpose([1, 0])
    return torch.from_numpy(w)
@torch.no_grad()
def p2n(w, t=True):
    if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
        w = w.flatten()
    if t:
        if w.ndim == 4:
            w = w.permute([2, 3, 1, 0])
        elif w.ndim == 3:
            w = w.permute([1, 2, 0])
        elif w.ndim == 2:
            w = w.permute([1, 0])
    return w.numpy()

In [3]:
class BottleneckPytorch(nn.Module):
  expansion = 4
  def __init__(self, inplanes, planes, stride=1):
    super().__init__()
    self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
    self.bn1 = nn.BatchNorm2d(planes)
    self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(planes)
    self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
    self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
    self.bn3 = nn.BatchNorm2d(planes * self.expansion)
    self.relu = nn.ReLU(inplace=True)
    self.downsample = None
    self.stride = stride
    if stride > 1 or inplanes != planes * BottleneckPytorch.expansion:
      self.downsample = nn.Sequential(OrderedDict([
        ("-1", nn.AvgPool2d(stride)),
        ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
        ("1", nn.BatchNorm2d(planes * self.expansion)) ]))
  def forward(self, x: torch.Tensor):
    identity = x
    out = self.relu(self.bn1(self.conv1(x)))
    out = self.relu(self.bn2(self.conv2(out)))
    out = self.avgpool(out)
    out = self.bn3(self.conv3(out))
    if self.downsample is not None:
      identity = self.downsample(x)
    out += identity
    out = self.relu(out)
    return out

class BottleneckTF(layers.Layer):
  expansion = 4
  def __init__(self, inplanes, planes, stride=1):
    super().__init__()
    self.conv1 = layers.Conv2D(planes, 1, padding='valid', use_bias=False, kernel_initializer='he_uniform')
    self.bn1 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)
    self.conv2 = layers.Conv2D(planes, 3, padding='same', use_bias=False, kernel_initializer='he_uniform')
    self.bn2 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)
    self.avgpool = layers.AveragePooling2D(stride, padding='valid') if stride > 1 else layers.Lambda(tf.identity)
    self.conv3 = layers.Conv2D(planes * self.expansion, 1, padding='valid',
      use_bias=False, kernel_initializer='he_uniform')
    self.bn3 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)
    self.relu = layers.ReLU()
    self.downsample = None
    self.stride = stride
    if stride > 1 or inplanes != planes * BottleneckTF.expansion:
      self.downsample = tf.keras.Sequential([
        layers.AveragePooling2D(stride, padding='valid'),
        layers.Conv2D(planes * self.expansion, 1, padding='valid', use_bias=False, kernel_initializer='he_uniform'),
        layers.BatchNormalization(momentum=0.9, epsilon=1e-5) ])
  def call(self, x):
    identity = x
    out = self.relu(self.bn1(self.conv1(x)))
    out = self.relu(self.bn2(self.conv2(out)))
    out = self.avgpool(out)
    out = self.bn3(self.conv3(out))
    if self.downsample is not None:
      identity = self.downsample(x)
    out += identity
    out = self.relu(out)
    return out

In [4]:
# pytorch weights →　tf weights
bp = BottleneckPytorch(3, 6)
bt = BottleneckTF(3, 6)
bp.eval()

# run tf model to create variables
img = torch.rand(1, 3, 8, 8)
imgTF = tf.Variable(img.permute(0, 2, 3, 1).numpy())

with torch.no_grad():
    print("Difference BEFORE loading weights")
    print( (bp(img).permute(0, 2, 3, 1).numpy() - bt(imgTF).numpy() ).sum() )
    
# from pytorch weights to tensorflow weights
pytorch_weights = []
for weight in bp.parameters():
    if len(weight.size()) > 0:
        pytorch_weights.append(weight)
for weight in bp.buffers():
    if len(weight.size()) > 0:
        pytorch_weights.append(weight)

for var, weight in zip(bt.variables, pytorch_weights):
    var.assign(p2n(weight.data))

bt.trainable = False

# verify outputs are same
img = torch.rand(1, 3, 8, 8)
imgTF = tf.Variable(img.permute(0, 2, 3, 1).numpy())

with torch.no_grad():
    print("Difference AFTER loading weights")
    print( (bp(img).permute(0, 2, 3, 1).numpy() - bt(imgTF).numpy() ).sum() )

Difference BEFORE loading weights
-538.1633
Difference AFTER loading weights
1.1597294e-05


In [5]:
# tf weights →　pytorch weights
bp = BottleneckPytorch(3, 6)
bt = BottleneckTF(3, 6)
bp.eval()

# run tf model to create variables
img = torch.rand(1, 3, 8, 8)
imgTF = tf.Variable(img.permute(0, 2, 3, 1).numpy())

with torch.no_grad():
    print("Difference BEFORE loading weights")
    print( (bp(img).permute(0, 2, 3, 1).numpy() - bt(imgTF).numpy() ).sum() )
    
# from tensorflow weights to pytorch weights
pytorch_weights = []
for weight in bp.parameters():
    if len(weight.size()) > 0:
        pytorch_weights.append(weight)
for weight in bp.buffers():
    if len(weight.size()) > 0:
        pytorch_weights.append(weight)

for var, weight in zip(bt.variables, pytorch_weights):
    weight.data = n2p( var.numpy() )

bt.trainable = False

# verify outputs are same
img = torch.rand(1, 3, 8, 8)
imgTF = tf.Variable(img.permute(0, 2, 3, 1).numpy())

with torch.no_grad():
    print("Difference AFTER loading weights")
    print( (bp(img).permute(0, 2, 3, 1).numpy() - bt(imgTF).numpy() ).sum() )

Difference BEFORE loading weights
-374.39005
Difference AFTER loading weights
3.7865713e-05
