In [1]:
import time
import cv2
import torch
import torch.nn
import mnn.vision.models.vision_transformer.encoder.config as mnn_config
import mnn.vision.image_size
from mnn.vision.models.vision_transformer.vit_encoder import (
    VisionTransformerEncoder,
    RawVisionTransformerEncoder,
    RawVisionTransformerEncoderRGB,
)

def inference_test(image: torch.Tensor, model: torch.nn.Module):
    t0 = time.time()
    output = model(image)
    t1 = time.time()
    print("Time taken:", t1 - t0, "seconds")
    print(output.shape)
    
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [2]:
""" CONFIGURATION """
n = 1
feed_forward_dimensions = 4096

image_size = mnn.vision.image_size.ImageSize(height=1280, width=1280, channels=1)
# image to pytorch tensor
sequence_length = image_size.height
embedding_size = image_size.width
hidden_dim = embedding_size
number_of_layers = 1
image = torch.randn(n, sequence_length, image_size.width)

encoder_config = mnn_config.VisionTransformerEncoderConfiguration(
    use_cnn=False, d_model=hidden_dim, number_of_layers=number_of_layers, feed_forward_dimensions=feed_forward_dimensions
)

In [3]:
my_encoder = RawVisionTransformerEncoder(encoder_config, image_size)
# my_encoder_rgb = torch.nn.Sequential(RawVisionTransformerEncoderRGB(encoder_config, image_size), RawVisionTransformerEncoderRGB(encoder_config, image_size))
my_encoder_rgb = RawVisionTransformerEncoderRGB(encoder_config, image_size)

In [4]:
inference_test(image, my_encoder_rgb)

Time taken: 0.9234600067138672 seconds
torch.Size([1, 3, 1280, 1280])


In [5]:
count_parameters(my_encoder_rgb)

51172608

In [6]:
print(my_encoder)

RawVisionTransformerEncoder(
  (positional_encoder): PositionalEncoding()
  (encoder_block): TransformerEncoderBlock(
    (block): Sequential(
      (0): TransformerEncoder(
        (layers): ModuleList(
          (0): TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=1280, out_features=1280, bias=True)
            )
            (linear1): Linear(in_features=1280, out_features=4096, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (linear2): Linear(in_features=4096, out_features=1280, bias=True)
            (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
            (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
            (dropout1): Dropout(p=0.1, inplace=False)
            (dropout2): Dropout(p=0.1, inplace=False)
          )
        )
        (norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
)
