In [1]:
import yaml
import pathlib
import time

import torch
import torch.nn

import mnn.vision.models.vision_transformer.encoder.config as mnn_encoder_config
import mnn.vision.config as mnn_config
from mnn.vision.models.vision_transformer.e2e import (
    MyVisionTransformer
)

## UTILITIES

In [2]:
def inference_test(image: torch.Tensor, model: torch.nn.Module):
    t0 = time.time()
    output = model(image)
    t1 = time.time()
    print("Time taken:", t1 - t0, "seconds")
    print("Model's output shape:", output.shape)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def read_yaml_file(file_path: pathlib.Path) -> dict:
    with file_path.open(mode="r") as f:
        # Python 3.11 need Loader
        return yaml.load(f, Loader=yaml.FullLoader)

## CHOOSE EXPERIMENT

In [3]:
experiment_number = int(input("Choose experiment Number and press Enter:"))
experiment_name = f"experiment{experiment_number}"

## INITIALIZATION

In [4]:
""" CONFIGURATION """
n = 1

model_config_as_dict = read_yaml_file(pathlib.Path(experiment_name) / "model.yaml")
model_config = mnn_encoder_config.MyVisionTransformerConfiguration.from_dict(
    model_config_as_dict["MyVisionTransformerConfiguration"]
)
encoder_config = model_config.encoder_config

hyperparameters_config_as_dict = read_yaml_file(pathlib.Path(experiment_name) / "hyperparameters.yaml")
hyperparameters_config = mnn_config.HyperparametersConfiguration.from_dict(hyperparameters_config_as_dict)
""" MODEL """

image_size = hyperparameters_config.image_size
# image to pytorch tensor
sequence_length = image_size.height
embedding_size = image_size.width
hidden_dim = embedding_size
image_RGB = torch.rand(n, 3, sequence_length, image_size.width) * 255

In [5]:
my_transformer = MyVisionTransformer(encoder_config, image_size, n_high_level_layers=model_config.n_high_level_layers, is_input_normalized=True)
my_transformer.set_batch_size(n)

In [6]:
inference_test(image_RGB, my_transformer)
print("Number of parameters:", count_parameters(my_transformer))

Time taken: 0.5752787590026855 seconds
Model's output shape: torch.Size([1, 3, 1024, 1024])
Number of parameters: 56696832


In [7]:
print(my_transformer)

MyVisionTransformer(
  (positional_encoder): MyVisionPositionalEncoding()
  (encoder_combinator_list): Sequential(
    (0): EncoderCombinator(
      (encoder): RawVisionTransformerEncoderRGB(
        (encoder_rgb): ModuleList(
          (0-2): 3 x RawVisionTransformerEncoder(
            (encoder_block): TransformerEncoderBlock(
              (block): Sequential(
                (0): TransformerEncoder(
                  (layers): ModuleList(
                    (0): TransformerEncoderLayer(
                      (self_attn): MultiheadAttention(
                        (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
                      )
                      (linear1): Linear(in_features=1024, out_features=1024, bias=True)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (linear2): Linear(in_features=1024, out_features=1024, bias=True)
                      (norm1): LayerNorm((1024,), eps=1e-05, elementw