In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math, time, os
from torch.nn.attention.flex_attention import create_block_mask
from torch.nn.attention import flex_attention

In [2]:
with open('autocomplete/lecture/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
print(f'Length of dataset in characters: {len(text)}')

Length of dataset in characters: 1115394


In [3]:
characters = sorted(list(set(text)))
vocab_size = len(characters)
print("All the unique characters:", ''.join(characters))

All the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [4]:
#encoder: string to integer
def encode(string):
    return [characters.index(c) for c in string]

#decoder: integer to string
def decode(index):
  return ''.join([characters[i] for i in index])

In [5]:
string = "hello there"
print(encode(string))
print(decode(encode(string)))

[46, 43, 50, 50, 53, 1, 58, 46, 43, 56, 43]
hello there


In [6]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)

torch.Size([1115394]) torch.int64


In [7]:
train_ratio = 0.9
n = int(train_ratio * len(data))
train_data = data[:n]
val_data = data[n:]
print(train_data.shape, val_data.shape)

torch.Size([1003854]) torch.Size([111540])


In [8]:
# hyperparameters
batch_size = 64# how many independent sequences will we process in parallel?
block_size = 256# what is the maximum context length for predictions?
max_iters = 5000# how many batches to train on
eval_interval = 500 # how often to evaluate the model
learning_rate = 3e-4 # learning rate for optimizer
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu' # use GPU if available
# device='cpu'
eval_iters = 200 # how many batches to use for evaluation
n_embd = 384 # embedding dimension
n_head = 6 # number of attention heads
n_layer = 6 # number of transformer blocks
dropout = 0.2 # dropout rate
sliding_window_len = 64

In [9]:
def get_batch(is_train = True):
    data = train_data if is_train else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    # find batch_size random starting indices in the data
    x = torch.stack([data[i:i+block_size] for i in ix])
    # get block_size length sequences starting from those indices
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    # for testing, offset start index by 1 to preduct the next character
    x, y = x.to(device), y.to(device)
    return x, y

In [10]:
print(torch.tril(torch.ones(3,3)))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])


In [11]:
class SelfAttentionHead(nn.Module):
  def causal_mask(self, b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

  def __init__(self, head_size):
    # self.device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu' # use GPU if available
    self.device = device
    super().__init__()
    self.key = nn.Linear(n_embd, head_size, bias=False)
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)
    self.o_proj = nn.Linear(head_size, n_embd, bias=False)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    self.dropout = nn.Dropout(dropout)
    # self.block_mask = create_block_mask(self.causal_mask, 1, 1, block_size,block_size, device=self.device)
  
  def forward(self, x):
    B, T, C = x.shape # batch size, sequence length, embedding dimension (n_embd)
    k = self.key(x)   # (B, T, head_size)
    q = self.query(x)
    
    #compute attention scores
    weights = torch.matmul(q, k.transpose(-2, -1)) * C**-0.5 # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)
    weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
    weights = F.softmax(weights, dim=-1) # (B, T, T)
    weights = self.dropout(weights)
    
    value = self.value(x) # (B, T, head_size)
    # output = flex_attention.flex_attention(q, k, value,block_mask=self.block_mask)
    # output = self.o_proj(output)
    out = torch.matmul(weights, value) # (B, T, T) @ (B, T, head_size) -> (B, T, head_size)
    return out
    # return output

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, num_heads, head_size):
    super().__init__()
    self.heads = nn.ModuleList(SelfAttentionHead(head_size) for _ in range(num_heads))
    self.proj = nn.Linear(head_size * num_heads, n_embd)
    self.dropout = nn.Dropout(dropout)
    
  def forward(self, x):
    out = torch.cat([h(x) for h in self.heads], dim=-1)
    out = self.dropout(self.proj(out))
    return out

