# 0. imports

In [1]:
import yaml
import pathlib
import time

import torch
import torch.nn

import mnn.vision.image_size
import mnn.vision.models.vision_transformer.encoder.config as mnn_encoder_config
import mnn.vision.config as mnn_config
from mnn.vision.models.vision_transformer.e2e import MyVisionTransformer
from mnn.vision.models.vision_transformer.tasks.object_detection import (
    ObjectDetectionOrdinalHead,
)

# 1. UTILITIES


In [2]:
def inference_test(image: torch.Tensor, model: torch.nn.Module):
    t0 = time.time()
    output = model(image)
    t1 = time.time()
    print("Time taken:", t1 - t0, "seconds")
    print("Model's output shape:", output.shape)
    traced_model = torch.jit.trace(model.forward, image, check_trace=True, strict=True)
    return traced_model


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def read_yaml_file(file_path: pathlib.Path) -> dict:
    with file_path.open(mode="r") as f:
        # Python 3.11 need Loader
        return yaml.load(f, Loader=yaml.FullLoader)

# 3. MODEL DEFINITION


## 3.1 Configuration

In [6]:
""" CONFIGURATION """
def load_model_config(yaml_path: pathlib.Path):
    model_config_as_dict = read_yaml_file(yaml_path)
    model_config = mnn_encoder_config.MyBackboneVitConfiguration.from_dict(
        model_config_as_dict["network"]["backbone"]
    )
    encoder_config = model_config.encoder_config
    head_config = mnn_encoder_config.VisionTransformerEncoderConfiguration.from_dict(
        model_config_as_dict["network"]["head"]["VisionTransformerHead"]
    )
    return model_config, encoder_config, head_config

def load_hyperparameters_config(yaml_path: pathlib.Path):
    hyperparameters_config_as_dict = read_yaml_file(yaml_path)
    hyperparameters_config = mnn_config.HyperparametersConfiguration.from_dict(hyperparameters_config_as_dict)
    return hyperparameters_config


model_config, encoder_config, head_config = load_model_config(pathlib.Path("model.yaml"))
hyperparameters_config = load_hyperparameters_config(pathlib.Path("hyperparameters.yaml"))

batch_size = hyperparameters_config.batch_size
embedding_size = model_config.rgb_combinator_config.d_model
sequence_length = model_config.rgb_combinator_config.feed_forward_dimensions
image_size = mnn.vision.image_size.ImageSize(width=embedding_size, height=sequence_length)

hidden_dim = embedding_size
image_RGB = torch.rand(batch_size, 3, image_size.height, image_size.width) * 255


## 3.2 NETWORK DEFINITION

In [7]:
class VitObjectDetectionNetwork(torch.nn.Module):

    def __init__(
        self,
        encoder_config: mnn_encoder_config.VisionTransformerEncoderConfiguration,
        head_config: mnn_encoder_config.VisionTransformerEncoderConfiguration,
        image_size: mnn.vision.image_size.ImageSize
    ):
        super().__init__()
        expected_image_width = encoder_config.d_model
        expected_image_height = -1
        self.expected_image_size = (3, expected_image_height, expected_image_width)
        self.encoder = MyVisionTransformer(encoder_config, image_size)
        self.head = ObjectDetectionOrdinalHead(config=head_config)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encoder(x)
        x = self.head(x)
        return x

def preprocess_tensor(x: torch.Tensor, expected_image_width: int) -> torch.Tensor:
    """
    Expecting tensors of shape (3, H, W)
    """
    x_w = x.shape[2]
    if x_w != expected_image_width:
        pass

## 3.3 MODEL UTILITIES

In [None]:
object_detection_model = VitObjectDetectionNetwork(
    encoder_config=model_config,
    head_config=head_config,
    image_size=image_size,
    dtype=hyperparameters_config.floating_point_precision,
)

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

object_detection_model.to(device=device, dtype=hyperparameters_config.floating_point_precision)
image_RGB = image_RGB.to(device=device, dtype=hyperparameters_config.floating_point_precision)

## 3.4 INFERENCE PROFILING


In [8]:
import mnn.visualize
import time

with torch.no_grad():
    for _ in range(10):
        image_RGB = torch.rand(hyperparameters_config.batch_size, 3, image_size.height, image_size.width) * 255
        image_RGB = image_RGB.to(device=device, dtype=hyperparameters_config.floating_point_precision)
        t0 = time.time()
        output = object_detection_model(image_RGB)
        out = output.detach().cpu().numpy()
        t1 = time.time()
        print("Time taken:", t1 - t0, "seconds | image_shape:", image_RGB.shape, "output_shape:", output.shape)

Time taken: 0.011376142501831055 seconds | image_shape: torch.Size([1, 3, 512, 512]) output_shape: torch.Size([1, 512, 512])
Time taken: 0.009015083312988281 seconds | image_shape: torch.Size([1, 3, 512, 512]) output_shape: torch.Size([1, 512, 512])
Time taken: 0.011157035827636719 seconds | image_shape: torch.Size([1, 3, 512, 512]) output_shape: torch.Size([1, 512, 512])
Time taken: 0.009406089782714844 seconds | image_shape: torch.Size([1, 3, 512, 512]) output_shape: torch.Size([1, 512, 512])
Time taken: 0.008602142333984375 seconds | image_shape: torch.Size([1, 3, 512, 512]) output_shape: torch.Size([1, 512, 512])
Time taken: 0.011742830276489258 seconds | image_shape: torch.Size([1, 3, 512, 512]) output_shape: torch.Size([1, 512, 512])
Time taken: 0.011185407638549805 seconds | image_shape: torch.Size([1, 3, 512, 512]) output_shape: torch.Size([1, 512, 512])
Time taken: 0.00852823257446289 seconds | image_shape: torch.Size([1, 3, 512, 512]) output_shape: torch.Size([1, 512, 512])
T

In [None]:
from torch.profiler import profile, record_function, ProfilerActivity

with profile(activities=[
        ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("model_inference"):
        object_detection_model(image_RGB)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

## 3.5 MODEL VISUALIZATION

In [None]:
mnn.visualize.make_dot(
    output, params=dict(object_detection_model.named_parameters())
).render("my_transformer", format="png")