In [1]:
from typing import Optional,Tuple
import torch
import torch.nn as nn

In [2]:
# This file implements a Vision Transformer (ViT) model that processes images by:
# 1. Splitting the image into patches
# 2. Converting each patch into an embedding
# 3. Adding positional embeddings
# 4. Processing through transformer layers
# 5. Outputting final image features

In [14]:
class VisionConfig:
    
    def __init__(
        self,
        hidden_size=768,  # Size of the embeddings used throughout the model
        intermediate_size=3072,  # Size of the intermediate layer in MLP
        num_hidden_layers=12,  # Number of transformer layers
        num_attention_heads=12,  # Number of attention heads in each transformer layer
        num_channels=3,  # Number of input image channels (3 for RGB)
        image_size=224,  # Input image size (224x224 pixels)
        patch_size=16,  # Size of each image patch (16x16 pixels)
        layer_norm_eps=1e-6,  # Small constant for numerical stability in layer norm
        attention_dropout=0.0,  # Dropout rate for attention
        num_image_tokens: int = None,  # Number of image tokens (patches) - calculated as (image_size/patch_size)^2
        **kwargs
    ):
        super().__init__()
        
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_channels = num_channels
        self.patch_size = patch_size
        self.image_size = image_size
        self.attention_dropout = attention_dropout
        self.layer_norm_eps = layer_norm_eps
        self.num_image_tokens = num_image_tokens
        

In [4]:
"""Converts input images into patch embeddings and adds positional embeddings"""

'Converts input images into patch embeddings and adds positional embeddings'

In [15]:
class VisionEmbeddings(nn.Module):
    
    def __init__(self,config: VisionConfig):
        
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size
        
        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding='valid'
        )
        
        self.num_patches = (self.image_size // self.patch_size) ** 2
        self.num_positions = self.num_patches
        
        self.position_embedding = nn.Embedding(self.num_positions,self.embed_dim)
        
        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions).expand((1,-1)),
            persistent=False
        )
        
    def forward(self,pixel_values: torch.FloatTensor):
        
        _,_,height,width = pixel_values.shape
        
        patch_embeds = self.patch_embedding(pixel_values)
        print('patch1 shape',patch_embeds.shape)
        
        embeddings = patch_embeds.flatten(2)
        print('emb shape',embeddings.shape)
        
        embeddings = embeddings.transpose(1,2)
        print('emb1 shape',embeddings.shape)
        
        embeddings = embeddings + self.position_embedding(self.position_ids)
        print('embeddings shape',embeddings.shape)

        return embeddings

In [16]:
config = VisionConfig(hidden_size=128, image_size=256, patch_size=16, num_channels=3)

In [17]:
# Initialize the VisionEmbeddings model
model = VisionEmbeddings(config)

# Create a sample input tensor (batch_size=1, channels=3, height=256, width=256)
input_tensor = torch.randn(1, 3, 256, 256)

# Pass the input tensor through the model
output = model(input_tensor)

print("Output shape:", output.shape)  # Should be (1, num_patches, embed_dim)

patch1 shape torch.Size([1, 128, 16, 16])
emb shape torch.Size([1, 128, 256])
emb1 shape torch.Size([1, 256, 128])
embeddings shape torch.Size([1, 256, 128])
Output shape: torch.Size([1, 256, 128])


