In [1]:
from typing import Optional, Tuple
from dataclasses import dataclass
import math
import torch
from torch import nn
import torch.nn.functional as F
import hiq
from llama import ModelArgs, Tokenizer, LLaMA # Transformer
from pathlib import Path
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from typing import Optional, Tuple
from dataclasses import dataclass
import math
import torch
from torch import nn
import torch.nn.functional as F
import hiq


@dataclass
class ModelArgs:
    dim: int = 512
    n_layers: int = 8
    n_heads: int = 8
    vocab_size: int = -1  # defined later by tokenizer
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    norm_eps: float = 1e-5

    max_batch_size: int = 1
    max_seq_len: int = 2048


class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)



class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_local_heads = args.n_heads // 1 # 32 // 1 = 32
        self.head_dim = args.dim // args.n_heads # 4096 // 32 = 128

        self.wq = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        ) # (4096, 4096)
        self.wk = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        ) # (4096, 4096)
        self.wv = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        )  # (4096, 4096)
        self.wo = nn.Linear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
        ) # (4096, 4096)
        self.cache_k = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
            # (1,1024,32,128)
        )
        self.cache_v = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
            # (1,1024,32,128)
        )
        if hiq.get_env_bool("KV_CAHCHE_IN_GPU", True):
            self.cache_k = self.cache_k.cuda()
            self.cache_v = self.cache_v.cuda()

    def forward(
        self,
        x: torch.Tensor, # (1,8,4096)
        start_pos: int, # 0 (initially)
        freqs_cis: torch.Tensor,  # (1024, 64)
        mask: Optional[torch.Tensor],  # (1,1,8,8)
    ):
        bsz, seqlen, _ = x.shape

        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        # all of shape (1,8,4096)
        print(f" shape of xq is {xq.shape}")

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) # (1,8,32,128)
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) # (1,8,32,128)
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) # (1,8,32,128)
        print(f" shape of xq is {xq.shape}")

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) # (1,8,32,128), (1,8,32,128)
        print(f" shape of xq is {xq.shape}")

        self.cache_k = self.cache_k.to(xq) # (1,1024,32,128)
        self.cache_v = self.cache_v.to(xq) # (1,1024,32,128)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk # (1,1024,32,128)
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv # (1,1024,32,128)

        keys = self.cache_k[:bsz, : start_pos + seqlen] # (1,1024,32,128)
        values = self.cache_v[:bsz, : start_pos + seqlen] # (1,1024,32,128)

        xq = xq.transpose(1, 2) # (1,32,8,128)
        keys = keys.transpose(1, 2) # (1,32,1024,128)
        values = values.transpose(1, 2) # (1,32,1024,128)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) # (1,32,8,1024)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq) # (1,32,8,1024)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim) # (1,32,8,128)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) # (1,8,4096)

        return self.wo(output) # (1,8,4096)

class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int, # 4096
        hidden_dim: int, # 4 * 4096 = 16384
        multiple_of: int, # 256
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3) # 2 * 16384 / 3 = 10922
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) # 256 * (10922 + 256 - 1) // 256 = 11177

        self.w1 = nn.Linear(dim, hidden_dim, bias=False) # (4096, 11177)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False) # (11177, 4096)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False) # (4096, 11177)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x)) # (1,8,4096)


class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
            # 4096, 4 * 4096, 256
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(
        self,
        x: torch.Tensor, # (1,8,4096)
        start_pos: int, # 0 (initially)
        freqs_cis: torch.Tensor, # (1024, 64)
        mask: Optional[torch.Tensor], # (1,1,8,8)
    ):
        # this is a skip connection
        h = x + self.attention.forward(
            self.attention_norm(x), start_pos, freqs_cis, mask
            # (1,8,4096), 0, (1024, 64), (1,1,8,8)
        ) # (1,8,4096)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out # (1,8,4096)