class MultiHeadFlexAttention(nn.Module):

    def __init__(self, num_heads, head_size, n_embd, block_size, dropout, device, sliding_window_len):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.n_embd = n_embd
        self.device = device

        # define causal_mask inline or pass in
        def causal_mask(b, h, q_idx, kv_idx):
            return q_idx >= kv_idx
          
        def sliding_window(b,h,q_idx,kv_idx):
          return q_idx-kv_idx <= sliding_window_len
        
        def mask(b,h,q_idx,kv_idx):
          return causal_mask(b,h,q_idx,kv_idx) | sliding_window(b,h,q_idx,kv_idx)

        # mask = torch.nn.attention.and_masks(causal_mask, sliding_window)
        # self.register_buffer('causal_mask', mask)
        self.block_mask = create_block_mask(mask, 1, 1, block_size, block_size, device=device)

        # projections project full embedding into all heads
        self.k = nn.Linear(n_embd, head_size * num_heads, bias=False)
        self.q = nn.Linear(n_embd, head_size * num_heads, bias=False)
        self.v = nn.Linear(n_embd, head_size * num_heads, bias=False)
        self.o_proj = nn.Linear(head_size * num_heads, n_embd, bias=False)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        B, T, C = x.shape
        nh, hs = self.num_heads, self.head_size

        # project
        k = self.k(x).view(B, T, nh, hs).transpose(1,2)
        q = self.q(x).view(B, T, nh, hs).transpose(1,2)
        v = self.v(x).view(B, T, nh, hs).transpose(1,2)

        # crop mask to actual sequence length
        print(T)
        self.block_mask = self.block_mask._adjust(T, T)
        print(self.block_mask.shape)

        # call flex attention
        out = flex_attention.flex_attention(q, k, v, block_mask=self.block_mask, enable_gqa=True)

        # merge heads back
        out = out.transpose(1, 2).contiguous().view(B, T, nh*hs)
        out = self.o_proj(out)
        return self.dropout(out)

In [13]:
class FFN(nn.Module):
  def __init__(self, n_embd):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embd, 4 * n_embd),
        nn.ReLU(),
        nn.Linear(4 * n_embd, n_embd),
        nn.Dropout(dropout),
    )
    
  def forward(self, x):
    return self.net(x)

    
class MoEFFN(nn.Module):
  def __init__(self, n_embd, num_experts=4, num_experts_per_token=2):
    super().__init__()
    self.num_experts_per_token = num_experts_per_token
    self.num_experts = num_experts
    self.experts = nn.ModuleList([FFN(n_embd) for _ in range(num_experts)])
    self.gate = nn.Linear(n_embd, num_experts)
    
  def forward(self, x):
    B, T, C = x.shape
    gate_scores = F.softmax(self.gate(x), dim=-1) # (B, T, num_experts)
    expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=-1) # (B, T, C, num_experts)
    # print(expert_outputs.shape, gate_scores.shape)
    topk_scores, topk_indices = torch.topk(gate_scores, self.num_experts_per_token, dim=-1) # (B, T, 2)
    top_probs = F.softmax(topk_scores, dim=-1) # (B, T, 2)
    expert_outputs = torch.gather(expert_outputs, 3, topk_indices.unsqueeze(2).expand(-1, -1, C, -1)) # (B, T, C, 2)
    out = (expert_outputs * top_probs.unsqueeze(2)).sum(dim=-1)  # (B, T, C)
    return out

class EfficientMoEFFN(nn.Module):
    def __init__(self, n_embd, num_experts=4, num_experts_per_token=2):
        super().__init__()
        self.num_experts_per_token = num_experts_per_token
        self.num_experts = num_experts
        self.experts = nn.ModuleList([FFN(n_embd) for _ in range(num_experts)])
        self.gate = nn.Linear(n_embd, num_experts)

    def forward(self, x):
        B, T, C = x.shape
        x_flat = x.view(B*T, C)  # Flatten tokens to (batch*tokens, d_model)

        # Gating
        gate_scores = self.gate(x_flat)   # (B*T, num_experts)
        topk_scores, topk_indices = torch.topk(
            gate_scores, self.num_experts_per_token, dim=-1
        )  # (B*T, k)
        topk_probs = F.softmax(topk_scores, dim=-1)  # (B*T, k), normalized

        # Output buffer
        out = torch.zeros_like(x_flat)

        # For each expert: route only the tokens assigned to it
        for expert_id, expert in enumerate(self.experts):
            # Find where this expert is selected
            mask = (topk_indices == expert_id)  # (B*T, k)
            if not mask.any():
                continue # if it's not part of the top k selected experts for any token, skip it

            token_ids, which_slot = mask.nonzero(as_tuple=True)

            # Select actual tokens
            tokens_for_expert = x_flat[token_ids]

            # Apply expert FFN
            expert_out = expert(tokens_for_expert)  # (num_tokens, C)

            # Scale by probability
            probs = topk_probs[token_ids, which_slot].unsqueeze(-1)
            expert_out = expert_out * probs

            # Scatter-add back to output buffer
            out.index_add_(0, token_ids, expert_out)

        return out.view(B, T, C)

In [14]:
class Block(nn.Module):
  # block where you have mha and feedforward then layer normalization
  def __init__(self, n_embd, n_head):
    super().__init__()
    head_size = n_embd // n_head
    # self.sa = MultiHeadAttention(n_head, head_size)
    self.sa = MultiHeadFlexAttention(n_head, head_size, n_embd, block_size, dropout, device, sliding_window_len)
    # self.ffwd = FeedForward(n_embd)
    self.ffwd =EfficientMoEFFN(n_embd, num_experts=4)
    self.ln1 = nn.LayerNorm(n_embd)
    self.ln2 = nn.LayerNorm(n_embd)

  def forward(self, x):
    x = x + self.sa(self.ln1(x))
    x = x + self.ffwd(self.ln2(x))
    return x

