In [6]:
import torch
import torch.nn as nn
from torch.nn import functional as F

import tiktoken
from dataclasses import dataclass

In [97]:
""" Text Decoder (GPT-style) """

@dataclass
class TextConfig:
  block_size: int = 77
  vocab_size: int = 50258 # from TextTokenizer
  n_layer: int = 6
  n_head: int = 8
  n_embd: int = 512
  out_dim: int = 512

class Attention(nn.Module):
    """ multiple self-attention heads in parallel """
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, batched together
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)

        self.n_head = config.n_head
        self.n_embd = config.n_embd


    def forward(self, x):
        B, T, C = x.shape # batch size, sequence length, n_embd
        # calculate query, key, value for all heads in a batch
        # C = n_head * head_size, eg n_head = 12, head_size = 64, so C = 768
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, -1).transpose(1, 2) #(B, T, n_head, head_size) -> (B, n_head, T, head_size)
        q = q.view(B, T, self.n_head, -1).transpose(1, 2) #(B, T, n_head, head_size) -> (B, n_head, T, head_size)
        v = v.view(B, T, self.n_head, -1).transpose(1, 2) #(B, T, n_head, head_size) -> (B, n_head, T, head_size)
        
        # use flash attention instead of manually implemented attention
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # (B, n_head, T, head_size)
        
        y = y.transpose(1, 2).reshape(B, T, -1) # (B, n_head, T, head_size) -> (B, T, n_head * head_size)

        y = self.c_proj(y) 
        return y    

class MLP(nn.Module):
    """ Linear layer + non-linearity to add compute after multi-head attention """

    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) # expand onto higher dimensional space
        self.gelu = nn.GELU(approximate='tanh')
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) # project back down to model's embedding dimensionality 

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class Block(nn.Module):
    """ Transformer block: Communication followed by computation, with residual connection (x +) """ 

    def __init__(self, config, is_decoder=True):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd) # ToDo: Understand
        self.attn = Attention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd) # ToDo: Understand
        self.mlp = MLP(config) 

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class TextDecoder(nn.Module):

  def __init__(self, config: TextDecoderConfig):
    super().__init__()
    self.config = config
    self.transformer = nn.ModuleDict(dict(
      wte = nn.Embedding(config.vocab_size, config.n_embd),
      wpe = nn.Embedding(config.block_size, config.n_embd),
      h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
      ln_f = nn.LayerNorm(config.n_embd),
    ))
    self.proj = nn.Linear(config.n_embd, config.out_dim, bias=False)

    self.apply(self._init_weights)

  def _init_weights(self, module):
    if isinstance(module, nn.Linear):
      std = 0.02
      torch.nn.init.normal_(module.weight, mean=0.0, std=std)
      if module.bias is not None:
        torch.nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
      torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    # Default LayerNorm init in pytorch matches GPT2, so no need to change

  def forward(self, idx, targets=None):
    B, T = idx.shape
    assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
      
    # idx and targets are both (B, T) tensor of integers
    token_emb = self.transformer.wte(idx) # (B, T, C = n_embd)
    pos_emb = self.transformer.wpe(torch.arange(T, device=idx.device)) # (T, C = n_embd)
    x = token_emb + pos_emb # (B, T, C = n_embd)
    for block in self.transformer.h:
      x = block(x) # (B, T, C = n_embd)
    x = self.transformer.ln_f(x)
    # Find first occurrence of eot_token_id (50256)
    # TODO: don't hardcode EOT token id
    eot_positions = (idx == 50256).int().argmax(dim=-1)  # (B,)
    x = x[torch.arange(x.shape[0]), eot_positions]  # (B, C = n_embd)
    x = self.proj(x) # (B, C = out_dim)
    x = x / x.norm(dim=-1, keepdim=True) # normalize to unit length for cosine similarity

    return x

In [None]:
class TextTokenizer: 
    """ 
    tiktoken Wrapper
    Not quite CLIP tokenizer, but approximates it using GPT-2 tokenizer 
    Vocab size = GPT-2 vocab size (50257) + 1 (for new SOT token) = 50258
    """

    def __init__(self, config):
        self.enc = tiktoken.get_encoding("gpt2")

        # Special tokens
        self.eot_token = "<|endoftext|>"
        self.pad_token = self.eot_token
        self.sot_token = "<|startoftext|>"
        self.eot_token_id = 50256 # already exists in GPT-2 tokenizer
        self.pad_token_id = self.eot_token_id
        self.sot_token_id = self.eot_token_id + 1 # doesn't exist in GPT-2 tokenizer
        
        self.block_size = config.block_size

    def encode(self, text):
        tokens = [self.sot_token_id]
        text_enc = self.enc.encode(text)
        if len(text_enc) + 2 > self.block_size:
            tokens.extend(text_enc[:self.block_size - 2])
        else:
            tokens.extend(text_enc)
            if len(tokens) < self.block_size:
                tokens.extend([self.pad_token_id] * (self.block_size - 1 - len(tokens)))
        tokens.extend([self.eot_token_id])
        return tokens

    def decode(self, ids, include_special_tokens=True):
        result = ""
        for id in ids:
            if id == self.sot_token_id:
                if include_special_tokens:
                    result += self.sot_token
            elif id == self.eot_token_id:
                if include_special_tokens:
                    result += self.eot_token
            else:
                result += self.enc.decode([id])
        return result

In [98]:
labels = [
    "a boy and a girl",
    "a red ball",
    "a boy and a girl playing soccer in the park with a red ball",
]

enc = TextTokenizer(TextConfig())
encodings = [torch.tensor(enc.encode(label), dtype=torch.long) for label in labels]

batch = torch.stack(encodings)
print(f"Batch shape: {batch.shape}\n----")

for encoding in encodings:
    print(f"Encoding preview: {encoding[:20]}")

print("----")

model = TextDecoder(TextConfig())
text_emb = model(batch)
print(f"Text embedding shape: {text_emb.shape}")

Batch shape: torch.Size([3, 77])
----
Encoding preview: tensor([50257,    64,  2933,   290,   257,  2576, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256])
Encoding preview: tensor([50257,    64,  2266,  2613, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256])
Encoding preview: tensor([50257,    64,  2933,   290,   257,  2576,  2712, 11783,   287,   262,
         3952,   351,   257,  2266,  2613, 50256, 50256, 50256, 50256, 50256])
----
Text embedding shape: torch.Size([3, 512])