class Transformer(nn.Module):
    def __init__(self, params: ModelArgs):
        super().__init__()
        self.params = params # ModelArgs
        self.vocab_size = params.vocab_size # 32_000
        self.n_layers = params.n_layers # 32

        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) # (32_000, 4096)

        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))

        self.norm = RMSNorm(params.dim, eps=params.norm_eps) # shape of output is same as input
        self.output = nn.Linear(params.dim, params.vocab_size, bias=False) # (4096, 32_000)

        self.freqs_cis = precompute_freqs_cis(
            self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
            # 4096 // 32 = 128, 1024 * 2
        ) # torch.Size([2048, 64])

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, start_pos: int):
        _bsz, seqlen = tokens.shape # (1,8)
        h = self.tok_embeddings(tokens) # (1,8,4096)
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] # torch.Size([8, 64])
        print(f"shape of freqs_cis: {freqs_cis.shape}")

        mask = None
        if seqlen > 1:
            mask = torch.full(
                (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
            ) # (1,1,8,8) , filled with -inf
            mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
            # (1,1,8,8) , filled with -inf, but only the upper triangle, lower triangle is 0
            # diagnol = start_pos + 1, so the first 8 tokens are not masked, it basically pushes the diagonola above


        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)
        h = self.norm(h) # (1,8,4096)
        output = self.output(h[:, -1, :])  # only compute last logits # (1, 4096) * (4096, 32_000) = (1, 32_000)
        return output.float() # (1, 32_000)


In [17]:
del model

In [19]:
# clear gpu memory
torch.cuda.empty_cache()

In [7]:
with open(Path(ckpt_dir) / "params.json", "r") as f:
        params = json.loads(f.read())

