In [1]:
import yaml
import pathlib
import time

import torch
import torch.nn

import mnn.vision.models.vision_transformer.encoder.config as mnn_config
import mnn.vision.image_size
from mnn.vision.models.vision_transformer.vision_transformer import (
    MyVisionTransformer
)



## UTILITIES

In [None]:
def inference_test(image: torch.Tensor, model: torch.nn.Module):
    t0 = time.time()
    output = model(image)
    t1 = time.time()
    print("Time taken:", t1 - t0, "seconds")
    print("Model's output shape:", output.shape)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def read_yaml_file(file_path: pathlib.Path) -> dict:
    with file_path.open(mode="r") as f:
        return yaml.load(f)

## CHOOSE EXPERIMENT

In [1]:
experiment_number = int(input("Choose experiment Number and press Enter:"))
experiment_name = f"experiment{experiment_number}"

## INITIALIZATION

In [2]:
""" CONFIGURATION """
n = 1

model_config_as_dict = read_yaml_file(pathlib.Path(experiment_name) / "model.yaml")
model_config = mnn_config.VisionTransformerEncoderConfiguration.from_dict(model_config_as_dict)
hyperparameters_config_as_dict = read_yaml_file(pathlib.Path(experiment_name) / "hyperparameters.yaml")

""" MODEL """

feed_forward_dimensions = 512

image_size = mnn.vision.image_size.ImageSize(height=1024, width=1024, channels=1)
# image to pytorch tensor
sequence_length = image_size.height
embedding_size = image_size.width
hidden_dim = embedding_size
number_of_layers = 1
image = torch.rand(n, sequence_length, image_size.width) * 255
image_RGB = torch.rand(n, 3, sequence_length, image_size.width) * 255

encoder_config = mnn_config.VisionTransformerEncoderConfiguration(
    use_cnn=False, d_model=hidden_dim, number_of_layers=number_of_layers, feed_forward_dimensions=feed_forward_dimensions
)

In [3]:
my_transformer = MyVisionTransformer(encoder_config, image_size, n_high_level_layers=5, is_input_normalized=False)
my_transformer.set_batch_size(n)

In [4]:
inference_test(image_RGB, my_transformer)
print("Number of parameters:", count_parameters(my_transformer))

Time taken: 4.706727027893066 seconds
Model's output shape: torch.Size([1, 3, 1280, 1280])
