In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Union, Optional
from dataclasses import dataclass

In [20]:
class SelfAttention(nn.Module):
    
    def __init__(self, d_embed, n_heads: int = 4, qkv_bias=True, out_bias=True) -> None:
        super().__init__()
        
        self.scale = d_embed ** -0.5
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads
        self.QKV = nn.Linear(d_embed, d_embed * 3, bias=qkv_bias)
        self.O = nn.Linear(d_embed, d_embed, bias=out_bias)
        
    def forward(self, x: torch.Tensor, mask: bool = False) -> torch.Tensor:
        # x: (batch_size, height*width, channels)
        # x: (batch_size, seq_len, d_embed)
        in_shape = x.shape
        bs, seq_len, d_embed = x.shape
        q, k, v = self.QKV(x).chunk(3, dim=-1) # (batch_size, seq_len, d_embed)@(d_embed, d_embed*3) -> (batch_size, seq_len, d_embed*3) -> (3x) (batch_size, seq_len, d_embed)
        
        # (batch_size, seq_len, d_embed) -> (batch_size, seq_len, n_heads, d_head) -> (batch_size, n_heads, seq_len, d_head)
        q = q.view(bs, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(bs, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(bs, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        
        attn_scores = (q @ k.transpose(-2, -1)) * self.scale
        if mask is not False:
            mask = torch.ones_like(attn_scores).bool().triu(1) # (batch_size, n_heads, seq_len, seq_len)
            attn_scores.masked_fill_(mask, -1e9)
        weights = F.softmax(attn_scores, dim=-1)
        output = weights @ v # (batch_size, n_heads, seq_len, seq_len) -> (batch_size, n_heads, seq_len, d_head) -> (batch_size, n_heads, seq_len, d_head)
        output = output.transpose(1, 2).contiguous().view(in_shape) # (batch_size, n_heads, seq_len, d_head) -> (batch_size, seq_len, n_heads, d_head) -> (batch_size, seq_len, d_embed)
        return self.O(output) # (batch_size, seq_len, d_embed)@(d_embed, d_embed) -> (batch_size, seq_len, d_embed)

In [21]:
@dataclass
class Config:
    vocab_size: int = 49408
    hidden_size: int = 768
    seq_len: int = 77
    n_heads: int = 12
    n_layers: int = 12
    layer_norm_eps: float = 1e-5

In [22]:
class CLIPEmbedding(nn.Module):
    """
    A class that converts input ids to embeddings.
    """
    def __init__(self, config: Config) -> None:
        super().__init__()
        
        self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embedding = nn.Parameter(torch.zeros(config.seq_len, config.hidden_size))
        
    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        # input_ids: (batch_size, seq_len)
        embeddings = self.token_embedding(input_ids) + self.position_embedding # (batch_size, seq_len, hidden_size)
        return embeddings

In [23]:
class CLIPBlock(nn.Module):
    
    def __init__(self, config: Config) -> None:
        super().__init__()
        
        self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.attention = SelfAttention(config.hidden_size, config.n_heads)
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size * 4),
            nn.GELU(),
            nn.Linear(config.hidden_size * 4, config.hidden_size)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, seq_len, hidden_size)
        x = x + self.attention(self.layer_norm1(x), mask=True)
        return x + self.mlp(self.layer_norm2(x)) # (batch_size, seq_len, hidden_size)

In [24]:
class CLIP(nn.Module):
    """
    A basic implementation of OpenAI's CLIP model
    """
    
    def __init__(self, config: Config) -> None:
        super().__init__()
        
        self.embeddings = CLIPEmbedding(config)
        self.layers = nn.ModuleList([CLIPBlock(config) for _ in range(config.n_layers)])
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        
    def forward(self, x: torch.LongTensor) -> torch.FloatTensor:
        token_embeddings = self.embeddings(x.type(torch.long)) # (batch_size, seq_len) -> (batch_size, seq_len, hidden_size)
        for layer in self.layers:
            token_embeddings = layer(token_embeddings)
        return self.layer_norm(token_embeddings) # (batch_size, seq_len, hidden_size)

In [25]:
# generate a test code
input_ids = torch.randint(0, 49408, (1, 77))
config = Config()
clip = CLIP(config)
clip

CLIP(
  (embeddings): CLIPEmbedding(
    (token_embedding): Embedding(49408, 768)
  )
  (layers): ModuleList(
    (0-11): 12 x CLIPBlock(
      (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attention): SelfAttention(
        (QKV): Linear(in_features=768, out_features=2304, bias=True)
        (O): Linear(in_features=768, out_features=768, bias=True)
      )
      (mlp): Sequential(
        (0): Linear(in_features=768, out_features=3072, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=3072, out_features=768, bias=True)
      )
    )
  )
  (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [26]:
clip(input_ids).shape

torch.Size([1, 77, 768])