In [2]:
from typing import Optional, Tuple
from dataclasses import dataclass
import math
import torch
from torch import nn
import torch.nn.functional as F
import hiq
from llama import ModelArgs, Transformer, Tokenizer, LLaMA
from pathlib import Path


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer_path = '/home/nishantbhansali/MyProject/model/tokenizer.model'
ckpt_dir = '/home/nishantbhansali/MyProject/model/7B'
tokenizer = Tokenizer(model_path=tokenizer_path)
print(tokenizer.n_words)

32000


In [3]:
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
ckpt_path = checkpoints[0]
checkpoint = torch.load(ckpt_path, map_location="cpu")

In [4]:
import json
max_seq_len = 1024
max_batch_size = 1

In [5]:
with open(Path(ckpt_dir) / "params.json", "r") as f:
        params = json.loads(f.read())

# create a model args object
# ModelArgs is a a simple dataclass that contains the parameters for the model
# file in llama/model_single.py
model_args: ModelArgs = ModelArgs(
    max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)

In [6]:
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
torch.set_default_tensor_type(torch.FloatTensor)
model.load_state_dict(checkpoint, strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['layers.0.attention.inner_attention.rope.freqs', 'layers.1.attention.inner_attention.rope.freqs', 'layers.2.attention.inner_attention.rope.freqs', 'layers.3.attention.inner_attention.rope.freqs', 'layers.4.attention.inner_attention.rope.freqs', 'layers.5.attention.inner_attention.rope.freqs', 'layers.6.attention.inner_attention.rope.freqs', 'layers.7.attention.inner_attention.rope.freqs', 'layers.8.attention.inner_attention.rope.freqs', 'layers.9.attention.inner_attention.rope.freqs', 'layers.10.attention.inner_attention.rope.freqs', 'layers.11.attention.inner_attention.rope.freqs', 'layers.12.attention.inner_attention.rope.freqs', 'layers.13.attention.inner_attention.rope.freqs', 'layers.14.attention.inner_attention.rope.freqs', 'layers.15.attention.inner_attention.rope.freqs', 'layers.16.attention.inner_attention.rope.freqs', 'layers.17.attention.inner_attention.rope.freqs', 'layers.18.attention.inner_attention.rope.freqs', 'layers.

In [13]:
prompts = [
        # For these prompts, the expected answer is the natural continuation of the prompt
        "I believe the meaning of life is",  # removed: keep only one prompt
    ]

max_gen_len=256
temperature=0.8
top_p=0.95
max_seq_len= 1024
max_batch_size= 1

In [14]:
bsz = len(prompts) # 1
params = model.params # those same ModelArgs
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

prompt_tokens = [tokenizer.encode(x, bos=True, eos=False) for x in prompts]
# [[1, 306, 4658, 278, 6593, 310, 2834, 338]]

min_prompt_size = min([len(t) for t in prompt_tokens]) # 8
max_prompt_size = max([len(t) for t in prompt_tokens]) # 8

total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) # 264

tokens = torch.full((bsz, total_len), tokenizer.pad_id).cuda().long()
# a tensor of size (1, 264) filled with -1's


In [15]:
for k, t in enumerate(prompt_tokens):
    tokens[k, : len(t)] = torch.tensor(t).cuda().long()
input_text_mask = tokens != tokenizer.pad_id # a tensor of size (1, 264) filled with True's ,
# where tokens is not -1, other wise False
start_pos = min_prompt_size # 8
prev_pos = 0

In [16]:
def sample_top_p(probs, p):
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(probs_idx, -1, next_token)
    return next_token

In [18]:
i = tokens[:, prev_pos:8]
print(i.shape)
logits = model(i, prev_pos)
print(logits.shape)

torch.Size([1, 8])
torch.Size([1, 32000])


In [20]:
tok_embeddings = nn.Embedding(model.params.vocab_size, model.params.dim)
tok_embeddings(torch.tensor([[1,2,3,4,5,6,7,8]])).shape

torch.Size([1, 8, 4096])

In [25]:
precompute_freqs_cis(128,2048)[0:0+1024].shape

torch.Size([1024, 64])

In [12]:
class Transformer(nn.Module):
    def __init__(self, params: ModelArgs):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) # (32_000, 4096)

        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))

        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = nn.Linear(params.dim, params.vocab_size, bias=False) # (4096, 32_000)

        self.freqs_cis = precompute_freqs_cis(
            self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
            # 4096 // 32 = 128, 1024 * 2
        ) # torch.Size([2048, 64])

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, start_pos: int):
        _bsz, seqlen = tokens.shape # (1,8)
        h = self.tok_embeddings(tokens) # (1,8,4096)
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] # torch.Size([1024, 64])

        mask = None
        if seqlen > 1:
            mask = torch.full(
                (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
            ) # (1,1,8,8) , filled with -inf
            mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
            # (1,1,8,8) , filled with -inf, but only the upper triangle, lower triangle is 0
            # diagnol = start_pos + 1, so the first 8 tokens are not masked, it basically pushes the diagonola above


        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)
        h = self.norm(h) # (1,8,4096)
        output = self.output(h[:, -1, :])  # only compute last logits # (1, 4096) * (4096, 32_000) = (1, 32_000)
        return output.float() # (1, 32_000)


