In [1]:
import json
import sys
import time
import warnings
from pathlib import Path
from typing import Literal, Optional

import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy

from generate.base import generate
from lit_gpt import Tokenizer
from lit_gpt.adapter import Block
from lit_gpt.adapter import GPT, Config
from lit_gpt.adapter_v2 import add_adapter_v2_parameters_to_linear_layers
from lit_gpt.utils import lazy_load, check_valid_checkpoint_dir, quantization
from scripts.prepare_alpaca import generate_prompt

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(


In [2]:
torch.set_float32_matmul_precision("high")

In [3]:
input: str = "What is AIPI?"
adapter_path: Path = Path("s3/out/adapter_finetune/lit_model_adapter_finetuned.pth")
checkpoint_dir: Path = Path("s3/checkpoint/meta-llama/Llama-2-7b-hf")
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]] = None
max_new_tokens: int = 100
top_k: int = 200
temperature: float = 0.8
strategy: str = "auto"
devices: int = 1
precision: str = "bf16-true"

In [4]:
if strategy == "fsdp":
    strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy)
fabric.launch()

check_valid_checkpoint_dir(checkpoint_dir)

with open(checkpoint_dir / "lit_config.json") as fp:
    config = Config(**json.load(fp))

if quantize is not None and devices > 1:
    raise NotImplementedError
if quantize == "gptq.int4":
    model_file = "lit_model_gptq.4bit.pth"
    if not (checkpoint_dir / model_file).is_file():
        raise ValueError("Please run `python quantize/gptq.py` first")
else:
    model_file = "lit_model.pth"
checkpoint_path = checkpoint_dir / model_file

fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.time()
with fabric.init_module(empty_init=True), quantization(quantize):
    model = GPT(config)
    add_adapter_v2_parameters_to_linear_layers(model)
fabric.print(f"Time to instantiate model: {time.time() - t0:.02f} seconds.", file=sys.stderr)

t0 = time.time()
with lazy_load(checkpoint_path) as checkpoint, lazy_load(adapter_path) as adapter_checkpoint:
    checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint))
    model.load_state_dict(checkpoint, strict=quantize is None)
fabric.print(f"Time to load the model weights: {time.time() - t0:.02f} seconds.", file=sys.stderr)

model.eval()
model = fabric.setup(model)

tokenizer = Tokenizer(checkpoint_dir)
sample = {"input": input}
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, device=model.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens

Loading model 's3/checkpoint/meta-llama/Llama-2-7b-hf/lit_model.pth' with {'org': 'meta-llama', 'name': 'Llama-2-7b-hf', 'block_size': 4096, 'vocab_size': 32000, 'padding_multiple': 64, 'padded_vocab_size': 32000, 'n_layer': 32, 'n_head': 32, 'n_embd': 4096, 'rotary_percentage': 1.0, 'parallel_residual': False, 'bias': False, 'n_query_groups': 32, 'shared_attention_norm': False, '_norm_class': 'RMSNorm', 'norm_eps': 1e-05, '_mlp_class': 'LLaMAMLP', 'intermediate_size': 11008, 'condense_ratio': 1, 'adapter_prompt_length': 10, 'adapter_start_layer': 2}
Time to instantiate model: 0.27 seconds.
Time to load the model weights: 13.97 seconds.


In [5]:
prompt

'### Input:\nWhat is AIPI?\n\n### Response:'

In [6]:
y = generate(
    model,
    encoded,
    max_returned_tokens,
    max_seq_length=max_returned_tokens,
    temperature=temperature,
    top_k=top_k,
    eos_id=tokenizer.eos_id,
)

In [7]:
model.reset_cache()
output = tokenizer.decode(y)
output = output.split("### Response:")[1].strip()
fabric.print(output)

AIPI is an acronym for the Duke-Cornell Artificial Intelligence for Product Innovation degree program.
