# Test vllm inference Llama 3B

## Links
- [Extending competition](https://www.kaggle.com/competitions/arc-prize-2024/discussion/536832)

## TODOs
- Inference with adapters and float16
- Inference and pixel metrics
- Data augmentations on the fly with formatter
- Filter data by number of tokens
- Prompts class with config
- Config for training in JSON

***

## Imports

In [None]:
import gc
import os
import torch
from typing import List, Optional, Tuple

# Prompts
# from notebooks_utils import create_eval_dataset

# Transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# vLLM
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest

## Docs

In [None]:
?SamplingParams

## Config

In [None]:
BASE_MODEL_ID = "finetuned_models/base-llama-32-3B-fp16-4bit"
# LORA_MODEL_ID = "finetuned_models/tmp_finetuning_llama_3B_max_seq_3072_comb_instr"
# LORA_MODEL_ID = "tmp_finetuning_llama_3B_max_seq_3072_comb_short"
LORA_MODEL_ID = "tmp_finetuning_llama_3B_max_seq_3072_comb_descr"

postfix = "-".join(LORA_MODEL_ID.split("_")[-3:])
MERGED_MODEL_ID = "finetuned_models/finetuned-llama-32-3B-fp16-4bit-merged" + "-" + postfix
assert os.path.exists(MERGED_MODEL_ID), f"{MERGED_MODEL_ID} does not exist"

PROMPT_FN = None
if MERGED_MODEL_ID.endswith("instr"):
    PROMPT_FN = prepare_input_v2
elif MERGED_MODEL_ID.endswith("descr"):
    PROMPT_FN = prepare_input_v3
elif MERGED_MODEL_ID.endswith("short"):
    PROMPT_FN = prepare_input_short

MAX_SEQ_LENGTH = 3072
MAX_NUM_EVAL_TASKS = 400

print(f">>> {BASE_MODEL_ID=}")
print(f">>> {LORA_MODEL_ID=}")
print(f">>> {MERGED_MODEL_ID=}")

## Evaluation dataset

In [None]:
eval_dataset = create_eval_dataset(
    base_model_id=BASE_MODEL_ID,
    max_seq_length=MAX_SEQ_LENGTH,
    max_num_eval_tasks=MAX_NUM_EVAL_TASKS,
    prepare_input_fn=PROMPT_FN,
    verbose=True,
)

In [None]:
prompts = [sample["conversations"][0]["content"] for sample in eval_dataset]

tokenizer = AutoTokenizer.from_pretrained(MERGED_MODEL_ID)
print(tokenizer.pad_token)
print(len(tokenizer.encode(prompts[0])))

## Engine

Lora params

```python
max_loras: int = 1
max_lora_rank: int = 16
enable_prompt_adapter: bool = False
max_prompt_adapters: int = 1
max_prompt_adapter_token: int = 0
fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
max_cpu_loras: Optional[int] = None
```

In [6]:
engine_args = EngineArgs(
    model=BASE_MODEL_ID,
    quantization="bitsandbytes",
    qlora_adapter_name_or_path=LORA_MODEL_ID,
    load_format="bitsandbytes",
    enable_lora=True,
    max_lora_rank=8,
    max_loras=1,
    fully_sharded_loras=False,
    lora_dtype=torch.float16,
    tensor_parallel_size=2,
    max_seq_len_to_capture=MAX_SEQ_LENGTH,
    gpu_memory_utilization=0.5,
)

In [None]:
engine = LLMEngine.from_engine_args(engine_args)
lora_req = LoRARequest("lora-test-1", 1, LORA_MODEL_ID)

In [None]:

# test_prompts = create_test_prompts(lora_path)
request_id = 0
test_prompts = [
    (prompt,
    SamplingParams(
        temperature=0.0,
        logprobs=1,
        prompt_logprobs=1,
        max_tokens=1024
    ),
    lora_req)
    for prompt in prompts
]

In [None]:

while test_prompts or engine.has_unfinished_requests():
    if test_prompts:
        prompt, sampling_params, lora_request = test_prompts.pop(0)
        print(f">>> Processing prompt of length {len(prompt)} with {lora_request=} and {sampling_params=}")
        engine.add_request(str(request_id),
                            prompt,
                            sampling_params,
                            lora_request=lora_request)
        request_id += 1

    request_outputs: List[RequestOutput] = engine.step()
    for request_output in request_outputs:
        if request_output.finished:
            print("----------------------------------------------------")
            print(f"Prompt: {request_output.prompt}")
            print(f"Output: {request_output.outputs[0].text}")

# Clean up the GPU memory for the next test
del engine
gc.collect()
torch.cuda.empty_cache()
