In [1]:
import time
import torch
import torch.nn
import mnn.vision.models.vision_transformer.encoder.config as mnn_config
import mnn.vision.image_size
from mnn.vision.models.vision_transformer.vit_encoder import (
    RawVisionTransformerEncoder,
    RawVisionTransformerEncoderRGB,
    ThreeChannelsCombinatorToThreeChannels
)

def inference_test(image: torch.Tensor, model: torch.nn.Module):
    t0 = time.time()
    output = model(image)
    t1 = time.time()
    print("Time taken:", t1 - t0, "seconds")
    print(output.shape)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [3]:
""" CONFIGURATION """
n = 1
feed_forward_dimensions = 1024

image_size = mnn.vision.image_size.ImageSize(height=1280, width=1280, channels=1)
# image to pytorch tensor
sequence_length = image_size.height
embedding_size = image_size.width
hidden_dim = embedding_size
number_of_layers = 1
image = torch.rand(n, sequence_length, image_size.width) * 255

encoder_config = mnn_config.VisionTransformerEncoderConfiguration(
    use_cnn=False, d_model=hidden_dim, number_of_layers=number_of_layers, feed_forward_dimensions=feed_forward_dimensions
)

In [None]:
my_encoder = RawVisionTransformerEncoder(encoder_config, image_size, is_input_normalized=False)
my_encoder_rgb = RawVisionTransformerEncoderRGB(encoder_config, image_size, is_input_normalized=False)

In [10]:
class EncoderCombinator(torch.nn.Module):
    def __init__(self, encoder: RawVisionTransformerEncoderRGB):
        super().__init__()
        self.encoder = encoder
        self.combinator = ThreeChannelsCombinatorToThreeChannels(encoder)

    def forward(self, x):
        x = self.encoder(x)
        print("encoder output shape:", x.shape)
        x = self.combinator(x)
        print("combinator output shape:", x.shape)

        return x


class MyVisionTransformer(torch.nn.Module):
    def __init__(
        self,
        encoder_config,
        image_size,
        n_high_level_layers: int,
        is_input_normalized: bool,
    ):
        super().__init__()

        self.encoder_combinator_list = torch.nn.ModuleList(
            [
                EncoderCombinator(
                    RawVisionTransformerEncoderRGB(
                        encoder_config,
                        image_size,
                        is_input_normalized=is_input_normalized,
                    )
                )
                for _ in range(n_high_level_layers)
            ]
        )
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        for encoder_combinator in self.encoder_combinator_list:
            x = encoder_combinator(x)
            x = self.sigmoid(x)
        return x  

In [11]:
my_transformer = MyVisionTransformer(encoder_config, image_size, n_high_level_layers=1, is_input_normalized=False)


In [12]:
# Something is buggy
inference_test(image, my_transformer)

encoder output shape: torch.Size([1, 3, 1280, 1280])
combinator output shape: torch.Size([3, 3, 1280, 1280])
Time taken: 0.9028410911560059 seconds
torch.Size([3, 3, 1280, 1280])


In [None]:
print(count_parameters(my_encoder))
print(count_parameters(my_encoder_rgb))

In [None]:
print(my_encoder)