### Now we have covered *MOST* of the important technobabble in AI
  - Easy! BUT!
#### RESPECT THE CODE!
  - There are plenty of gotchas!
#### The simplified formulations I show **are** the basis for the industrial scale code being used by OpoenAI, Google, Meta, etc...
#### ~~The Devil~~ Bugs are in the details!!
  - You can get pretty far without understanding the math, and adapting the boiler-plate code for your purposes, BUT!
  - You **WILL** write buggy code, which has fundamental *math* flaws.
  - You **WILL** write buggy code, that uses pytorch classes incorrectly.
  - Only way to NOT waste weeks/months of effort is to (slowly?) **LEARN THE MATH**
  - And figure out what each pytorch function you use is actually doing!
  - **Use the source, Luke!**

In [None]:
import torch
from torch.nn import functional as F
from torch import nn
from tinytorch.MyStuff import *
from tinytorch.tensorhelpers import *
import math
import matplotlib.pyplot as plt
import numpy as np

from matplotlib.patches import Ellipse
from matplotlib.text import OffsetFrom
import os
import os.path as path
import random

### Now for SimpleLlama3/LLama 3.1/3.2
  - Arguably one of the top open LLM models out there.
  - Uses a slightly modified `transformer` architecture
  - Meta did a service to humanity by producing this model, and releasing it for anyone to use and modify.
  - Encouraged Microsoft to do the same with their Phi models

### (https://github.com/meta-llama/llama-models.git)
  - `llama-models/models/llama3/api/args.py`       -  `ModelArgs`
  - `llama-models/models/llama3/api/tokenizer.model`  - the tokenizer model file!
  - `llama-models/models/llama3/api/tokenizer.py`  -  the tokenizer class!
  - `llama-models/models/llama3/reference_impl/model.py` - the code
  - `llama-models/models/llama3/reference_impl/generation.py` <- build/load the model
#### Download the -Instruct models! (3.2, 1B or 3B only) or llama3.0 (8B only) from (https://www.llama.com/llama-downloads/)
  - You will need a supported gpu that works with pytorch (almost any nvidia card, with at least 16gb VRAM).

In [None]:
from dataclasses import dataclass, asdict
from typing import Optional
import inspect

from LogRelay import *

In [None]:
@dataclass
class ModelArgs:
  dim: int = 4096
  n_layers: int = 32
  n_heads: int = 32
  n_kv_heads: Optional[int] = None
  vocab_size: int = -1 # Later set in the build method
  multiple_of: int = 256
  ffn_dim_multiplier: Optional[float] = None
  norm_eps: float = 1e-5
  rope_theta: float = 500000
  # Needed for KV cache
  max_batch_size: int = 32
  max_seq_len: int = 2048 # Llama3 has 8192. LLama 3.1 suppports upto 128K
  device: str = None
  bias: bool = False
  dropout: float = 0.0