In [28]:
mask = torch.full(
                (1, 1, 8, 8), float("-inf"), device=tokens.device
            )
mask = torch.triu(mask, diagonal=0 + 1)
mask

tensor([[[[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., 0., -inf, -inf, -inf],
          [0., 0., 0., 0., 0., 0., -inf, -inf],
          [0., 0., 0., 0., 0., 0., 0., -inf],
          [0., 0., 0., 0., 0., 0., 0., 0.]]]], device='cuda:0')

In [11]:
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
            # 4096, 4 * 4096, 256
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(
        self,
        x: torch.Tensor, # (1,8,4096)
        start_pos: int, # 0 (initially)
        freqs_cis: torch.Tensor, # (1024, 64)
        mask: Optional[torch.Tensor], # (1,1,8,8)
    ):
        # this is a skip connection
        h = x + self.attention.forward(
            self.attention_norm(x), start_pos, freqs_cis, mask
            # (1,8,4096), 0, (1024, 64), (1,1,8,8)
        ) # (1,8,4096)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out # (1,8,4096)

In [32]:
256 * (10922 + 256 - 1) // 256

11177

In [33]:
2 * 16384 / 3

10922.666666666666

In [7]:
class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int, # 4096
        hidden_dim: int, # 4 * 4096 = 16384
        multiple_of: int, # 256
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3) # 2 * 16384 / 3 = 10922
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) # 256 * (10922 + 256 - 1) // 256 = 11177

        self.w1 = nn.Linear(dim, hidden_dim, bias=False) # (4096, 11177)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False) # (11177, 4096)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False) # (4096, 11177)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x)) # (1,8,4096)

In [8]:

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_local_heads = args.n_heads // 1 # 32 // 1 = 32
        self.head_dim = args.dim // args.n_heads # 4096 // 32 = 128

        self.wq = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        ) # (4096, 4096)
        self.wk = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        ) # (4096, 4096)
        self.wv = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        )  # (4096, 4096)
        self.wo = nn.Linear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
        ) # (4096, 4096)
        self.cache_k = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
            # (1,1024,32,128)
        )
        self.cache_v = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
            # (1,1024,32,128)
        )
        if hiq.get_env_bool("KV_CAHCHE_IN_GPU", True):
            self.cache_k = self.cache_k.cuda()
            self.cache_v = self.cache_v.cuda()

    def forward(
        self,
        x: torch.Tensor, # (1,8,4096)
        start_pos: int, # 0 (initially)
        freqs_cis: torch.Tensor,  # (1024, 64)
        mask: Optional[torch.Tensor],  # (1,1,8,8)
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        # all of shape (1,8,4096)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) # (1,8,32,128)
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) # (1,8,32,128)
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) # (1,8,32,128)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) # (1,8,32,128), (1,8,32,128)

        self.cache_k = self.cache_k.to(xq) # (1,1024,32,128)
        self.cache_v = self.cache_v.to(xq) # (1,1024,32,128)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk # (1,1024,32,128)
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv # (1,1024,32,128)

        keys = self.cache_k[:bsz, : start_pos + seqlen] # (1,1024,32,128)
        values = self.cache_v[:bsz, : start_pos + seqlen] # (1,1024,32,128)

        xq = xq.transpose(1, 2) # (1,32,8,128)
        keys = keys.transpose(1, 2) # (1,32,1024,128)
        values = values.transpose(1, 2) # (1,32,1024,128)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) # (1,32,8,1024)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq) # (1,32,8,1024)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim) # (1,32,8,128)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) # (1,8,4096)

        return self.wo(output) # (1,8,4096)