In [15]:
class LanguageModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.token_embed_table = nn.Embedding(vocab_size, n_embd)
    self.position_embed_table = nn.Embedding(block_size, n_embd)
    self.blocks = nn.Sequential(
      *[Block(n_embd, n_head) for _ in range(n_layer)])
    self.ln_f = nn.LayerNorm(n_embd) # final layer norm
    self.lm_head = nn.Linear(n_embd, vocab_size)

  def forward(self, idx, targets=None):
    B, T = idx.shape
    token_emb = self.token_embed_table(idx) # (B, T, n_embd)
    position_emb = self.position_embed_table(torch.arange(T, device=device))

    x = token_emb + position_emb # (B, T, n_embd)
    x = self.blocks(x) # (B, T, n_embd)
    x = self.ln_f(x) # (B, T, n_embd)
    logits = self.lm_head(x) # (B, T, vocab_size)
    
    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)

    return logits, loss
  
  def generate(self, idx, max_new_tokens):
    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):
      # if(print_characters):
      #   print(f"\r{decode(idx[-1].tolist())}", end="", flush=True) 
      # crop idx to the last block_size tokens
      idx_cond = idx[:, -block_size:]
      # get the predictions
      logits, loss = self(idx_cond)
      # focus only on the last time step
      logits = logits[:, -1, :] # becomes (B, C)
      # apply softmax to get probabilities
      probs = F.softmax(logits, dim=-1) # (B, C)
      # sample from the distribution
      idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
      # append sampled index to the running sequence
      idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

In [16]:
model = LanguageModel().to(device)
print(sum(p.numel() for p in model.parameters())/1e6, 'Million Model Parameters')

optim = torch.optim.AdamW(model.parameters(), lr=learning_rate)

def generate_streaming(model, context, max_new_tokens):
    max_new_tokens = max_new_tokens-1
    for _ in range(max_new_tokens):
      context = model.generate(context, max_new_tokens=1)
      generated = decode(context[0].tolist())[-1]
      print(f"{generated}", end="", flush=True)
    print()
    return generated

32.064089 Million Model Parameters


In [17]:
for iter in range(max_iters):
  print(f"\r Iter {iter+1}/{max_iters}", end="", flush=True)
  x_data, y_data = get_batch(is_train=True)
  logits, loss = model(x_data, y_data)
  optim.zero_grad(set_to_none=True)
  loss.backward()
  optim.step()
#train loop
torch.save(model.state_dict(), 'autocomplete/models/model_v3_flex_attn.pth')

 Iter 6/5000

KeyboardInterrupt: 

In [18]:
# model.load_state_dict(torch.load('autocomplete/models/model_lecture_style.pth'))
model.load_state_dict(torch.load('autocomplete/models/model_v3_flex_attn.pth', map_location=device))

<All keys matched successfully>

In [19]:
context = torch.zeros((1,1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=500)[0].tolist()))

AttributeError: 'MultiHeadFlexAttention' object has no attribute 'mask'

In [None]:
text_to_be_continued = "ROMEO"
context = torch.tensor(encode(text_to_be_continued), dtype=torch.long, device=device).unsqueeze(0)
# print(decode(model.generate(context, max_new_tokens=5000)[0].tolist()))
# model.generate(context, max_new_tokens=50,print_characters=True)
generate_streaming(model, context, max_new_tokens=5000)

:
The first sends between the warld of itself,
That the prepared for my kinsmen's delight,
Re-quicken'd her shore, her on God,
Incapable to an impeach shall my cherish here
A mile richest shepherd.

BUSHY:
Cousin, I pray you, lords; and this grown business
It would but see term than my reported:
Such that in slain them restored me wrath
Thus hurrip'd with magnany a dish death
Throw upon thy house.

RICHMOND:
It is my lady; I'll bear weeping you to the grave:
Then but valiant first and witeight.
But, love, go, come by Minitio.
Hasting title our san;
Say is better 's un gravator greypard's sains;
There's no more of it. Pray you, s
Havi been your humility, and you deliver
Margary to sanctuary fortune, and chaste
That murdering your handship subject?
For sorrow'st of royalive, when I loved how;
I'll follow me for what I hammerchanced an happy;
For I will give you on his knees.

NORFOLOLANUS:
Mark, my sovereign neighbours, and theirs lie,
His answer'd his hands: in a great day for Roques,
S

'\n'