### FeedForward layer
  - [torch.nn.Module](https://pytorch.org/docs/2.2/generated/torch.nn.Module.html#torch.nn.Module)
  - [torch.nn.Linear](https://pytorch.org/docs/2.2/generated/torch.nn.Linear.html)
  - [torch.nn.parameter.Parameter](https://pytorch.org/docs/2.2/generated/torch.nn.parameter.Parameter.html)
  - [torch.nn.functinoal.silu](https://pytorch.org/docs/2.2/generated/torch.nn.functional.silu.html)
    - [The SILU paper](https://arxiv.org/abs/1702.03118)
    - [GELU](https://arxiv.org/abs/1606.08415)
    - [SwISH](https://arxiv.org/abs/1710.05941v1)

In [None]:
# Why are we using '_' in names? so that we can load up the SimpleLlama3 model without changes!
class FeedForward(nn.Module):
  LG = startDebug(__name__)
  def __init__(
    self,
    args: ModelArgs
  ):
    nn.Module.__init__(self)

    hidden_dim = 4 * args.dim
    hidden_dim = int(2 * hidden_dim / 3)
    if args.ffn_dim_multiplier is not None:
      hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
    # Round the hidden_dim to the nearest multiple of the multiple_of parameter
    hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)

    self.w1 = nn.Linear(args.dim, hidden_dim, bias=False)
    self.w2 = nn.Linear(hidden_dim, args.dim, bias=False)
    self.w3 = nn.Linear(args.dim, hidden_dim, bias=False)
    # assert str(self.w3.device).startswith('cuda')
    
  def forward(self, x: torch.Tensor):
    # (B, Seq_Len, Dim) --> (B, Seq_Len, Hidden_Dim)
    swish = F.silu(self.w1(x))
    # (B, Seq_Len, Dim) --> (B, Seq_Len, Hidden_Dim)
    x_V = self.w3(x)
    # (B, Seq_Len, Hidden_Dim) * (B, Seq_Len, Hidden_Dim) --> (B, Seq_Len, Hidden_Dim)
    x = swish * x_V
    # (B, Seq_Len, Hidden_Dim) --> (B, Seq_Len, Dim)
    x = self.w2(x)
    return x

### For Rotary Positional Encoding
  - [torch.outer](https://pytorch.org/docs/2.2/generated/torch.outer.html)
  - [torch.arange](https://pytorch.org/docs/2.2/generated/torch.arange.html)
  - [torch.polar](https://pytorch.org/docs/2.2/generated/torch.polar.html)

In [None]:
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, theta: float = 10000.0, device: str = ''):
  # As written in the paragraph 3.2.2 of the paper
  # >> In order to generalize our results in 2D to any xi ∈ Rd where **d is even**, [...]
  assert head_dim % 2 == 0, "Dimension must be divisible by 2"
  # Build the theta parameter
  # According to the formula theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ... dim/2]
  # Shape: (Head_Dim / 2)
  theta_numerator = torch.arange(0, head_dim, 2).float()
  # Shape: (Head_Dim / 2)
  theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device) # (Dim / 2)
  # Construct the positions (the "m" parameter)
  # Shape: (Seq_Len)
  m = torch.arange(seq_len, device=device)
  # Multiply each theta by each position using the outer product.
  # Shape: (Seq_Len) outer_product* (Head_Dim / 2) -> (Seq_Len, Head_Dim / 2)
  freqs = torch.outer(m, theta).float()
  # We can compute complex numbers in the polar form c = R * exp(m * theta), where R = 1 as follows:
  # (Seq_Len, Head_Dim / 2) -> (Seq_Len, Head_Dim / 2)
  freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
  return freqs_complex

def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
  # Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number
  # Two consecutive values will become a single complex number
  # (B, Seq_Len, H, Head_Dim) -> (B, Seq_Len, H, Head_Dim/2)
  x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
  # Reshape the freqs_complex tensor to match the shape of the x_complex tensor. So we need to add the batch dimension and the head dimension
  # (Seq_Len, Head_Dim/2) --> (1, Seq_Len, 1, Head_Dim/2)
  freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2).to(device)
  # Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor
  # Which results in the rotation of the complex number as shown in the Figure 1 of the paper
  # (B, Seq_Len, H, Head_Dim/2) * (1, Seq_Len, 1, Head_Dim/2) = (B, Seq_Len, H, Head_Dim/2)
  assert x_complex.device == freqs_complex.device, f'{x_complex.device=} == {freqs_complex.device=}'
  x_rotated = x_complex * freqs_complex
  # Convert the complex number back to the real number
  # (B, Seq_Len, H, Head_Dim/2) -> (B, Seq_Len, H, Head_Dim/2, 2)
  x_out = torch.view_as_real(x_rotated)
  # (B, Seq_Len, H, Head_Dim/2, 2) -> (B, Seq_Len, H, Head_Dim)
  x_out = x_out.reshape(*x.shape)
  return x_out.type_as(x).to(device)

### RMSNorm Layer
  - [LLama suggests RMSNorm instead](https://arxiv.org/abs/1910.07467)
  - [torch.nn.Parameter](https://pytorch.org/docs/2.2/generated/torch.nn.parameter.Parameter.html)
  - [torch.rsqrt](https://pytorch.org/docs/2.2/generated/torch.rsqrt.html)
  - [torch.pow](https://pytorch.org/docs/2.2/generated/torch.pow.html)
  - [torch.mean](https://pytorch.org/docs/2.2/generated/torch.mean.html)

In [None]:
class RMSNorm(nn.Module):
  LG = startDebug(__name__)
  def __init__(self, dim: int, eps: float = 1e-6):
    nn.Module.__init__(self)
    self.eps = eps
    # The gamma parameter
    self.weight = nn.Parameter(torch.ones(dim))

  def _norm(self, x: torch.Tensor):
    # (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim)
    # rsqrt: 1 / sqrt(x)
    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

  def forward(self, x: torch.Tensor):
    # (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim)
    return self.weight * self._norm(x.float()).type_as(x)

### repeat_kv (used for Grouped Multi-Head Attention
  - [torch.Tensor.expand](https://pytorch.org/docs/stable/generated/torch.Tensor.expand)
  - [torch.Tensor.reshape](https://pytorch.org/docs/stable/generated/torch.Tensor.reshape)

In [None]:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
  batch_size, seq_len, n_kv_heads, head_dim = x.shape
  if n_rep == 1:
    return x
  return (
    # (B, Seq_Len, N_KV_Heads, 1, Head_Dim)
    x[:, :, :, None, :]
    # (B, Seq_Len, N_KV_Heads, N_Rep, Head_Dim)
    .expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
    # (B, Seq_Len, N_KV_Heads * N_Rep, Head_Dim)
    .reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
  )

### SelfAttention Layer
  - [torch.zeros](https://pytorch.org/docs/2.2/generated/torch.zeros.html)
  - [torch.view](https://pytorch.org/docs/2.2/generated/torch.view.html)
  - [torch.transpose](https://pytorch.org/docs/2.2/generated/torch.transpose.html)
  - [torch.nn.functional.softmax](https://pytorch.org/docs/2.2/generated/torch.nn.functional.softmax.html)
  - CAREFUL! [Pytorch broadcasting semantics](https://pytorch.org/tutorials/beginner/introyt/tensors_deeper_tutorial.html)

In [None]:
class SelfAttention(nn.Module):
  LG = startDebug(__name__)
  def __init__(self, args: ModelArgs):
    nn.Module.__init__(self)
    self.args = args
    # Indicates the number of heads for the Keys and Values
    self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
    # Indicates the number of heads for the Queries
    self.n_heads_q = args.n_heads
    # Indicates how many times the Keys and Values should be repeated
    self.n_rep = self.n_heads_q // self.n_kv_heads
    # Indicates the dimension of each head, that is, the part of the embedding that each head will be responsible for
    self.head_dim = args.dim // args.n_heads
    self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
    self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
    self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
    self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
    assert isinstance(args, ModelArgs)
    assert args.device is not None
    self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim)).to(args.device)
    self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim)).to(args.device)
  def forward(
      self,
      x: torch.Tensor,
      start_pos: int,
      freqs_complex: torch.Tensor,
      mask: Optional[torch.Tensor]
  ):
    batch_size, seq_len, _ = x.shape  # (B, 1, Dim)

    # (B, 1, Dim) -> (B, 1, H_Q * Head_Dim)
    xq = self.wq(x)
    # (B, 1, Dim) -> (B, 1, H_KV * Head_Dim)
    xk = self.wk(x)
    # (B, 1, Dim) -> (B, 1, H_KV * Head_Dim)
    xv = self.wv(x)

    # (B, 1, H_Q * Head_Dim) -> (B, 1, H_Q, Head_Dim)
    xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)
    # (B, 1, H_KV * Head_Dim) -> (B, 1, H_KV, Head_Dim)
    xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
    # (B, 1, H_KV * Head_Dim) -> (B, 1, H_KV, Head_Dim)
    xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)

    # (B, 1, H_Q, Head_Dim) --> (B, 1, H_Q, Head_Dim)
    xq = apply_rotary_embeddings(xq, freqs_complex, device=x.device)
    # (B, 1, H_KV, Head_Dim) --> (B, 1, H_KV, Head_Dim)
    xk = apply_rotary_embeddings(xk, freqs_complex, device=x.device)

    # Replace the entry in the cache
    self.cache_k[:batch_size, start_pos : start_pos + seq_len] = xk
    self.cache_v[:batch_size, start_pos : start_pos + seq_len] = xv

    # (B, Seq_Len_KV, H_KV, Head_Dim)
    keys = self.cache_k[:batch_size, : start_pos + seq_len]
    # (B, Seq_Len_KV, H_KV, Head_Dim)
    values = self.cache_v[:batch_size, : start_pos + seq_len]

    # Since every group of Q shares the same K and V heads, just repeat the K and V heads for every Q in the same group.
    # (B, Seq_Len_KV, H_KV, Head_Dim) --> (B, Seq_Len_KV, H_Q, Head_Dim)
    keys = repeat_kv(keys, self.n_rep)
    # (B, Seq_Len_KV, H_KV, Head_Dim) --> (B, Seq_Len_KV, H_Q, Head_Dim)
    values = repeat_kv(values, self.n_rep)

    # (B, 1, H_Q, Head_Dim) -> (B, H_Q, 1, Head_Dim)
    xq = xq.transpose(1, 2)
    # (B, Seq_Len_KV, H_Q, Head_Dim) -> (B, H_Q, Seq_Len_KV, Head_Dim)
    keys = keys.transpose(1, 2)
    # (B, Seq_Len_KV, H_Q, Head_Dim) -> (B, H_Q, Seq_Len_KV, Head_Dim)
    values = values.transpose(1, 2)
    # (B, H_Q, 1, Head_Dim) @ (B, H_Q, Head_Dim, Seq_Len_KV) -> (B, H_Q, 1, Seq_Len_KV)
    scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
    if mask is not None:
      scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)

    # (B, H_Q, 1, Seq_Len_KV) -> (B, H_Q, 1, Seq_Len_KV)
    scores = F.softmax(scores.float(), dim=-1).type_as(xq)

    # (B, H_Q, 1, Seq_Len) @ (B, H_Q, Seq_Len_KV, Head_Dim) -> (B, H_Q, 1, Head_Dim)
    output = torch.matmul(scores, values)
    # (B, H_Q, 1, Head_Dim) -> (B, 1, H_Q, Head_Dim) -> (B, 1, Dim)
    output = (output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1))
    return self.wo(output) # (B, 1, Dim) -> (B, 1, Dim)

### EncoderBlock
  - SelfAttention
  - FeedForward
  - RMSNorm

In [None]:
class EncoderBlock(nn.Module):
  LG = startDebug(__name__)
  def __init__(self, args: ModelArgs):
    nn.Module.__init__(self)

    self.n_heads = args.n_heads
    self.dim = args.dim
    self.head_dim = args.dim // args.n_heads

    self.attention = SelfAttention(args)
    self.feed_forward = FeedForward(args)

    # Normalization BEFORE the attention block
    self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
    # Normalization BEFORE the feed forward block
    self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
  
  def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor,
              mask: Optional[torch.Tensor]):
    assert isinstance(x, torch.Tensor) and isinstance(start_pos, int)
    # assert x.device == mask.device, f'{x.device=} == {mask.device=}'
    # (B, Seq_Len, Dim) + (B, Seq_Len, Dim) --> (B, Seq_Len, Dim)
    h = x + self.attention.forward(
      self.attention_norm(x), start_pos, freqs_complex, mask
    )
    # (B, Seq_Len, Dim) + (B, Seq_Len, Dim) --> (B, Seq_Len, Dim)
    out = h + self.feed_forward.forward(self.ffn_norm(h))
    return out

### Transformer
  - Uses all of the modules defined above
  - [torch.nn.Embedding](https://pytorch.org/docs/2.2/generated/torch.nn.Embedding.html)
  - [torch.nn.functional.softmax](https://pytorch.org/docs/2.2/generated/torch.nn.functional.softmax.html)
  - [torch.hstack](https://pytorch.org/docs/2.2/generated/torch.hstack.html)
  - [torch.cat](https://pytorch.org/docs/2.2/generated/torch.cat.html)
  - [torch.full](https://pytorch.org/docs/2.2/generated/torch.full.html)
  - [torch.triu](https://pytorch.org/docs/2.2/generated/torch.triu.html)

In [None]:
### quick example of torch.cat and torch.stack
A = torch.arange(6).view(3,-1)

In [None]:
print(Shapes(A=A,cat0=torch.cat([A, A], dim=0), cat1=torch.cat([A, A], dim=1)))

In [None]:
print(Shapes(A=A, stack0=torch.stack([A, A], dim=0), stack1=torch.stack([A, A], dim=1)))

In [None]:
class Transformer(nn.Module):
  LG = startDebug(__name__)
  def __init__(self, args: ModelArgs):
    nn.Module.__init__(self)
    setupLogRelay(self)
    assert args.device is not None
    self.info(f'Transformer initializing with {asdict(args)}')
    assert args.vocab_size != -1, "Vocab size must be set"
    self.params = args
    self.config = args
    self.args = args
    self.vocab_size = args.vocab_size
    self.n_layers = args.n_layers
    self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)

    self.layers = nn.ModuleList()
    for layer_id in range(args.n_layers):
      self.layers.append(EncoderBlock(args))

    self.norm = RMSNorm(args.dim, eps=args.norm_eps)
    self.output = nn.Linear(args.dim, self.vocab_size, bias=False)

    self.freqs_complex = precompute_theta_pos_frequencies(self.args.dim // self.args.n_heads, self.args.max_seq_len * 2,
                                args.rope_theta,
                                device=self.args.device)
    self.info("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

  def forward(self, tokens: torch.Tensor, start_pos: int, targets = None):
    # (B, Seq_Len)
    batch_size, seq_len = tokens.shape
    # assert seq_len > 1, f"Only one token at a time can be processed {seq_len=}"

    # (B, Seq_Len) -> (B, Seq_Len, Dim)
    h = self.tok_embeddings(tokens)
    assert isinstance(h, torch.Tensor)
    assert isinstance(start_pos, int), f'{start_pos=} not int'
    # Retrieve the pairs (m, theta) corresponding to the positions [start_pos, start_pos + seq_len]
    freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]

    mask = None
    if seq_len > 1:
      mask = torch.full((seq_len, seq_len), float("-inf"), device=tokens.device)

      mask = torch.triu(mask, diagonal=1)

      # When performing key-value caching, we compute the attention scores
      # only for the new sequence. Thus, the matrix of scores is of size
      # (seq_len, cache_len + seq_len), and the only masked entries are (i, j) for
      # j > cache_len + i, since row i corresponds to token cache_len + i.
      mask = torch.hstack(
        [torch.zeros((seq_len, start_pos), device=tokens.device), mask]
      ).type_as(h)

    # Consecutively apply all the encoder layers
    for layer in self.layers:
      h = layer(h, start_pos, freqs_complex, mask)
    h = self.norm(h)
    output = self.output(h).float()
    loss = None
    if targets is not None:
      logits = output
      loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

    return output, loss

  def get_num_params(self, non_embedding=True):
    """
    Return the number of parameters in the model.
    For non-embedding count (default), the position embeddings get subtracted.
    The token embeddings would too, except due to the parameter sharing these
    params are actually used as weights in the final layer, so we include them.
    """
    n_params = sum(p.numel() for p in self.parameters())
    # if non_embedding:
    #   n_params -= self.transformer.wpe.weight.numel()
    return n_params

  def _init_weights(self, module):
    if isinstance(module, nn.Linear):
      torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
      if module.bias is not None:
        torch.nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
      torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

### Now for the tokenizer
  - SimpleLlama3/SimpleLlama3.1 uses OpenAI's tiktokenizer

In [None]:
from logging import getLogger
from pathlib import Path
from typing import (
  AbstractSet,
  cast,
  Collection,
  Dict,
  Iterator,
  List,
  Literal,
  Sequence,
  TypedDict,
  Union,
)

import tiktoken
from tiktoken.load import load_tiktoken_bpe


logger = getLogger(__name__)


Role = Literal["system", "user", "assistant"]


class Message(TypedDict):
  role: Role
  content: str

Dialog = Sequence[Message]


class Tokenizer:
  """
  Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
  """

  special_tokens: Dict[str, int]

  num_reserved_special_tokens = 256

  pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"  # noqa: E501

  def __init__(self, model_path: str):
    """
    Initializes the Tokenizer with a Tiktoken model.

    Args:
      model_path (str): The path to the Tiktoken model file.
    """
    assert os.path.isfile(model_path), model_path

    mergeable_ranks = load_tiktoken_bpe(model_path)
    num_base_tokens = len(mergeable_ranks)
    special_tokens = [
      "<|begin_of_text|>",
      "<|end_of_text|>",
      "<|reserved_special_token_0|>",
      "<|reserved_special_token_1|>",
      "<|reserved_special_token_2|>",
      "<|reserved_special_token_3|>",
      "<|start_header_id|>",
      "<|end_header_id|>",
      "<|reserved_special_token_4|>",
      "<|eot_id|>",  # end of turn
    ] + [
      f"<|reserved_special_token_{i}|>"
      for i in range(5, self.num_reserved_special_tokens - 5)
    ]
    self.special_tokens = {
      token: num_base_tokens + i for i, token in enumerate(special_tokens)
    }
    self.model = tiktoken.Encoding(
      name=Path(model_path).name,
      pat_str=self.pat_str,
      mergeable_ranks=mergeable_ranks,
      special_tokens=self.special_tokens,
    )
    logger.info(f"Reloaded tiktoken model from {model_path}")

    self.n_words: int = self.model.n_vocab
    # BOS / EOS token IDs
    self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
    self.eos_id: int = self.special_tokens["<|end_of_text|>"]
    self.pad_id: int = -1
    self.stop_tokens = {
      self.special_tokens["<|end_of_text|>"],
      self.special_tokens["<|eot_id|>"],
    }
    logger.info(
      f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
    )

  def encode(self, s: str,  *, bos: bool, eos: bool,
             allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
             disallowed_special: Union[Literal["all"], Collection[str]] = (),
            ) -> List[int]:
    """
    Encodes a string into a list of token IDs.

    Args:
      s (str): The input string to be encoded.
      bos (bool): Whether to prepend the beginning-of-sequence token.
      eos (bool): Whether to append the end-of-sequence token.
      allowed_tokens ("all"|set[str]): allowed special tokens in string
      disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string

    Returns:
      list[int]: A list of token IDs.

    By default, setting disallowed_special=() encodes a string by ignoring
    special tokens. Specifically:
    - Setting `disallowed_special` to () will cause all text corresponding
      to special tokens to be encoded as natural text (insteading of raising
      an error).
    - Setting `allowed_special` to "all" will treat all text corresponding
      to special tokens to be encoded as special tokens.
    """
    assert type(s) is str

    # The tiktoken tokenizer can handle <=400k chars without
    # pyo3_runtime.PanicException.
    TIKTOKEN_MAX_ENCODE_CHARS = 400_000

    # https://github.com/openai/tiktoken/issues/195
    # Here we iterate over subsequences and split if we exceed the limit
    # of max consecutive non-whitespace or whitespace characters.
    MAX_NO_WHITESPACES_CHARS = 25_000

    substrs = (
      substr
      for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
      for substr in self._split_whitespaces_or_nonwhitespaces(
        s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
      )
    )
    t: List[int] = []
    for substr in substrs:
      t.extend(
        self.model.encode(
          substr,
          allowed_special=allowed_special,
          disallowed_special=disallowed_special,
        )
      )
    if bos:
      t.insert(0, self.bos_id)
    if eos:
      t.append(self.eos_id)
    return t

  def decode(self, t: Sequence[int]) -> str:
    """
    Decodes a list of token IDs into a string.

    Args:
      t (List[int]): The list of token IDs to be decoded.

    Returns:
      str: The decoded string.
    """
    # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
    return self.model.decode(cast(List[int], t))

  @staticmethod
  def _split_whitespaces_or_nonwhitespaces(
    s: str, max_consecutive_slice_len: int
  ) -> Iterator[str]:
    """
    Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
    consecutive whitespaces or consecutive non-whitespaces.
    """
    current_slice_len = 0
    current_slice_is_space = s[0].isspace() if len(s) > 0 else False
    slice_start = 0

    for i in range(len(s)):
      is_now_space = s[i].isspace()

      if current_slice_is_space ^ is_now_space:
        current_slice_len = 1
        current_slice_is_space = is_now_space
      else:
        current_slice_len += 1
        if current_slice_len > max_consecutive_slice_len:
          yield s[slice_start:i]
          slice_start = i
          current_slice_len = 1
    yield s[slice_start:]


class ChatFormat:
  def __init__(self, tokenizer: Tokenizer):
    self.tokenizer = tokenizer

  def encode_header(self, message: Message) -> List[int]:
    tokens = []
    tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
    tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
    tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
    tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
    return tokens

  def encode_message(self, message: Message) -> List[int]:
    tokens = self.encode_header(message)
    tokens.extend(
      self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)
    )
    tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
    return tokens

  def encode_dialog_prompt(self, dialog: Dialog) -> List[int]:
    tokens = []
    tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
    for message in dialog:
      tokens.extend(self.encode_message(message))
    # Add the start of an assistant message for the model to complete.
    tokens.extend(self.encode_header({"role": "assistant", "content": ""}))
    return tokens

### Inference engine
 - [torch.multinomial](https://pytorch.org/docs/2.2/generated/torch.multinomial.html)
 - [torch.argmax](https://pytorch.org/docs/2.2/generated/torch.argmax.html)
 - [torch.where](https://pytorch.org/docs/2.2/generated/torch.where.html)
 - [torch.gather](https://pytorch.org/docs/2.2/generated/torch.gather.html)
 - [torch.cumsum](https://pytorch.org/docs/2.2/generated/torch.cumsum.html)

In [None]:
# adapted from official Llama3 repository
from typing import (
  AbstractSet,
  cast,
  Collection,
  Dict,
  Iterator,
  List,
  Literal,
  Sequence,
  TypedDict,
  Optional,
  Union,
  Tuple,
)
import torch
import time
from pathlib import Path
import json
# from sentencepiece import SentencePieceProcessor
from tqdm import tqdm
import tiktoken
from tiktoken.load import load_tiktoken_bpe

class CompletionPrediction(TypedDict, total=False):
  generation: str
  tokens: List[str]  # not required
  logprobs: List[float]  # not required

class ChatPrediction(TypedDict, total=False):
  generation: Message
  tokens: List[str]  # not required
  logprobs: List[float]  # not required

class SimpleLlama3:
  LG = startDebug(__name__)
  def __init__(self, model: Transformer, tokenizer: Tokenizer, model_args: ModelArgs):
    setupLogRelay(self)
    self.model = model
    self.tokenizer = tokenizer
    self.args = model_args
    self.formatter = ChatFormat(tokenizer)

  @staticmethod
  def build(checkpoints_dir: str, tokenizer_path: str, load_model: bool, max_seq_len: int, max_batch_size: int, device: str):
    prev_time = time.time()
    if load_model:
      checkpoints = sorted(Path(checkpoints_dir).glob("*.pth"))
      assert len(checkpoints) > 0, f"no checkpoint files found in {checkpoints_dir}"
      ckpt_path = checkpoints[0]
      print(f'Loading checkpoint "{ckpt_path}"')
      checkpoint = torch.load(ckpt_path, map_location="cpu")
      print(f"Loaded checkpoint in {time.time() - prev_time:.2f}s")
      prev_time = time.time()
    with open(Path(checkpoints_dir) / "params.json", "r") as f:
      params = json.loads(f.read())

    model_args: ModelArgs = ModelArgs(
      max_seq_len=max_seq_len,
      max_batch_size=max_batch_size,
      device=device,
      **params
    )

    tokenizer = Tokenizer(model_path=tokenizer_path)
    # tokenizer.load(tokenizer_path)
    model_args.vocab_size = tokenizer.n_words
    print(f'{model_args.vocab_size=}')
    if device == "cuda":
      torch.set_default_tensor_type(torch.cuda.HalfTensor)
    else:
      torch.set_default_tensor_type(torch.BFloat16Tensor)
    
    model = Transformer(model_args).to(device)

    if load_model:
      # The only unmatched key in the checkpoint is rope.freqs. Remove it
      # del checkpoint['rope.freqs']
      model.load_state_dict(checkpoint, strict=True)
      print(f"Loaded state dict in {time.time() - prev_time:.2f}s")
    
    return SimpleLlama3(model, tokenizer, model_args)

  @torch.inference_mode()
  def generate(
    self,
    prompt_tokens: List[List[int]],
    max_gen_len: int,
    temperature: float = 0.6,
    top_p: float = 0.9,
    logprobs: bool = False,
    echo: bool = False,
  ) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
    """
    Generate text sequences based on provided prompts using the language generation model.

    Args:
      prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
      max_gen_len (int): Maximum length of the generated text sequence.
      temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
      top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
      logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
      echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.

    Returns:
      Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.

    Note:
      This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
      If logprobs is True, token log probabilities are computed for each generated token.

    """
    device = self.args.device
    params = self.model.params
    bsz = len(prompt_tokens)
    assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

    min_prompt_len = min(len(t) for t in prompt_tokens)
    max_prompt_len = max(len(t) for t in prompt_tokens)
    assert max_prompt_len <= params.max_seq_len
    total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

    pad_id = self.tokenizer.pad_id
    tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device)
    for k, t in enumerate(prompt_tokens):
      tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
    if logprobs:
      token_logprobs = torch.zeros_like(tokens, dtype=torch.float)

    prev_pos = 0
    eos_reached = torch.tensor([False] * bsz, device=device)
    input_text_mask = tokens != pad_id
    if min_prompt_len == total_len:
      logits, loss = self.model.forward(tokens, prev_pos)
      self.debug('got logits.shape=$lg loss $LO temp=$temp topP=$topk', list(logits.shape), loss, temperature, top_p)
      token_logprobs = -F.cross_entropy(
        input=logits.transpose(1, 2),
        target=tokens,
        reduction="none",
        ignore_index=pad_id,
      )

    stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens))

    for cur_pos in range(min_prompt_len, total_len):
      logits,loss = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
      # self.debug('pos $p got logits.shape=$L LOSS $L2', cur_pos, list(logits.shape), loss)
      if temperature > 0:
        probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
        next_token = self._sample_top_p(probs, top_p)
      else:
        next_token = torch.argmax(logits[:, -1], dim=-1)

      next_token = next_token.reshape(-1)
      # only replace token if prompt has already been generated
      next_token = torch.where(
        input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
      )
      tokens[:, cur_pos] = next_token
      if logprobs:
        token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
          input=logits.transpose(1, 2),
          target=tokens[:, prev_pos + 1 : cur_pos + 1],
          reduction="none",
          ignore_index=pad_id,
        )
      eos_reached |= (~input_text_mask[:, cur_pos]) & (
        torch.isin(next_token, stop_tokens)
      )
      prev_pos = cur_pos
      if all(eos_reached):
        break
    self.debug('got total [$mv, $p]=$t', min_prompt_len, total_len, total_len-min_prompt_len)

    if logprobs:
      token_logprobs = token_logprobs.tolist()
    out_tokens, out_logprobs = [], []
    for i, toks in enumerate(tokens.tolist()):
      # cut to max gen len
      start = 0 if echo else len(prompt_tokens[i])
      toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
      probs = None
      if logprobs:
        probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
      # cut to after eos tok if any
      for stop_token in self.tokenizer.stop_tokens:
        try:
          eos_idx = toks.index(stop_token)
          toks = toks[:eos_idx]
          probs = probs[:eos_idx] if logprobs else None
        except ValueError:
          pass
      out_tokens.append(toks)
      out_logprobs.append(probs)
    return (out_tokens, out_logprobs if logprobs else None)
  
  def _sample_top_p(self, probs, p):
    # (B, vocab_size)
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    # (B, vocab_size)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    # (B, vocab_size)
    # (Substracting "probs_sort" shifts the cumulative sum by 1 position to the right before masking)
    mask = probs_sum - probs_sort > p 
    # Zero out all the probabilities of tokens that are not selected by the Top P
    probs_sort[mask] = 0.0 
    # Redistribute the probabilities so that they sum up to 1.
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    # Sample a token (its index) from the top p distribution
    next_token = torch.multinomial(probs_sort, num_samples=1)
    # Get the token position in the vocabulary corresponding to the sampled index
    next_token = torch.gather(probs_idx, -1, next_token) 
    return next_token

  def text_completion(
    self,
    prompts: List[str],
    temperature: float = 0.6,
    top_p: float = 0.9,
    max_gen_len: Optional[int] = None,
    logprobs: bool = False,
    echo: bool = False,
  ) -> List[CompletionPrediction]:
    """
    Perform text completion for a list of prompts using the language generation model.

    Args:
      prompts (List[str]): List of text prompts for completion.
      temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
      top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
      max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence.
        If not provided, it's set to the model's maximum sequence length minus 1.
      logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
      echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.

    Returns:
      List[CompletionPrediction]: List of completion predictions, each containing the generated text completion.

    Note:
      This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness.
      If logprobs is True, token log probabilities are computed for each generated token.

    """
    if max_gen_len is None:
      max_gen_len = self.model.params.max_seq_len - 1
    prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
    generation_tokens, generation_logprobs = self.generate(
      prompt_tokens=prompt_tokens,
      max_gen_len=max_gen_len,
      temperature=temperature,
      top_p=top_p,
      logprobs=logprobs,
      echo=echo,
    )
    if logprobs:
      return [
        {
          "generation": self.tokenizer.decode(t),
          "tokens": [self.tokenizer.decode([x]) for x in t],
          "logprobs": logprobs_i,
        }
        for t, logprobs_i in zip(generation_tokens, generation_logprobs)
      ]
    return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]

  def chat_completion(
    self,
    dialogs: List[Dialog],
    temperature: float = 0.6,
    top_p: float = 0.9,
    max_gen_len: Optional[int] = None,
    logprobs: bool = False,
  ) -> List[ChatPrediction]:
    """
    Generate assistant responses for a list of conversational dialogs using the language generation model.

    Args:
      dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages.
      temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
      top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
      max_gen_len (Optional[int], optional): Maximum length of the generated response sequence.
        If not provided, it's set to the model's maximum sequence length minus 1.
      logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.

    Returns:
      List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response.

    Note:
      This method generates assistant responses for the provided conversational dialogs.
      It employs nucleus sampling to introduce controlled randomness in text generation.
      If logprobs is True, token log probabilities are computed for each generated token.
    """
    if max_gen_len is None:
      max_gen_len = self.model.params.max_seq_len - 1

    prompt_tokens = [
      self.formatter.encode_dialog_prompt(dialog) for dialog in dialogs
    ]
    self.debug('got $L tokens $M', len(prompt_tokens[0]), " ".join(map(str, prompt_tokens[0])))
    generation_tokens, generation_logprobs = self.generate(
      prompt_tokens=prompt_tokens,
      max_gen_len=max_gen_len,
      temperature=temperature,
      top_p=top_p,
      logprobs=logprobs,
    )
    if logprobs:
      return [
        {
          "generation": {
            "role": "assistant",
            "content": self.tokenizer.decode(t),
          },
          "tokens": [self.tokenizer.decode([x]) for x in t],
          "logprobs": logprobs_i,
        }
        for t, logprobs_i in zip(generation_tokens, generation_logprobs)
      ]
    return [
      {
        "generation": {
          "role": "assistant",
          "content": self.tokenizer.decode(t),
        },
      }
      for t in generation_tokens
    ]

In [None]:
## what the heck is this?!?! This is to modify any classes you may have modified above
## but skip the reloading of the model!
try:
  if hasattr(model, 'model'):
    print("rebuilding model from existing (because we changed the code!)")
    M2 = SimpleLlama3(model.model, model.tokenizer, model.args)
    model = M2
    print("rebuilt model")
except Exception as e:
  print('building model from scratch')
  T0 = time.monotonic()
  model = SimpleLlama3.build(checkpoints_dir='Meta-Llama-3-8B-Instruct', 
                       tokenizer_path='Meta-Llama-3-8B-Instruct/tokenizer.model',
                       load_model=True,
                       max_seq_len=1024,
                       max_batch_size=8,
                       device='cuda')
  T1 = time.monotonic()
  print(f'{T1-T0} seconds to load model')

In [None]:
print(model.model)

In [None]:
## SimpleLlama3 code expects dialogs of this shape:
dialogs: List[Dialog] = [
    [{"role": "user", "content": "what is the recipe of mayonnaise?"}],
]

In [None]:
def addConversation(M,role=1):
  Roles = [ "system", "user", "assistant" ]
  Rtn = []
  Rtn.append({ "role" : Roles[role], "content": M})
  return Rtn

In [None]:
Q =  addConversation('what is the recipe for vanilla rice pudding?')

In [None]:
A = model.chat_completion([Q])

In [None]:
print(A[0]['generation']['content'])

In [None]:
A = model.chat_completion([addConversation("Write a poem about trees in winter. Make it like Shakespeare")])

In [None]:
print(A[0]['generation']['content'])

In [None]:
C = [ addConversation("Write a poem about trees in winter. Make it like Shakespeare"),
       addConversation(A[0]['generation']['content'], role=2),
       addConversation("It's too long. Make it 5 lines, please")  ]

In [None]:
A2 = model.chat_completion(C)

In [None]:
print(A2[0]['generation']['content'])

In [None]:
A2

In [None]:
len(A2)

In [None]:
def addConversation(M,role=1):
  Roles = [ "system", "user", "assistant" ]
  return { "role" : Roles[role], "content": M}

In [None]:
C2 = [[ addConversation("Write a poem about trees in winter. Make it like Shakespeare"),
       addConversation(A[0]['generation']['content'], role=2),
       addConversation("It's too long. Make it 5 lines, please")  ]]

In [None]:
A2 = model.chat_completion(C2)

In [None]:
print(A2[0]['generation']['content'])

## How to support LLAMA 3.2!

In [None]:
# llama 3.2!
def apply_scaling(freqs: torch.Tensor):
  # Values obtained from grid search
  scale_factor = 8
  low_freq_factor = 1
  high_freq_factor = 4
  old_context_len = 8192  # original llama3 length

  low_freq_wavelen = old_context_len / low_freq_factor
  high_freq_wavelen = old_context_len / high_freq_factor
  new_freqs = []
  for freq in freqs:
    wavelen = 2 * math.pi / freq
    if wavelen < high_freq_wavelen:
      new_freqs.append(freq)
    elif wavelen > low_freq_wavelen:
      new_freqs.append(freq / scale_factor)
    else:
      assert low_freq_wavelen != high_freq_wavelen
      smooth = (old_context_len / wavelen - low_freq_factor) / (
        high_freq_factor - low_freq_factor
      )
      new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
  return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, theta: float = 10000.0, use_scaled: bool = False, device: str = ''):
  # As written in the paragraph 3.2.2 of the paper
  # >> In order to generalize our results in 2D to any xi ∈ Rd where **d is even**, [...]
  assert head_dim % 2 == 0, "Dimension must be divisible by 2"
  # Build the theta parameter
  # According to the formula theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ... dim/2]
  # Shape: (Head_Dim / 2)
  theta_numerator = torch.arange(0, head_dim, 2).float()
  # Shape: (Head_Dim / 2)
  theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device) # (Dim / 2)
  # llama-3.2
  # only two lines to get 128K context. WTF!
  if use_scaled:
    theta = apply_scaling(theta)
  # Construct the positions (the "m" parameter)
  # Shape: (Seq_Len)
  m = torch.arange(seq_len, device=device)
  # Multiply each theta by each position using the outer product.
  # Shape: (Seq_Len) outer_product* (Head_Dim / 2) -> (Seq_Len, Head_Dim / 2)
  freqs = torch.outer(m, theta).float()
  # We can compute complex numbers in the polar form c = R * exp(m * theta), where R = 1 as follows:
  # (Seq_Len, Head_Dim / 2) -> (Seq_Len, Head_Dim / 2)
  freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
  return freqs_complex

In [None]:
@dataclass
class ModelArgs:
  dim: int = 4096
  n_layers: int = 32
  n_heads: int = 32
  n_kv_heads: Optional[int] = None
  vocab_size: int = -1 # Later set in the build method
  multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
  ffn_dim_multiplier: Optional[float] = None
  norm_eps: float = 1e-5
  rope_theta: float = 500000

  ## llama 3.2
  use_scaled_rope: bool = False

  # Needed for KV cache
  max_batch_size: int = 32
  max_seq_len: int = 2048

  # vision model params
  vision_chunk_size: int = -1  # image resolution for image models
  vision_max_num_chunks: int = 4
  vision_num_cross_attention_layers: int = -1

  # others!
  # doRPE : bool = True
  bias: bool = False
  dropout: float = 0.0

  # for SimpleTrainer and TrainerConfig
  dtype : torch.dtype = None
  device: str = None

  def __init__(self, **kwargs):
    for k, v in kwargs.items():
      if hasattr(self, k):
        setattr(self, k, v)

  def __post_init__(s):
    if isinstance(s.dtype, str):
      s.dtype = findByName(s.dtype, locals(), globals())


In [None]:
  model = SimpleLlama3.build(checkpoints_dir='Llama3.2-3B-Instruct', 
                       tokenizer_path='Llama3.2-3B-Instruct/tokenizer.model',
                       load_model=True,
                       max_seq_len=1024,
                       max_batch_size=8,
                       device='cuda')
  T1 = time.monotonic()

In [None]:
model.model

In [None]:
del model

In [None]:
import gc

In [None]:
gc.collect()

In [None]:
torch.cuda.empty_cache()

In [None]:
model.model.args

In [None]:
Q =[[addConversation("What is the recipe for vanilla rice pudding?")]]

In [None]:
addConversation("Waht is the recipe for vanilla rice pudding?")

In [None]:
A3 = model.chat_completion(Q)

In [None]:
print(A3[0]['generation']['content'])

In [None]:
C

In [None]:
A4 = model.chat_completion(C)

In [None]:
C2 = [[ addConversation("Write a poem about trees in winter. Make it like Shakespeare"),
       addConversation(A[0]['generation']['content'], role=2),
       addConversation("It's too long. Make it 5 lines, please")  ]]

In [None]:
A5 = model.chat_completion(C2)

In [None]:
print(A5[0]['generation']['content'])