In [9]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

In [31]:
xq = torch.randn(1,8,32,128)
xk = torch.randn(1,8,32,128)
freqs_ciss = torch.randn(1024,64)
print(freqs_ciss.shape)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_ciss)
print(xq.shape)
print(xk.shape)

torch.Size([1024, 64])


AssertionError: 

In [7]:
xq = torch.randn(1,8,32,128)
xk = torch.randn(1,8,32,128)
xv = torch.randn(1,8,32,128)
cache_k = torch.zeros(1,1024,32,128)
cache_v = torch.zeros(1,1024,32,128)

In [8]:
cache_k = cache_k.to(xq) # (1,1024,32,128)
cache_v = cache_v.to(xq) # (1,1024,32,128)
print(cache_k.shape)

torch.Size([1, 1024, 32, 128])


In [9]:
cache_k[:1, 0 : 0 + 8] = xk # (1,1024,32,128)
cache_v[:1, 0 : 0 + 8] = xv # (1,1024,32,128)

In [10]:
cache_k

tensor([[[[ 0.7062, -0.0164,  1.0580,  ..., -0.3052, -0.6578,  0.2834],
          [ 1.2989,  0.5767,  0.6553,  ..., -0.6626, -0.6152,  0.1348],
          [-1.8242, -0.2470, -1.0356,  ..., -0.3468, -0.0629, -1.7530],
          ...,
          [ 0.4131, -1.2540, -0.3932,  ...,  0.5691, -0.5716,  0.5818],
          [ 0.1771, -0.1443,  1.4504,  ...,  0.3732, -1.1157,  1.1206],
          [-0.3930, -1.0321,  0.4053,  ..., -0.1372,  1.8330, -0.4297]],

         [[-1.3268,  0.4006,  0.0696,  ..., -1.0020, -2.3518, -1.3615],
          [-0.1848, -0.0380,  0.4700,  ...,  0.9859,  0.9182,  0.7895],
          [ 2.6625, -0.2266,  1.1674,  ..., -0.3145,  0.0824,  1.5076],
          ...,
          [ 0.0749,  1.8606, -0.5492,  ...,  0.4825, -1.1236,  0.3204],
          [-1.4965,  0.3555,  0.2601,  ..., -0.4831, -0.3213, -0.5264],
          [-0.3599,  1.1146, -1.0524,  ...,  0.1825,  1.3348,  0.9725]],

         [[ 0.6107, -0.5555,  0.4862,  ...,  1.3714, -0.5348,  0.8950],
          [-0.0682,  2.0073,  

In [13]:
keys = cache_k[:1, : 0 + 8] # (1,8,32,128)
values = cache_v[:1, : 0 + 8] # (1,8,32,128)
print(keys.shape,values.shape)
xq = xq.transpose(1, 2) # (1,32,8,128)
keys = keys.transpose(1, 2) # (1,32,8,128)
values = values.transpose(1, 2) # (1,32,8,128)
xq.shape,keys.shape,values.shape # (1,32,8,128), (1,32,8,128), (1,32,8,128)


torch.Size([1, 8, 32, 128]) torch.Size([1, 8, 32, 128])


In [15]:
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(128) # matrix multiply of (1,32,8,128) and (1,32,128,8) resulting in # (1,32,8,8)

In [16]:
mask = torch.zeros(1,1,8,8)
scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)