# imports

In [1]:
import os
import math
import time
import random
import json
import gc
import requests
from io import BytesIO
from PIL import Image
from dataclasses import dataclass
from typing import Optional, Tuple, List
from sentencepiece import SentencePieceProcessor
import torch
from torch import nn, tensor
import torch.autograd.profiler as profiler
import torch.nn.init as init
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import matplotlib.pyplot as plt
from clip import load
import logging

logging.basicConfig(filename='log.txt', level=logging.INFO, filemode='w')

# model secondary

In [2]:
@dataclass
class ModelArgs:
  dim: int = 4096
  n_layers: int = 32
  n_heads: int = 32
  n_kv_heads: Optional[int] = None
  vocab_size: int = 32000
  multiple_of: int = 256  
  ffn_dim_multiplier: Optional[float] = None
  norm_eps: float = 1e-5
  max_batch_size: int = 32
  max_seq_len: int = 2048 # 1024
  img_len: int = 577
  img_dim: int = 1024
  max_gen_len: int = 32

class Tokenizer:
  def __init__(self, model_path: str):
    assert os.path.isfile(model_path), model_path
    self.sp_model = SentencePieceProcessor(model_file=model_path)

    self.n_words: int = self.sp_model.vocab_size()
    self.bos_id: int = self.sp_model.bos_id()
    self.eos_id: int = self.sp_model.eos_id()
    self.pad_id: int = self.sp_model.pad_id()
    assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

  def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
    assert type(s) is str
    t = self.sp_model.encode(s)
    if bos:
      t = [self.bos_id] + t
    if eos:
      t = t + [self.eos_id]
    return t

  def decode(self, t: List[int]) -> str:
    return self.sp_model.decode(t)    

class RMSNorm(torch.nn.Module):
  def __init__(self, dim: int, eps: float = 1e-6):
    super().__init__()
    self.eps = eps
    self.weight = nn.Parameter(torch.ones(dim))

  def _norm(self, x):
    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

  def forward(self, x):
    output = self._norm(x.float()).type_as(x)
    return output * self.weight

class MLP(nn.Module):
  def __init__(self, args):
    super().__init__()
    self.w1 = nn.Linear(args.img_dim, args.dim, bias=True).to(dtype=torch.float16)
    self.gelu = nn.GELU()
    self.w2 = nn.Linear(args.dim, args.dim, bias=True).to(dtype=torch.float16)

  def forward(self, x):
    h = self.w1(x)
    h = self.gelu(h)
    h = self.w2(h)
    return h

class FeedForward(nn.Module):
  def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None):
    super().__init__()
    hidden_dim = int(2 * hidden_dim / 3)
    if ffn_dim_multiplier is not None:
      hidden_dim = int(ffn_dim_multiplier * hidden_dim)
    hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
    self.w1 = nn.Linear(dim, hidden_dim, bias=False)
    self.w2 = nn.Linear(hidden_dim, dim, bias=False)
    self.w3 = nn.Linear(dim, hidden_dim, bias=False)

  def forward(self, x):
    x = checkpoint(lambda x: F.silu(self.w1(x)) * self.w3(x), x)
    return self.w2(x)      

# model primary

In [3]:
class Attention(nn.Module):
  def __init__(self, args):
    super().__init__()
    self.n_heads = args.n_heads
    self.head_dim = args.dim // args.n_heads
    # print('self.head_dim', self.head_dim)
    self.wq = nn.Linear(args.dim, self.n_heads * self.head_dim, bias=False)
    self.wk = nn.Linear(args.dim, self.n_heads * self.head_dim, bias=False)
    self.wv = nn.Linear(args.dim, self.n_heads * self.head_dim, bias=False)
    self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)  

  def reshape_for_broadcast(self, freqs_cis, x):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

  def apply_rotary_emb(self, xq, xk, freqs_cis):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

  def forward(self, x, freqs_cis, mask=None):
    bsz, seqlen, _ = x.shape
    # apply linear layer with checkpointing
    xq = checkpoint(lambda x: self.wq(x), x)
    # print('xq.shape', xq.shape)
    xk = checkpoint(lambda x: self.wk(x), x)
    xv = checkpoint(lambda x: self.wv(x), x)
    # reshape to heads and head dims
    xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
    # print('xq.shape', xq.shape)
    xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
    xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)
    # apply rotary encoding
    xq, xk = self.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
    # transpose heads to 2D for easier computation
    xq = xq.transpose(1, 2)
    xk = xk.transpose(1, 2)
    xv = xv.transpose(1, 2)
    # dot product of q v, scaled by the sqrt of head dims
    dot_product = torch.matmul(xq, xk.transpose(2, 3))
    scores = dot_product / math.sqrt(self.head_dim)
    # apply masks
    if mask is not None:
      scores = scores + mask
    # softmax to get attn scores
    scores = F.softmax(scores.float(), dim=-1).type_as(xq)
    # get weighted sum from scores and value
    output = torch.matmul(scores, xv)
    # reshape back to original shape, apply linear layer
    output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
    return self.wo(output)

