This encoder produces two outputs: 
cls_out (a global feature vector summarizing the state) and
 feature_map (a smaller spatial feature map of size 
H/patch × W/patch capturing spatial details).
 We will use 
feature_map as input to the next stage, and use 
cls_out to initialize hidden
 states.

In [None]:
import torch
import torch.nn as nn
class ViTEncoder(nn.Module):
def __init__(self, img_size, patch_size, in_channels, embed_dim, num_layers=6, num_heads=8):
    super().__init__()
    assert img_size % patch_size == 0, "Image size must be divisible by patch size"
    self.patch_size = patch_size
    self.num_patches = (img_size // patch_size) ** 2
    self.embed_dim = embed_dim

    # Patch embedding: conv layer that produces embed_dim feature maps from input channels
    self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
    
    # Class token and positional embedding
    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
    self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
 
    # Transformer Encoder
    encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=4*embed_dim)
    self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    
def forward(self, x):
    # x shape: (B, in_channels, H, W)
    B = x.size(0)
    # Create patch embeddings
    patches = self.patch_embed(x)
    # (B, embed_dim, H/patch, W/ patch)

    patches = patches.flatten(2).transpose(1, 2) # (B, N, embed_dim), 
    N=num_patches
    
    # Prepend class token
    cls_tokens = self.cls_token.expand(B,-1,-1) # (B, 1, embed_dim)
    tokens = torch.cat([cls_tokens, patches], dim=1) # (B, N+1, embed_dim)
    tokens = tokens + self.pos_embed[:, :tokens.size(1), :]
 
    # Transformer encoding
    tokens = tokens.transpose(0, 1)
    # (N+1, B, embed_dim) for transformer
    enc_outputs = self.transformer(tokens)
    enc_outputs = enc_outputs.transpose(0, 1)
    # Separate class token and patch embeddings
    cls_out = enc_outputs[:, 0, :]
    # (N+1, B, embed_dim)
    # (B, N+1, embed_dim)
    # (B, embed_dim)
 patch_out = enc_outputs[:, 1:, :].transpose(1, 2) # (B, embed_dim, N)
 
 # Reshape patch_out back to spatial grid
 grid_size = int(self.num_patches**0.5)
 feature_map = patch_out.view(B, self.embed_dim, grid_size, grid_size)
 
 # (B, embed_dim, H/patch, W/patch)
 return cls_out, feature_map