# EoMT Architecture Visualization

This notebook visualizes the architecture of the EoMT (End-to-End Mask Transformer?) model used in the project. It instantiates the model with default configuration parameters and performs a dummy forward pass to show input/output shapes.

In [18]:
import sys
import os

# Add the current directory to sys.path to ensure modules can be imported
sys.path.append(os.getcwd())
os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")

import torch
from models.eomt import EoMT
from models.vit import ViT

try:
    from torchinfo import summary
except ImportError:
    print("torchinfo not found. Install it with `pip install torchinfo` for a prettier summary, or we will just use print(model).")
    summary = None

## 1. Model Configuration
We set up the configuration parameters matching `configs/dinov2/cityscapes/semantic/eomt_base_640.yaml`.

In [19]:
# Configuration Parameters
img_size = (640, 640)
patch_size = 14
backbone_name = "vit_base_patch14_reg4_dinov2"
num_classes = 19  # Cityscapes
num_q = 100
num_blocks = 3
masked_attn_enabled = True

## 2. Model Instantiation
First we create the ViT encoder backbone, then wrap it in the EoMT model.

In [20]:
print("Initializing ViT Backbone...")
encoder = ViT(
    img_size=img_size,
    patch_size=patch_size,
    backbone_name=backbone_name,
    ckpt_path=None # We don't need weights for visualization
)

print("Initializing EoMT Model...")
model = EoMT(
    encoder=encoder,
    num_classes=num_classes,
    num_q=num_q,
    num_blocks=num_blocks,
    masked_attn_enabled=masked_attn_enabled
)

# Move to evaluatoin mode
model.eval();

Initializing ViT Backbone...
Initializing EoMT Model...


## 3. Architecture Summary
Using `torchinfo` if available, otherwise printing the pytorch module.

In [21]:
if summary:
    summary(model, input_size=(1, 3, *img_size), depth=4)
else:
    print(model)
print(model)

EoMT(
  (encoder): ViT(
    (backbone): VisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
        (norm): Identity()
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (patch_drop): Identity()
      (norm_pre): Identity()
      (blocks): Sequential(
        (0): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): LayerScale()
          (drop_path1): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_feat

In [22]:
print(encoder)

ViT(
  (backbone): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): LayerScale()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
   

## 4. Forward Pass Check
Running a dummy input to verify output shapes.

Expected outputs:
1. `mask_logits_list`: List containing outputs for intermediate blocks + final. Length should be `num_blocks + 1`.
2. `class_logits_list`: Same length.

In [23]:
# Determine device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Move model to device
model = model.to(device)

# Create input on the same device
dummy_input = torch.randn(1, 3, *img_size).to(device)

with torch.no_grad():
    mask_logits_list, class_logits_list = model(dummy_input)

print(f"Input Shape: {dummy_input.shape}")
print(f"Number of outputs (stages): {len(mask_logits_list)}")

# Check global prediction (last one)
final_mask = mask_logits_list[-1]
final_class = class_logits_list[-1]

print(f"\nFinal Mask Logits Shape: {final_mask.shape}")
print(f"   Batch: {final_mask.shape[0]}")
print(f"   Queries: {final_mask.shape[1]}")
print(f"   Height: {final_mask.shape[2]}")
print(f"   Width: {final_mask.shape[3]}")

print(f"\nFinal Class Logits Shape: {final_class.shape}")
print(f"   Batch: {final_class.shape[0]}")
print(f"   Queries: {final_class.shape[1]}")
print(f"   Classes (N+1): {final_class.shape[2]}")

Using device: cuda
Input Shape: torch.Size([1, 3, 640, 640])
Number of outputs (stages): 4

Final Mask Logits Shape: torch.Size([1, 100, 90, 90])
   Batch: 1
   Queries: 100
   Height: 90
   Width: 90

Final Class Logits Shape: torch.Size([1, 100, 20])
   Batch: 1
   Queries: 100
   Classes (N+1): 20


## 5. Raw Timm Model Inspection
Checking the layers of the raw `vit_base_patch14_reg4_dinov2` model from `timm` directly.

In [1]:
import timm

print("Downloading and creating raw timm model...")
raw_model = timm.create_model("vit_base_patch14_reg4_dinov2", pretrained=True)
print(raw_model)

: 