# transformer, block

In [4]:
class TransformerBlock(nn.Module):
  def __init__(self, layer_id: int, args):
    super().__init__()
    self.n_heads = args.n_heads
    self.dim = args.dim
    self.head_dim = args.dim // args.n_heads
    self.attention = Attention(args)
    self.feed_forward = FeedForward(
      dim=args.dim,
      hidden_dim=4 * args.dim,
      multiple_of=args.multiple_of,
      ffn_dim_multiplier=args.ffn_dim_multiplier,
    )
    self.layer_id = layer_id
    self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
    self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

  def forward(self, x, freqs_cis, mask=None, ):
    h = x + self.attention.forward(
      self.attention_norm(x), freqs_cis, mask
    )
    out = h + self.feed_forward.forward(self.ffn_norm(h))
    return out


class Transformer(nn.Module):
  def __init__(self, params):
    super().__init__()
    self.clip, _ = load("ViT-L/14@336px")
    self.mlp = MLP(params)
    self.params = params
    self.vocab_size = params.vocab_size
    self.n_layers = params.n_layers
    self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
    self.layers = torch.nn.ModuleList()
    self.freqs_cis = self.precompute_freqs_cis(
      self.params.dim // self.params.n_heads, 
      self.params.max_seq_len * 2
    )
    
    for layer_id in range(params.n_layers):
      self.layers.append(TransformerBlock(layer_id, params))
    self.norm = RMSNorm(params.dim, eps=params.norm_eps)
    self.output = nn.Linear(params.dim, params.vocab_size, bias=False)

  def precompute_freqs_cis(self, dim, end, theta=10000.0):
    indices = torch.arange(0, dim, 2) 
    sliced_indices = indices[: (dim // 2)].float()
    scaled_indices = sliced_indices / dim
    theta_power = theta ** scaled_indices
    freqs = 1.0 / theta_power
    t = torch.arange(end, device="cuda")
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

  # @torch.inference_mode()
  def forward(self, toks, imgs=None): 
    _bsz, seqlen = toks.shape
    h = self.tok_embeddings(toks)
    
    if imgs != None:
      image_encoded = self.clip.encode_image(imgs).to(h.device)
      image_encoded.detach()
      image_projected = self.mlp(image_encoded)
      image_projected = image_projected.expand(_bsz, -1, -1)
      seqlen += image_projected.size(1)
      
      h_before = h[:, :1, :]
      h_after = h[:, 1:, :]
      h = torch.cat([h_before, image_projected, h_after], dim=1)
    
    self.freqs_cis = self.freqs_cis.to(h.device)
    freqs_cis = self.freqs_cis[:seqlen]
    mask = None
    
    if seqlen > 1:
      mask = torch.full((seqlen, seqlen), float("-inf"), device=toks.device)
      mask = torch.triu(mask, diagonal=1)
      mask = torch.hstack([
        torch.zeros((seqlen, 0), device=toks.device), mask
      ]).type_as(h)
    
    for layer in self.layers:
      h = layer(h, freqs_cis, mask)
    h = self.norm(h)
    output = self.output(h).float()
    return output

# hooks

In [14]:
def register_act_hook(module, layer_filter=None):
  activations = {}
  def hook(module, input, output):
    if module not in activations:
      ACT_DICT[module] = output.detach()
    return hook   
  for name, module in model.named_modules():
    if layer_filter:
      if layer_filter in name:
        module.register_forward_hook(register_hook(module))
    else:
      module.register_forward_hook(register_hook(module))
  return activations

def register_mem_hooks(module, module_name_prefix='', sub=False):
  def forward_hook(module, input, output):
    log_message = f"\n{module_name_prefix} - forward:\n"
    log_message += f"Layer: {module.__class__.__name__}\n"
    if hasattr(input[0], 'shape'):
      log_message += f"Input shape: {input[0].shape}\n"
    if hasattr(output, 'shape'):
      log_message += f"Output shape: {output.shape}\n"
    if hasattr(module, 'weight'):
      log_message += f"Weight mean: {module.weight.data.mean()}, std: {module.weight.data.std()}\n"
    logging.info(log_message)

  def backward_hook(module, grad_input, grad_output):
    log_message = f"{module_name_prefix} - backward:\n"
    logging.info(log_message)

  module.register_forward_hook(forward_hook)
  module.register_full_backward_hook(backward_hook)

  if sub:
    for name, child in module.named_children():
      new_prefix = f"{module_name_prefix}.{name}" if module_name_prefix else name
      register_mem_hooks(child, new_prefix, sub)

# build, generate, test

In [6]:
def build(model_args):
  start_time = time.time()
  torch.cuda.set_device(0)
  torch.set_default_tensor_type(torch.cuda.HalfTensor)  
  ckpt = torch.load("consolidated.00.pth", map_location="cuda")
  model = Transformer(model_args)
  model.load_state_dict(ckpt, strict=False)
  print(f"llama loaded in {time.time() - start_time:.2f} seconds")
  return model

def generate(model, tkzr, model_args, tkns, image=None, max_gen=32):
  start_time = time.time()
  bsz = len(tkns)
  if max_gen:
    max_gen_len = max_gen
  else:  
    max_gen_len = model_args.max_seq_len - 1
  min_tkn_len = min(len(t) for t in tkns)
  max_tkn_len = max(len(t) for t in tkns)
  ttl_len = min(model_args.max_seq_len, max_gen_len + max_tkn_len)

  pad_id = tkzr.pad_id
  gen_tkns = torch.full((bsz, ttl_len), pad_id, dtype=torch.long)
  for k, t in enumerate(tkns):
    gen_tkns[k, : len(t)] = torch.tensor(t, dtype=torch.long)
  eos_reached = torch.tensor([False] * bsz)
  input_text_mask = gen_tkns != pad_id
  
  for cur_pos in range(min_tkn_len, ttl_len):
    logits = model.forward(gen_tkns[:, :cur_pos])#, image)
    nxt_tkn = torch.argmax(logits[:, -1], dim=-1)
    nxt_tkn = nxt_tkn.reshape(-1)
    nxt_tkn = torch.where(input_text_mask[:, cur_pos], gen_tkns[:, cur_pos], nxt_tkn)
    print(nxt_tkn)
    gen_tkns[:, cur_pos] = nxt_tkn
    eos_reached |= (~input_text_mask[:, cur_pos]) & (nxt_tkn == tkzr.eos_id)
    if all(eos_reached): break

  out_tkns = []
  for i, t in enumerate(gen_tkns.tolist()):
    tkns_len = len(tkns[i])
    t = t[tkns_len : tkns_len + max_gen_len]
    if tkzr.eos_id in t:
      eos_idx = t.index(tkzr.eos_id)
      t = t[:eos_idx]
    out_tkns.append(t)  
  print(f"generated in {time.time() - start_time:.2f} seconds")
  return out_tkns

def test(model, tkzr, model_args):
  prompts = {
    'txt' : [
      "Simply put, the theory of relativity states that", # the laws of physics are the same for all non-accelerating observers, regardless of their state of motion or their energy content.
      "Long ago there lived a magical cat named Puss" # in Boots. Puss in Boots was a very clever cat. He was so clever that he could talk. He was so clever that he could talk
    ],
    'img' : ["Simply put, the theory of relativity states that"],
    'train' : [
      "image label:", "image label:",
      "photo description:", "photo description:",
      "image title:", "image title:",
      "picture summary:", "picture summary:",
    ]
  }
  tkns = [tkzr.encode(x, bos=True, eos=False) for x in prompts['txt']]
  text_out_tkns = generate(model, tkzr, model_args, tkns, image=None)
  [print(tkzr.decode(t)) for t in text_out_tkns]    

def train(model, train_len=1, bsz=4, accum=2):  
  
  for n in range(1, train_len + 1):
    src_txt, src_img, tgt = ds[n]
    
    split_idx = src_txt.shape[1] - 1
    if split_idx == 0: continue
      
    logits = model.forward(src_txt, src_img)
    
    logits = logits[:, -split_idx:, :]
    logits = logits.reshape(-1, logits.size(-1))
    
    tgt_prefix = src_txt[:, 2:]
    tgt = tgt.unsqueeze(1)
    tgt = torch.cat((tgt_prefix, tgt), dim=1)
    tgt = tgt.reshape(-1)
    
    loss = loss_fn(logits, tgt)
    loss_value = loss.item()
      
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(f'\n{n}\t{loss_value}')

# dataset

In [7]:
class Dataset():
  def __init__(self, seq_len=32, bsz=4, shuffle=True):
    self.clip, self.image_pre = load("ViT-L/14@336px")
    self.tkzr = Tokenizer('tokenizer.model')
    self.seq_len = seq_len

    self.prompts = ["image label: ", "photo description: ","image title: ", "picture summary: "]

    with open('data/metadata.json', 'r') as file: 
      data = json.load(file)
    ds = [x for x in data if x.get('image') and x.get('blip_caption')]
    if shuffle: random.shuffle(ds)
    self.bsz_mult = 3
    self.bsz = bsz * self.bsz_mult
    ds = [ds[i:i + self.bsz] for i in range(0, len(ds) - len(ds) % self.bsz)] 
    self.ds = ds
    self.index = 0

  def __len__(self):
    return len(self.ds)

  def __iter__(self):
    self.index = 0
    return self

  def __next__(self):
    if self.index < len(self.ds):
      result = self.__getitem__(self.index)
      self.index += 1
      return result
    else:
      raise StopIteration

  def __getitem__(self, idx):
    batch = self.ds[idx]
    text_toks = [self.tkzr.encode(d['blip_caption'], bos=False, eos=False) for d in batch]
    min_len = min(len(x) for x in text_toks)
    # print('min_len', min_len)

    _images = []
    _text = []
    _target = []
    for data in batch:
      if len(_images) == self.bsz / self.bsz_mult:
        break
      img = Image.open(f"data/images/{data['image']}")
      img = self.image_pre(img).unsqueeze(0).to("cuda")
      _images.append(img)

      txt = data['blip_caption']
      prefix = random.choice(self.prompts)
      txt = prefix + txt
      txt = self.tkzr.encode(data['blip_caption'], bos=True, eos=False)
      txt = txt[: min_len]
      _target.append(txt[-1])
      txt = txt[:-1]
      txt = torch.tensor(txt, dtype=torch.long, device="cuda")
      _text.append(txt)
      
    tgt = torch.tensor(_target, dtype=torch.long, device="cuda")
    src_img = torch.cat(_images, dim=0)
    src_txt = torch.stack(_text, dim=0)
    return src_txt, src_img, tgt
# ds = Dataset(bsz=4)
# src_txt, src_img, tgt = ds[0]
# print('src_txt, src_img, tgt', src_txt.shape, src_img.shape, tgt.shape)

# train

# init

In [47]:
clip, img_pre = load("ViT-L/14@336px")
TKZR = Tokenizer("tokenizer.model")
DS = Dataset(bsz=4)

plt.style.use('dark_background')

In [16]:
if 'model' in locals() or 'model' in globals():
  del model
  gc.collect()
torch.cuda.empty_cache()

MODEL_ARGS = ModelArgs(max_batch_size=4)
MODEL = build(MODEL_ARGS)

OPTIMIZER = optim.SGD(MODEL.parameters(), lr=1e-4, momentum=0.9)
LOSS_FN = nn.CrossEntropyLoss()

for name, param in MODEL.named_parameters():
  if 'clip' in name:
    param.requires_grad = False

llama loaded in 9.45 seconds


# introspect

In [122]:
def debug(action, layer_filter=None):
  model, tkzr, loss_fn, optimizer = MODEL, TKZR, LOSS_FN, OPTIMIZER
  optimizer.zero_grad()
    
  activations, layer_names = [], []
  def hook_fn(module, input, output):
    activations.append(output)
    layer_names.append(module.__class__.__name__)

  hooks = []
  if action == 'activation':
    for name, layer in model.named_modules():
      if layer_filter in name:
        hook = layer.register_forward_hook(hook_fn)
        hooks.append(hook)
        
  prompts = ["Simply put, the theory of relativity states that", "Long ago there lived a magical cat named Puss"]
  toks = [tkzr.encode(x, bos=True, eos=False) for x in prompts]
  toks = [torch.tensor(x) for x in toks]
  toks = torch.stack(toks, dim=0)
  src = toks[:,:-1]
  tgt = toks[:,1:].reshape(-1)

  logits = model.forward(src) 
  logits = logits.reshape(-1, logits.size(-1)) 
  loss = loss_fn(logits, tgt)
  loss.backward()
    
  if action == 'logits': 
    optimizer.zero_grad()
    logits = model.forward(src) 
    # loss = loss_fn(logits, tgt) ; print(f'loss:\t{loss.item()}')
    # Logic for examining logits
    last_logits = logits[:, -1, :]
    
    values, idxs = torch.topk(last_logits, 6, dim=1)  # Get top 6 logits and their indices

    for i, seq in enumerate(idxs.tolist()):
      print(f'\n{tkzr.decode(toks[i].tolist())}')
      print(f'last tok: [{tkzr.decode(toks[i].tolist()[-1:])}]')
      for j, tok in enumerate(seq):
        print(f'[{tkzr.decode([tok])}] \t {values[i, j].item()}')
    return
      
    # d = [tkzr.decode([i]) for i in   

      
    # decod = tkzr.decode(top_indices.tolist()[0]) ; print(decod) ; return
    decod = tkzr.decode(top_indices.tolist()[1][2:3]) ; print(decod) ; return
    decoded_tokens = [tkzr.decode(x) for x in top_indices.tolist()]

  if action == 'overfit':
    optimizer.zero_grad()
    for n in range(1, 20):
      logits = model.forward(src) 
      logits = logits.reshape(-1, logits.size(-1)) 
      loss = loss_fn(logits, tgt) ; print(f'{n}\t{loss.item()}')
      loss.backward()     
      optimizer.step()
          
  if action in ['weight', 'gradient']:
    print(f'i\tmean\t\tstd\t\tlayer_name')  
    plt.figure(figsize=(20, 4))
    for i, (name, param) in enumerate(model.named_parameters()):
      if param.grad is not None and (layer_filter is None or layer_filter in name):
        data = param.data if viz_type == 'weight' else param.grad
        print(f'{i}\t{data.mean():+f}\t{data.std():e}\t{name}')
        hy, hx = torch.histogram(data.cpu().float(), density=True)
        plt.plot(hx[:-1].detach(), hy.detach()) ; plt.show()

  if action == 'activation':
    print(f'i\tmean\t\tstd\t\tsat\t\tlayer_name')    
    plt.figure(figsize=(20, 4))
    for i, activation in enumerate(activations):
      act = activation.detach().cpu().float()
      hy, hx = torch.histogram(act, density=True)
      print(f'{i}\t {layer_names[i]:10}, mean: {act.mean():+.4f}, std: {act.std():e}, saturated: {(act.abs() > 0.97).float().mean() * 100:.2f}%')
      plt.plot(hx[:-1], hy) ; plt.show()
    [hook.remove() for hook in hooks]

  if action == 'weight_update':
    optimizer.zero_grad()      
    logits = model.forward(src) 
    logits = logits.reshape(-1, logits.size(-1)) 
    loss = loss_fn(logits, tgt) ; print(f'loss:\t{loss.item()}')
    loss.backward()     

    pre_update_weights = {}
    for name, param in model.named_parameters():
      if (layer_filter is None or layer_filter in name) and param.grad is not None:
        pre_update_weights[name] = param.clone()         

    optimizer.step()

    print(f'i\tmean\t\tstd\t\tlayer_name')          
    for i, (name, param) in enumerate(model.named_parameters()):
      if name in pre_update_weights:
        update = param - pre_update_weights[name]
        if param.grad is not None:  # Ensure there is a gradient
          epsilon = 1e-8
          normalized_update = update / (pre_update_weights[name].abs() + epsilon)
          
          mean_normalized_update = normalized_update.mean().item()
          std_normalized_update = normalized_update.std().item()
          
          print(f'{i}\t{mean_normalized_update:+.4e}\t{std_normalized_update:e}\t{name}')      
          hy, hx = torch.histogram(normalized_update.cpu().float(), density=True)
          plt.plot(hx[:-1].detach(), hy.detach(), label=name)        
    optimizer.zero_grad()
  
  plt.show()

debug('logits')


Simply put, the theory of relativity states that
last tok: [that]
[that] 	 26.953125
[,] 	 19.984375
[the] 	 18.6875
[there] 	 18.328125
[:] 	 17.875
[a] 	 17.84375

Long ago there lived a magical cat named Puss
last tok: [uss]
[uss] 	 25.46875
[us] 	 17.921875
[ush] 	 16.625
[urr] 	 16.40625
[ete] 	 15.5625
[ink] 	 15.171875