# create a model args object
# ModelArgs is a a simple dataclass that contains the parameters for the model
# file in llama/model_single.py
model_args: ModelArgs = ModelArgs(
    max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
torch.set_default_tensor_type(torch.FloatTensor)

In [8]:
model(torch.randint(0, 32_000, (1, 8)).cuda(), 0).shape # (1, 32_000)

shape of freqs_cis: torch.Size([8, 64])
 shape of xq is torch.Size([1, 8, 4096])
 shape of xq is torch.Size([1, 8, 32, 128])
 shape of xq is torch.Size([1, 8, 32, 128])
 shape of xq is torch.Size([1, 8, 4096])
 shape of xq is torch.Size([1, 8, 32, 128])
 shape of xq is torch.Size([1, 8, 32, 128])
 shape of xq is torch.Size([1, 8, 4096])
 shape of xq is torch.Size([1, 8, 32, 128])
 shape of xq is torch.Size([1, 8, 32, 128])
 shape of xq is torch.Size([1, 8, 4096])
 shape of xq is torch.Size([1, 8, 32, 128])
 shape of xq is torch.Size([1, 8, 32, 128])
 shape of xq is torch.Size([1, 8, 4096])
 shape of xq is torch.Size([1, 8, 32, 128])
 shape of xq is torch.Size([1, 8, 32, 128])
 shape of xq is torch.Size([1, 8, 4096])
 shape of xq is torch.Size([1, 8, 32, 128])
 shape of xq is torch.Size([1, 8, 32, 128])
 shape of xq is torch.Size([1, 8, 4096])
 shape of xq is torch.Size([1, 8, 32, 128])
 shape of xq is torch.Size([1, 8, 32, 128])
 shape of xq is torch.Size([1, 8, 4096])
 shape of xq is 

torch.Size([1, 32000])

In [6]:

max_seq_len = 1024
max_batch_size = 1
tokenizer_path = '/home/nishantbhansali/MyProject/model/tokenizer.model'
ckpt_dir = '/home/nishantbhansali/MyProject/model/7B'
tokenizer = Tokenizer(model_path=tokenizer_path)
print(tokenizer.n_words)

32000


In [3]:
tokenizer.bos_id,tokenizer.eos_id,tokenizer.pad_id

(1, 2, -1)

In [7]:
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
ckpt_path = checkpoints[0]
checkpoint = torch.load(ckpt_path, map_location="cpu")

In [9]:
with open(Path(ckpt_dir) / "params.json", "r") as f:
        params = json.loads(f.read())

# create a model args object
# ModelArgs is a a simple dataclass that contains the parameters for the model
# file in llama/model_single.py
model_args: ModelArgs = ModelArgs(
    max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
torch.set_default_tensor_type(torch.FloatTensor)

In [10]:
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
torch.set_default_tensor_type(torch.FloatTensor)
model.load_state_dict(checkpoint, strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['layers.0.attention.inner_attention.rope.freqs', 'layers.1.attention.inner_attention.rope.freqs', 'layers.2.attention.inner_attention.rope.freqs', 'layers.3.attention.inner_attention.rope.freqs', 'layers.4.attention.inner_attention.rope.freqs', 'layers.5.attention.inner_attention.rope.freqs', 'layers.6.attention.inner_attention.rope.freqs', 'layers.7.attention.inner_attention.rope.freqs', 'layers.8.attention.inner_attention.rope.freqs', 'layers.9.attention.inner_attention.rope.freqs', 'layers.10.attention.inner_attention.rope.freqs', 'layers.11.attention.inner_attention.rope.freqs', 'layers.12.attention.inner_attention.rope.freqs', 'layers.13.attention.inner_attention.rope.freqs', 'layers.14.attention.inner_attention.rope.freqs', 'layers.15.attention.inner_attention.rope.freqs', 'layers.16.attention.inner_attention.rope.freqs', 'layers.17.attention.inner_attention.rope.freqs', 'layers.18.attention.inner_attention.rope.freqs', 'layers.

In [12]:
print(model.params)

ModelArgs(dim=4096, n_layers=32, n_heads=32, vocab_size=32000, multiple_of=256, norm_eps=1e-06, max_batch_size=1, max_seq_len=1024)


In [11]:
# count number of parameters in model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_parameters(model)

6738415616

In [2]:
prompts = [
        # For these prompts, the expected answer is the natural continuation of the prompt
        "I believe the meaning of life is",  # removed: keep only one prompt
    ]

max_gen_len=256
temperature=0.8
top_p=0.95
max_seq_len= 1024
max_batch_size= 1

In [3]:
bsz = len(prompts) # 1
params = model.params # those same ModelArgs
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

prompt_tokens = [tokenizer.encode(x, bos=True, eos=False) for x in prompts]
# [[1, 306, 4658, 278, 6593, 310, 2834, 338]]

min_prompt_size = min([len(t) for t in prompt_tokens]) # 8
max_prompt_size = max([len(t) for t in prompt_tokens]) # 8

total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) # 264

tokens = torch.full((bsz, total_len), tokenizer.pad_id).cuda().long()
# a tensor of size (1, 264) filled with -1's


NameError: name 'model' is not defined

In [20]:
for k, t in enumerate(prompt_tokens):
    tokens[k, : len(t)] = torch.tensor(t).cuda().long()
input_text_mask = tokens != tokenizer.pad_id # a tensor of size (1, 264) filled with True's ,
# where tokens is not -1, other wise False
start_pos = min_prompt_size # 8
prev_pos = 0

In [22]:
def sample_top_p(probs, p):
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(probs_idx, -1, next_token)
    return next_token

In [24]:
i = tokens[:, prev_pos:8]
logits = model(i, prev_pos)
logits

tensor([[ -8.8516, -14.3594,   3.5098,  ...,  -6.0430,  -7.9609,  -3.8164]],
       device='cuda:0')

In [25]:
torch.softmax(logits / temperature, dim=-1)

tensor([[8.2456e-14, 8.4374e-17, 4.2344e-07,  ..., 2.7601e-12, 2.5102e-13,
         4.4632e-11]], device='cuda:0')

In [26]:
next_token = sample_top_p(torch.softmax(logits / temperature, dim=-1), top_p)
next_token

tensor([[8471]], device='cuda:0')

In [27]:
next_token = torch.where(
        input_text_mask[:, 8], tokens[:, 8], next_token
    )
next_token

tensor([[8471]], device='cuda:0')

In [28]:
tokens[:, 8] 

tensor([-1], device='cuda:0')

In [29]:
tokens[:, 8] = next_token
i = tokens[:, 8:9]
print(i)

tensor([[8471]], device='cuda:0')


In [30]:
logits = model(i, 8)

In [31]:
logits.shape

torch.Size([1, 32000])

In [32]:
prev_pos

0

In [33]:
for cur_pos in range(start_pos, total_len):
    i = tokens[:, prev_pos:cur_pos]
    logits = model(i, prev_pos) # torch.Size([1, 32000])

    if temperature > 0:
        probs = torch.softmax(logits / temperature, dim=-1)
        next_token = sample_top_p(probs, top_p)
    else:
        next_token = torch.argmax(logits, dim=-1)
    next_token = next_token.reshape(-1)
    # only replace token if prompt has already been generated
    next_token = torch.where(
        input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
    )
    tokens[:, cur_pos] = next_token
    prev_pos = cur_pos
    
    # if self._should_stop(tokens, prompt_tokens, stop_ids, stop_words):
    #     break

In [34]:
def postprocessing(output_text, stop_words=None, threshold=10):
    sentences = output_text.split(".")
    filtered_sentences = []
    for sentence in sentences:
        sentence = sentence.strip()
        if len(sentence) > threshold and sentence[-1] == ".":
            filtered_sentences.append(sentence)
    r = '.'.join(sentences).strip()
    if stop_words:
        for w in stop_words:
            if r.endswith(w):
                r = r[0:-len(w)].strip()
    if r[-1] != '.':
        r += '...'
    return r

In [41]:
tokens[0][3]

tensor(278, device='cuda:0')

In [44]:
tokenizer.decode(tokens[0][4].tolist())

'meaning'

In [35]:
tokens[tokens == tokenizer.pad_id] = tokenizer.eos_id
decoded = []
for i, t in enumerate(tokens.tolist()):
    # cut to max gen len
    t = t[: len(prompt_tokens[i]) + max_gen_len]
    # cut to eos tok if any
    try:
        t = t[: t.index(tokenizer.eos_id)]
    except ValueError:
        pass
    decoded.append(tokenizer.decode(t))
#print(decoded)
[postprocessing(i, None) for i in decoded]

['I believe the meaning of life is to do whatever you can to make the world a better place for all, and to create memories.\nIt\'s easy to become depressed about the direction of the world these days, but if you focus on the good and the positive things going on, you\'ll find you have a more optimistic view of the world. If you can bring the light into the world, you can make a difference.\nThat\'s what my series of streetscapes does - it brings the light into the world.\nThe easiest way to make the world a better place is to smile at everyone you meet.\nI have a friend who was in a meeting with a large group of managers at a company where I used to work, and after he left, one of the managers said to him, "I\'ve never seen you smile before. What\'s up?" He said, "I didn\'t think my smile would make a difference."\nI think my smile makes a difference.\nI think the life of a street photographer is the ideal job for someone who wants to make the world a better place. When you go out on t

In [2]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

In [3]:
precompute_freqs_cis(128,2048).shape

torch.Size([2048, 64])

In [20]:
1024*2

2048

In [22]:
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    print("sdfsd",(x.shape[1], x.shape[-1]))
    print("dfsdf",freqs_cis.shape)
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # [1,8,1,64]
    return freqs_cis.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor, # (1,8,32,128)
    xk: torch.Tensor, # (1,8,32,128)
    freqs_cis: torch.Tensor, # (8,64)
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # (1,8,32,128) -> (1,8,32,64,2)
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # (1,8,32,128) -> (1,8,32,64,2)
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # (1,8,1,64)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # (1,8,32,64,2) -> (1,8,32,128) 
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) # (1,8,32,64,2) -> (1,8,32,128)
    return xq_out.type_as(xq), xk_out.type_as(xk) # (1,8,32,128), (1,8,32,128)

In [24]:
xq = torch.randn(1,8,32,128)
xq = xq.reshape(*xq.shape[:-1], -1, 2)
print(xq.shape)
xq = torch.view_as_complex(xq.float())
print(xq.shape)
xk = torch.randn(1,8,32,128)
xk = xk.reshape(*xk.shape[:-1], -1, 2)
xk = torch.view_as_complex(xk.float())
print(xk.shape)
freqs_cis = precompute_freqs_cis(128,2048)
freqs_cis = freqs_cis[:8]
print(freqs_cis.shape)
# freqs_cis = reshape_for_broadcast(freqs_cis, xq)
# print(freqs_cis.shape)
apply_rotary_emb(xq,xk,freqs_cis).shape

torch.Size([1, 8, 32, 64, 2])
torch.Size([1, 8, 32, 64])
torch.Size([1, 8, 32, 64])
torch.Size([8, 64])
sdfsd (8, 32)
dfsdf torch.Size([8, 64])


AssertionError: 