In [26]:
class Attention(nn.Module):
    
    def __init__(self,config):
        
        super().__init__()
        
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        
        self.scale = self.head_dim ** -0.5
        self.dropout = config.attention_dropout
        
        self.k_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.v_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.q_proj = nn.Linear(self.embed_dim,self.embed_dim)
        
        self.out_proj = nn.Linear(self.embed_dim,self.embed_dim)
        
    
    def forward(self,hidden_states: torch.Tensor) -> Tuple[torch.Tensor,
                                                          Optional[torch.Tensor]]:
        
        """Apply multi-headed self-attention to the input.
        
        Args:
            hidden_states (torch.Tensor): Input tensor of shape [Batch_Size, Num_Patches, Embed_Dim]
            
        Returns:
            Tuple[torch.Tensor, Optional[torch.Tensor]]: 
                - Attention output of shape [Batch_Size, Num_Patches, Embed_Dim]
                - Attention weights of shape [Batch_Size, Num_Heads, Num_Patches, Num_Patches]
        """
        batch_size, seq_len, _ = hidden_states.size()
        
        query_states = self.q_proj(hidden_states)
        
        key_states = self.k_proj(hidden_states)
        
        value_states = self.v_proj(hidden_states)
        
        query_states = query_states.view(batch_size, seq_len,self.num_heads,self.head_dim).transpose(1,2)
        
        key_states = key_states.view(batch_size, seq_len,self.num_heads,self.head_dim).transpose(1,2)

        value_states = value_states.view(batch_size, seq_len,self.num_heads,self.head_dim).transpose(1,2)
        
        
        attn_weights = (torch.matmul(query_states,key_states.transpose(2,3))*self.scale )
        
        if attn_weights.size() != (batch_size, self.num_heads, seq_len, seq_len):
            raise ValueError(
                f"Attention weights should be of size {(batch_size, self.num_heads, seq_len, seq_len)}, but is"
                f" {attn_weights.size()}"
            )
        
        attn_weights = nn.functional.softmax(attn_weights,dim=-1,dtype=torch.float32).to(query_states.dtype)
        
        attn_weights = nn.functional.dropout(attn_weights,p=self.dropout,training=self.training)
        
        attn_output = torch.matmul(attn_weights,value_states)
        
        
        if attn_output.size() != (batch_size, self.num_heads, seq_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(batch_size, self.num_heads, seq_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )
        
        attn_output = attn_output.transpose(1,2).contiguous()
        
        attn_output = attn_output.reshape(batch_size,seq_len,self.embed_dim)

        attn_output = self.out_proj(attn_output)
        
        return attn_output,attn_weights

In [27]:
config = VisionConfig(hidden_size = 128,num_attention_heads = 8,attention_dropout = 0.1)
attention_layer = Attention(config)
# 假设输入的 hidden_states 形状为 [batch_size, seq_len, embed_dim]
hidden_states = torch.randn(32, 10, 128)  # batch_size=32, seq_len=10, embed_dim=128

# 前向传播
attn_output, attn_weights = attention_layer(hidden_states)

# 输出的形状
print(attn_output.shape)  # 应该是 [batch_size, seq_len, embed_dim]
print(attn_weights.shape)  # 应该是 [batch_size, num_heads, seq_len, seq_len]

torch.Size([32, 10, 128])
torch.Size([32, 8, 10, 10])


In [28]:
class MLP(nn.Module):
    
    
    def __init__(self,config):
        
        super().__init__()
        
        self.config = config
        self.fc1 = nn.Linear(config.hidden_size,config.intermediate_size)
        self.fc2 = nn.Linear(config.intermediate_size,config.hidden_size)
        
    
    def forward(self,hidden_states:torch.Tensor) -> torch.Tensor:
        
        hidden_states = self.fc1(hidden_states)
        
        hidden_states = nn.functional.gelu(hidden_states,approximate='tanh')
        
        hidden_states = self.fc2(hidden_states)
        
        return hidden_states
        

In [29]:
        
config = config = VisionConfig(hidden_size = 128,intermediate_size = 64)
mlp_layer = MLP(config)

# 假设输入的 hidden_states 形状为 [batch_size, seq_len, embed_dim]
hidden_states = torch.randn(32, 10, 128)  # batch_size=32, seq_len=10, embed_dim=128

# 前向传播
output = mlp_layer(hidden_states)

# 输出的形状
print(output.shape)  # 应该是 [batch_size, seq_len, hidden_size] = [32, 10, 128]


torch.Size([32, 10, 128])


In [32]:
class EncoderLayer(nn.Module):
    
    def __init__(self,config:VisionConfig):
        
        super().__init__()
        self.embed_dim = config.hidden_size
        self.self_attn = Attention(config)
        self.layer_norm1 = nn.LayerNorm(self.embed_dim,eps=config.layer_norm_eps)
        self.mlp = MLP(config)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim,eps=config.layer_norm_eps)
    
    
    def forward(self,hidden_states: torch.Tensor) -> torch.Tensor:
        
        residual = hidden_states
        hidden_states = self.layer_norm1(hidden_states)
        
        hidden_states,_ = self.self_attn(hidden_states=hidden_states)
        
        hidden_states = residual + hidden_states
        
        residual = hidden_states
        
        hidden_states = self.layer_norm2(hidden_states)
        
        hidden_states = self.mlp(hidden_states)
        
        hidden_states = residual + hidden_states
        
        return hidden_states

