In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define PyTorch model
class VertexVista(nn.Module):
    def __init__(self, num_classes):
        super(VertexVista, self).__init__()
        # First Block with Skip Connection
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.shortcut1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=1, stride=1),
            nn.BatchNorm2d(64)
        )

        # Second Block
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.shortcut2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=1, stride=2),
            nn.BatchNorm2d(128)
        )

        # Global Average Pooling and Fully Connected Layer
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        # First Block
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = out + self.shortcut1(identity)
        out = F.relu(out)

        # Second Block
        identity = out
        out = F.relu(self.bn3(self.conv3(out)))
        out = F.relu(self.bn4(self.conv4(out)))
        out = out + self.shortcut2(identity)
        out = F.relu(out)

        # Global Average Pooling and Fully Connected Layer
        out = self.pool(out)
        out = out.view(out.size(0), -1)
        out = self.dropout(out)
        out = self.fc(out)
        return out  

# Create PyTorch model and load state
num_classes = 10
model = VertexVista(num_classes)

n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of Params: {n_parameters / 1000000:.1f}M")

Number of Params: 0.3M


In [10]:
torch.save(model.state_dict(), 'vertex_vista.pth')

In [11]:
import tensorflow as tf
from tensorflow.keras import layers, models

# Define TensorFlow model
def VertexVista_tf(num_classes):
    inputs = tf.keras.Input(shape=(224, 224, 3))

    # First Block with Skip Connection
    x = layers.Conv2D(64, (3, 3), padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(64, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)

    shortcut = layers.Conv2D(64, (1, 1), strides=1)(inputs)
    shortcut = layers.BatchNormalization()(shortcut)

    x = layers.add([x, shortcut])
    x = layers.ReLU()(x)

    # Second Block with Skip Connection
    x = layers.Conv2D(128, (3, 3), padding='same', strides=2)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(128, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)

    shortcut = layers.Conv2D(128, (1, 1), strides=2)(shortcut)
    shortcut = layers.BatchNormalization()(shortcut)

    x = layers.add([x, shortcut])
    x = layers.ReLU()(x)

    # Global Average Pooling and Fully Connected Layer
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    model = models.Model(inputs, outputs)
    return model


num_classes = 10
tf_model = VertexVista_tf(num_classes)

n_parameters = tf_model.count_params()
print(f"Number of Params: {n_parameters / 1000000:.1f}M")

Number of Params: 0.3M


In [12]:
pytorch_model = VertexVista(num_classes)

pytorch_model.load_state_dict(torch.load('/kaggle/working/vertex_vista.pth', map_location=torch.device('cpu')))
pytorch_model.eval()

# Helper function to convert PyTorch tensors to NumPy arrays
def pt_to_np(tensor):
    return tensor.detach().cpu().numpy()

In [13]:
# for i, layer in enumerate(tf_model.layers):
#     print(i, layer.name, [w.shape for w in layer.get_weights()])

In [14]:
tf_model.layers[1].set_weights([pt_to_np(pytorch_model.conv1.weight.permute(2, 3, 1, 0)), pt_to_np(pytorch_model.conv1.bias)])
tf_model.layers[2].set_weights([pt_to_np(pytorch_model.bn1.weight), pt_to_np(pytorch_model.bn1.bias),
                                pt_to_np(pytorch_model.bn1.running_mean), pt_to_np(pytorch_model.bn1.running_var)])

tf_model.layers[4].set_weights([pt_to_np(pytorch_model.conv2.weight.permute(2, 3, 1, 0)), pt_to_np(pytorch_model.conv2.bias)])
tf_model.layers[7].set_weights([pt_to_np(pytorch_model.bn2.weight), pt_to_np(pytorch_model.bn2.bias),
                                pt_to_np(pytorch_model.bn2.running_mean), pt_to_np(pytorch_model.bn2.running_var)])

tf_model.layers[5].set_weights([pt_to_np(pytorch_model.shortcut1[0].weight.permute(2, 3, 1, 0)), pt_to_np(pytorch_model.shortcut1[0].bias)])
tf_model.layers[6].set_weights([pt_to_np(pytorch_model.shortcut1[1].weight), pt_to_np(pytorch_model.shortcut1[1].bias),
                                pt_to_np(pytorch_model.shortcut1[1].running_mean), pt_to_np(pytorch_model.shortcut1[1].running_var)])

tf_model.layers[10].set_weights([pt_to_np(pytorch_model.conv3.weight.permute(2, 3, 1, 0)), pt_to_np(pytorch_model.conv3.bias)])
tf_model.layers[11].set_weights([pt_to_np(pytorch_model.bn3.weight), pt_to_np(pytorch_model.bn3.bias),
                                pt_to_np(pytorch_model.bn3.running_mean), pt_to_np(pytorch_model.bn3.running_var)])

tf_model.layers[13].set_weights([pt_to_np(pytorch_model.conv4.weight.permute(2, 3, 1, 0)), pt_to_np(pytorch_model.conv4.bias)])
tf_model.layers[16].set_weights([pt_to_np(pytorch_model.bn4.weight), pt_to_np(pytorch_model.bn4.bias),
                                pt_to_np(pytorch_model.bn4.running_mean), pt_to_np(pytorch_model.bn4.running_var)])

tf_model.layers[14].set_weights([pt_to_np(pytorch_model.shortcut2[0].weight.permute(2, 3, 1, 0)), pt_to_np(pytorch_model.shortcut2[0].bias)])
tf_model.layers[15].set_weights([pt_to_np(pytorch_model.shortcut2[1].weight), pt_to_np(pytorch_model.shortcut2[1].bias),
                                pt_to_np(pytorch_model.shortcut2[1].running_mean), pt_to_np(pytorch_model.shortcut2[1].running_var)])

tf_model.layers[21].set_weights([pt_to_np(pytorch_model.fc.weight.T), pt_to_np(pytorch_model.fc.bias)])

In [15]:
# Save the TensorFlow model
tf_model.save('vertex_vista_tf.h5')

In [16]:
# Load the saved TensorFlow model
loaded_tf_model = tf.keras.models.load_model('vertex_vista_tf.h5')

In [17]:
n_parameters = loaded_tf_model.count_params()
print(f"Number of Params: {n_parameters / 1000000:.1f}M")

Number of Params: 0.3M
