In [2]:
from typing import Optional
from functools import partial

from tinygrad import Tensor, nn
from tinygrad.nn.state import get_parameters
from tinygrad.helpers import Timing, GlobalCounters

In [5]:
def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
  bs, seqlen, n_kv_heads, head_dim = x.shape
  if n_rep == 1: return x
  # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
  return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)

In [6]:
class Attention:
  def __init__(self, dim, n_heads, n_kv_heads=None, max_context=1024, linear=nn.Linear):
    self.n_heads = n_heads
    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
    self.head_dim = dim // n_heads
    self.n_rep = self.n_heads // self.n_kv_heads
    self.max_context = max_context

    self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
    self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
    self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
    self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)

  def __call__(self, x:Tensor, start_pos:int, mask:Optional[Tensor]=None) -> Tensor:
      xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
      xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
      xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
      xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)

      # todo: add RoPE
      bsz, seqlen, _, _ = xq.shape

      # create kv cache
      if not hasattr(self, "cache_kv"):
        self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()

      # update the cache
      self.cache_kv[:, :, start_pos:start_pos+seqlen, :, :].assign(Tensor.stack(xk, xv)).realize()

      keys = self.cache_kv[0, :, 0:start_pos+seqlen, :, :]
      values = self.cache_kv[1, :, 0:start_pos+seqlen, :, :]

      keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
      xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
      attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2)
      attn = attn.reshape(bsz, seqlen, -1)
      return self.wo(attn)

In [7]:
def stats(st:float, et:float):
    dur = GlobalCounters.time_sum_s-st
    msg = f", {dur*1e3:.2f} ms on GPU"
    msg += f", {GlobalCounters.global_ops*1e-9:.2f} GOPs"
    msg += f", {GlobalCounters.global_mem*1e-9:.2f} GB"
    msg += f", {(0.0 if GlobalCounters.global_mem == 0.0 else GlobalCounters.global_ops / GlobalCounters.global_mem):.2f} OPs/byte"
    # msg += f", {(0.0 if dur == 0 else GlobalCounters.global_mem*1e-9/dur):.2f} GB/s"
    # msg += f", param {(0.0 if dur == 0 else param_bytes*1e-9/dur):.2f} GB/s"
    return msg

In [8]:
dim = 1024
n_heads = 64
max_context = 1024
model = Attention(dim, n_heads, max_context=max_context)
param_bytes = sum(x.lazydata.size * x.dtype.itemsize for x in get_parameters(model))
print(f"model params take up {param_bytes*1e-9} GB")

GlobalCounters.reset()
st = GlobalCounters.time_sum_s
with Timing("Realize Model (no kv cache) in", on_exit=partial(stats, st)):
    # realize wq, wk, wv, wo
    for x in get_parameters(model):
        x.realize()
        print(x.shape)
        print(x.lazydata.size)
        print(x.lazydata.dtype.itemsize)

bsz, seqlen, dim = 1, 8, 1024
start_pos = 0

# with Timing("total ", on_exit=lambda x: f", {1e9/x:.2f} tok/s, {GlobalCounters.global_mem/x:.2f} GB/s, param {param_bytes/x:.2f} GB/s"):
with Timing("Attention in ", on_exit=partial(stats, st)):
    x = Tensor.empty(bsz, seqlen, dim)
    out = model(x, start_pos)
    out.numpy()

model params take up 0.016777216 GB
(1024, 1024)
1048576
4
(1024, 1024)
1048576
4
(1024, 1024)
1048576
4
(1024, 1024)
1048576
4
Realize Model (no kv cache) in503.80 ms, 0.00 ms on GPU, 0.93 GOPs, 0.13 GB, 6.96 OPs/byte
Attention in 178.00 ms, 0.00 ms on GPU, 1.00 GOPs, 0.17 GB, 5.95 OPs/byte


In [9]:
for x in get_parameters(model):
    print(x.shape)

(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(2, 1, 1024, 64, 16)


In [10]:
class MLP:
    def __init__(self, d1, d2, d3):
        self.l1, self.l2 = nn.Linear(d1, d2),  nn.Linear(d2, d3)
    def __call__(self, x): return self.l2(self.l1(x).relu())

model = MLP(1024, 2048, 10)
param_bytes = sum([x.lazydata.size * x.dtype.itemsize for x in get_parameters(model)])
print(f'param bytes: {param_bytes*1e-9:.2f} GB')

GlobalCounters.reset()
for x in get_parameters(model): x.realize()
print(f"global mem: {GlobalCounters.global_mem*1e-9:.2f} GB")
print(f"mem used: {GlobalCounters.mem_used*1e-9:.2f} GB")

param bytes: 0.01 GB
global mem: 0.07 GB
mem used: 0.01 GB


I noticed there is a difference between the bytes allocated for the model parameters and the bytes allocated globally for everything. Specifically, why do we allocate 0.06 GBs more for the global memory than for just the model params? What is taking up this extra memory? Where does this overhead come from? I tried reseting the global memory but that didn't work. Is this because we have also allocated space for all of the nested uops and all the complexity of the tinygrad stack?