In [1]:
import yaml
import pathlib
import time

import torch
import torch.nn

import mnn.vision.image_size
import mnn.vision.models.vision_transformer.encoder.config as mnn_encoder_config
import mnn.vision.config as mnn_config
from mnn.vision.models.vision_transformer.e2e import (
    MyVisionTransformer
)
from mnn.vision.models.vision_transformer.tasks.object_detection import ObjectDetectionOrdinalHead 

## UTILITIES

In [2]:
def inference_test(image: torch.Tensor, model: torch.nn.Module):
    t0 = time.time()
    output = model(image)
    t1 = time.time()
    print("Time taken:", t1 - t0, "seconds")
    print("Model's output shape:", output.shape)
    traced_model = torch.jit.trace(model.forward, image, check_trace=True, strict=True)
    return traced_model

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def read_yaml_file(file_path: pathlib.Path) -> dict:
    with file_path.open(mode="r") as f:
        # Python 3.11 need Loader
        return yaml.load(f, Loader=yaml.FullLoader)

## INITIALIZATION

In [4]:
""" CONFIGURATION """
model_config_as_dict = read_yaml_file(pathlib.Path("model.yaml"))
model_config = mnn_encoder_config.MyVisionTransformerConfiguration.from_dict(
    model_config_as_dict["MyVisionTransformer"]
)
encoder_config = model_config.encoder_config
head_config = mnn_encoder_config.VisionTransformerEncoderConfiguration.from_dict(
    model_config_as_dict["MyVisionTransformer"]["VisionTransformerHead"]
)

hyperparameters_config_as_dict = read_yaml_file(pathlib.Path("hyperparameters.yaml"))
hyperparameters_config = mnn_config.HyperparametersConfiguration.from_dict(hyperparameters_config_as_dict)

image_size = hyperparameters_config.image_size
sequence_length = image_size.height
embedding_size = image_size.width
hidden_dim = embedding_size
image_RGB = torch.rand(hyperparameters_config.batch_size, 3, sequence_length, image_size.width) * 255


In [6]:
class VitObjectDetectionNetwork(torch.nn.Module):

    def __init__(
            self,
            encoder_config: mnn_encoder_config.VisionTransformerEncoderConfiguration,
            head_config: mnn_encoder_config.VisionTransformerEncoderConfiguration,
            image_size: mnn.vision.image_size.ImageSize,
            is_input_normalized: bool,
            dtype: torch.dtype,
            batch_size: int
        ):
        super().__init__()
        self.encoder = MyVisionTransformer(encoder_config, image_size, is_input_normalized, dtype)
        self.encoder.set_batch_size(batch_size)
        self.head = ObjectDetectionOrdinalHead(config=head_config)
        self.head.to(dtype=dtype)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encoder(x)
        x = self.head(x)
        return x

    def set_batch_size(self, batch_size: int):
        self.encoder.set_batch_size(batch_size)

object_detection_model = VitObjectDetectionNetwork(
    encoder_config=model_config,
    head_config=head_config,
    image_size=image_size,
    is_input_normalized=True,
    dtype=hyperparameters_config.floating_point_precision,
    batch_size=hyperparameters_config.batch_size
)
image_RGB = image_RGB.to(dtype=hyperparameters_config.floating_point_precision)


### Visualize the model

In [7]:
import mnn.visualize
output = object_detection_model(image_RGB)

RuntimeError: The size of tensor a (3) must match the size of tensor b (16) at non-singleton dimension 1

In [None]:
mnn.visualize.make_dot(output, params=dict(my_transformer.named_parameters())).render("my_transformer", format="png")
