In [15]:
%%capture 
%run config.ipynb
%run InkML-parser.ipynb
%run tokenizer.ipynb

In [16]:
import torch 
from torchinfo import summary
from pprint import pprint
from transformers import ViTConfig, ViTModel
from transformers import MT5Config, MT5Model


In [17]:
# Define ViT model 
vitConfig = ViTConfig(
    image_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    num_channels=IMG_IN_CHANNELS,
    hidden_size=D_MODEL,
    num_hidden_layers=VIT_N_LAYERS,
    num_attention_heads=VIT_N_HEADS,
    intermediate_size=VIT_FFN_HIDDEN,
    hidden_dropout_prob=VIT_DROPOUT,
    attention_probs_dropout_prob=VIT_DROPOUT,
)

ViT_model = ViTModel(vitConfig)

logger.info(f'ViT Model Configuration: {ViT_model.config}')
print(f'ViT Model Configuration: {ViT_model.config}')

logger.info(summary(ViT_model, input_size=(BATCH_SIZE, IMG_IN_CHANNELS, IMG_SIZE, IMG_SIZE)))
print(summary(ViT_model, input_size=(BATCH_SIZE, IMG_IN_CHANNELS, IMG_SIZE, IMG_SIZE)))


ViT Model Configuration: ViTConfig {
  "_attn_implementation_autoset": true,
  "attention_probs_dropout_prob": 0.1,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 512,
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 1024,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 8,
  "num_channels": 3,
  "num_hidden_layers": 6,
  "patch_size": 16,
  "qkv_bias": true,
  "transformers_version": "4.47.0"
}

Layer (type:depth-idx)                                  Output Shape              Param #
ViTModel                                                [64, 512]                 --
├─ViTEmbeddings: 1-1                                    [64, 197, 512]            101,376
│    └─ViTPatchEmbeddings: 2-1                          [64, 196, 512]            --
│    │    └─Conv2d: 3-1                                 [64, 512, 14, 14]         393,728
│    └─Dropout: 2-2                                     [64, 1

In [18]:
# Define mT5 Model 
mt5Config = MT5Config(
    vocab_size=ENC_NUM_TOKENS,
    d_model=D_MODEL,
    num_layers=MT5_ENC_N_LAYERS,
    num_decoder_layers=MT5_DEC_N_LAYERS,
    num_heads=MT5_ENC_N_HEADS,
    d_ff=MT5_FFN_HIDDEN,
    dropout_rate=MT5_DROPOUT,
)

MT5_model = MT5Model(mt5Config)

logger.info(f'ViT Model Configuration: {MT5_model.config}')
print(f'ViT Model Configuration: {MT5_model.config}')

tmp_mt5_input_data = {
    "input_ids": torch.randint(0, ENC_NUM_TOKENS, (BATCH_SIZE, MT5_ENC_MAX_SEQ_LEN)),
    "decoder_input_ids": torch.randint(0, DEC_NUM_TOKENS, (BATCH_SIZE, MT5_DEC_MAX_SEQ_LEN))

}

logger.info(summary(MT5_model, input_data=tmp_mt5_input_data))
print(summary(MT5_model, input_data=tmp_mt5_input_data))


ViT Model Configuration: MT5Config {
  "_attn_implementation_autoset": true,
  "classifier_dropout": 0.0,
  "d_ff": 1024,
  "d_kv": 64,
  "d_model": 512,
  "decoder_start_token_id": 0,
  "dense_act_fn": "gelu_new",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "mt5",
  "num_decoder_layers": 3,
  "num_heads": 8,
  "num_layers": 3,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "tie_word_embeddings": false,
  "tokenizer_class": "T5Tokenizer",
  "transformers_version": "4.47.0",
  "use_cache": true,
  "vocab_size": 1002
}

Layer (type:depth-idx)                                       Output Shape              Param #
MT5Model                                                     [64, 512, 512]            --
├─MT5Stack: 1-1                                              [64, 