In [1]:
import time
import cv2
import torch
import torch.nn
import mnn.vision.models.vision_transformer.encoder.config as mnn_config
import mnn.vision.image_size
from mnn.vision.models.vision_transformer.vit_encoder import (
    VisionTransformerEncoder,
    RawVisionTransformerEncoder,
    RawVisionTransformerEncoderRGB,
)

def inference_test(image: torch.Tensor, model: torch.nn.Module):
    t0 = time.time()
    output = model(image)
    t1 = time.time()
    print("Time taken:", t1 - t0, "seconds")
    print(output.shape)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [2]:
""" CONFIGURATION """
n = 1
feed_forward_dimensions = 4096

image_size = mnn.vision.image_size.ImageSize(height=1280, width=1280, channels=1)
# image to pytorch tensor
sequence_length = image_size.height
embedding_size = image_size.width
hidden_dim = embedding_size
number_of_layers = 1
image = torch.rand(n, sequence_length, image_size.width) * 255

encoder_config = mnn_config.VisionTransformerEncoderConfiguration(
    use_cnn=False, d_model=hidden_dim, number_of_layers=number_of_layers, feed_forward_dimensions=feed_forward_dimensions
)

In [3]:
my_encoder = RawVisionTransformerEncoder(encoder_config, image_size, is_input_normalized=False)
my_encoder_rgb = RawVisionTransformerEncoderRGB(encoder_config, image_size, is_input_normalized=False)

In [4]:
inference_test(image, my_encoder)
inference_test(image, my_encoder_rgb)

Time taken: 0.287841796875 seconds
torch.Size([1, 1280, 1280])
Time taken: 0.8668558597564697 seconds
torch.Size([1, 3, 1280, 1280])


In [None]:
print(count_parameters(my_encoder))
print(count_parameters(my_encoder_rgb))

In [None]:
print(my_encoder)