In [36]:
config.hidden_size

768

In [34]:
# 创建配置对象
config = VisionConfig()

# 创建EncoderLayer对象
encoder_layer = EncoderLayer(config)

# 创建一个随机输入张量，假设batch size为10，序列长度为20
hidden_states = torch.randn(20, 10, config.hidden_size)

# 传入编码器层
output = encoder_layer(hidden_states)

print(output.shape)  # 输出的形状

torch.Size([20, 10, 768])


In [39]:
class Encoder(nn.Module):
    
    def __init__(self,config: VisionConfig):
        
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList([
          EncoderLayer(config)  for _ in range(config.num_hidden_layers)
        ])
        
    def forward(self,input_embeds: torch.Tensor) -> torch.Tensor:
        
        hidden_states = input_embeds
        
        for encoder in self.layers:
            hidden_states = encoder(hidden_states)
        return hidden_states

In [40]:
# 创建配置对象
config = VisionConfig()

# 创建EncoderLayer对象
encoder_layer = Encoder(config)

# 创建一个随机输入张量，假设batch size为10，序列长度为20
hidden_states = torch.randn(20, 10, config.hidden_size)

# 传入编码器层
output = encoder_layer(hidden_states)

print(output.shape)  # 输出的形状

torch.Size([20, 10, 768])


In [41]:
class VisionTransformer(nn.Module):
    
    def __init__(self,config: VisionConfig):
        
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size
        
        self.embeddings = VisionEmbeddings(config)
        self.encoder = Encoder(config)
        self.post_layernorm = nn.LayerNorm(embed_dim,eps=config.layer_norm_eps)
    
    def forward(self,pixel_values: torch.Tensor) -> torch.Tensor:
        
        hidden_states = self.embeddings(pixel_values)
        
        last_hidden_state = self.encoder(input_embeds=hidden_states)
        
        last_hidden_state = self.post_layernorm(last_hidden_state)
        
        return last_hidden_state

In [46]:
input_tensor = torch.randn(1, 3, 224, 224)

config = VisionConfig()

model = VisionTransformer(config)

hidden_state = model(input_tensor)

patch1 shape torch.Size([1, 768, 14, 14])
emb shape torch.Size([1, 768, 196])
emb1 shape torch.Size([1, 196, 768])
embeddings shape torch.Size([1, 196, 768])


In [47]:
hidden_state.shape

torch.Size([1, 196, 768])

In [48]:
class VisionModel(nn.Module):
    
    def __init__(self,config: VisionConfig):
        
        super().__init__()
        self.config = config
        self.vision_model = VisionTransformer(config)
    
    
    def forward(self,pixel_values) -> Tuple:
        
        return self.vision_model(pixel_values=pixel_values)

In [51]:
input_tensor = torch.randn(10, 3, 224, 224)

config = VisionConfig()

model = VisionModel(config)

hidden_state = model(input_tensor)

patch1 shape torch.Size([10, 768, 14, 14])
emb shape torch.Size([10, 768, 196])
emb1 shape torch.Size([10, 196, 768])
embeddings shape torch.Size([10, 196, 768])


In [52]:
hidden_state.shape

torch.Size([10, 196, 768])