# Test vllm with merged model

***

## Imports

In [None]:
import dataclasses
import os
import copy
import pickle

import numpy as np

from notebooks_utils import get_clean_prompts

# Prompts
from llm_prompts.reader import ReaderMany
# from llm_prompts.prompts.grid_formatter import GridFormatter
from llm_prompts.prompts.text_prompts import PromptSolveInstrV2
from llm_prompts.type_aliases import Grid

# Transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

# vLLM
from vllm import SamplingParams, LLM

## Legacy code

In [2]:
DEFAULT_END_GRID_TOKEN = "+++"
DEFAULT_END_ROW_TOKEN = "\n"
DEFAULT_START_GRID_TOKEN = "!grid\n"
DEFAULT_START_ROW_TOKEN = ""



@dataclasses.dataclass(slots=True, kw_only=True, frozen=True)
class GridFormatter:
    start_grid_token: str = dataclasses.field(default=DEFAULT_START_GRID_TOKEN)
    end_grid_token: str = dataclasses.field(default=DEFAULT_END_GRID_TOKEN)
    start_row_token: str = dataclasses.field(default=DEFAULT_START_ROW_TOKEN)
    end_row_token: str = dataclasses.field(default=DEFAULT_END_ROW_TOKEN)
    color_separator_token: str = dataclasses.field(default="")

    def get_special_tokens_not_in(self, tokenizer: AutoTokenizer) -> list[str]:
        """Find which tokens need to be added to the tokenizer, when using this GridFromatter."""
        additional_special_tokens = []

        grid_formatting_tokens = [
            self.start_grid_token,
            self.end_grid_token,
            self.start_row_token,
            self.end_row_token,
            self.color_separator_token,
        ]

        for token in grid_formatting_tokens:
            if token not in tokenizer.vocab:
                if token != "":
                    additional_special_tokens.append(token)

        return additional_special_tokens

    def encode_grid(self, grid: Grid, input_or_output: str) -> str:
        """Format a Grid into a string to be used inside LLM prompts."""
        num_rows = len(grid)
        formatted_grid = self.start_grid_token

        for idx_row in range(num_rows):
            formatted_grid += self.start_row_token
            row = grid[idx_row]
            for color in row[:-1]:
                formatted_grid += str(color)
                formatted_grid += self.color_separator_token
            formatted_grid += str(row[-1])
            formatted_grid += self.end_row_token

        formatted_grid += self.end_grid_token

        return formatted_grid

    def encode_pairs(self, pairs: list[dict[str, Grid]]) -> str:
        assert len(pairs) > 0
        formatted_pairs: str = ""

        for i, pair in enumerate(pairs):
            formatted_example = f"Input {i}:\n"

            encode_input_grid = self.encode_grid(pair["input"], input_or_output="input")
            formatted_example += encode_input_grid
            formatted_example += "\n"

            formatted_example += f"Output {i}:\n"
            encode_output_grid = self.encode_grid(pair["output"], input_or_output="output")
            formatted_example += encode_output_grid
            formatted_example += "\n"

            formatted_pairs += formatted_example

        return formatted_pairs


## Configs

In [None]:
BASE_MODEL_ID = "finetuned_models/base-llama-32-3B-fp32"
# LORA_MODEL_ID = "finetuned_models/tmp_finetuning_llama_3B_max_seq_3072_comb_instr"
# LORA_MODEL_ID = "tmp_finetuning_llama_3B_max_seq_3072_comb_short"
# LORA_MODEL_ID = "tmp_finetuning_llama_3B_max_seq_3072_comb_descr"

# LORA_MODEL_ID = "tmp_finetuning_llama_32_3B_max_seq_2048_comb_instr"
# LORA_MODEL_ID = "tmp_finetuning_llama_32_3B_max_seq_2048_comb_descr"
LORA_MODEL_ID = "finetuned_models/llama_32_3B_rearc_400x200_8bit_lr1_4"

MERGED_MODEL_ID = "finetuned_models/merged_llama_32_3B_rearc_400x200_8bit_lr1_4-8bit-lr1-4"
assert os.path.exists(MERGED_MODEL_ID), f"{MERGED_MODEL_ID} does not exist"

PROMPT_FN = PromptSolveInstrV2(grid_formatter=GridFormatter())
prompt_type = "descr"
MAX_SEQ_LENGTH = 2048
MAX_NUM_EVAL_TASKS = 8

print(f">>> {BASE_MODEL_ID=}")
print(f">>> {LORA_MODEL_ID=}")
print(f">>> {MERGED_MODEL_ID=}")
print(f">>> {PROMPT_FN=}")

## Data

In [None]:
train_tasks = ReaderMany(
    dataset_dir="./kaggle/input",
    dataset_type="training",
    read_test_output=True,
).read_tasks()

eval_tasks = ReaderMany(
    dataset_dir="./kaggle/input",
    dataset_type="evaluation",
    read_test_output=True,
).read_tasks()

eval_tasks.update(train_tasks)

len(eval_tasks)

Write

**Note** Using a tokenizer in this notebook messed up `vllm`

In [5]:
# clean_eval_prompts = get_clean_prompts(
#     tasks=eval_tasks,
#     model_id=MERGED_MODEL_ID,
#     max_seq_length=MAX_SEQ_LENGTH,
#     max_num_tasks=MAX_NUM_EVAL_TASKS,
#     prompt_fn=PROMPT_FN,
# )

# with open(f"clean_eval_data_{prompt_type}.pickle", "wb") as f:
#     pickle.dump(clean_eval_prompts, f)

Read

In [None]:
with open(f"clean_eval_data_{prompt_type}.pickle", "rb") as f:
    clean_eval_prompts = pickle.load(f)

sorted_task_ids = sorted(clean_eval_prompts)
task_id = sorted_task_ids[0]

input_prompt_separator = "<|start_header_id|>assistant<|end_header_id|>"
print(clean_eval_prompts)

index = clean_eval_prompts[task_id][1].find(input_prompt_separator)

print(f">>> Number of prompts: {len(clean_eval_prompts)}")
print(f">>> Number of valid prompts: {sum(x[0] for x in clean_eval_prompts.values())}")
print(f">>> {index=}")

print("---\nExample prompt:")
print(clean_eval_prompts[task_id][1][:index])
print("...")

## Inference

In [None]:
llm = LLM(
    model=MERGED_MODEL_ID,
    tensor_parallel_size=8,
    dtype="float16",
    seed=0,
    enforce_eager=True,
    gpu_memory_utilization=0.90,
    cpu_offload_gb=0,
    max_seq_len_to_capture=MAX_SEQ_LENGTH,
)

In [None]:
sorted_task_ids = sorted(clean_eval_prompts.keys())
valid_indices = []
valid_prompts = []
for idx, task_id in enumerate(sorted_task_ids):
    if clean_eval_prompts[task_id][0]:
        index_separator = clean_eval_prompts[task_id][1].find(input_prompt_separator)
        new_input_prompt = clean_eval_prompts[task_id][1][:index_separator]
        valid_indices.append(idx)
        valid_prompts.append(new_input_prompt)

print(f">>> {len(valid_indices)=}")

In [None]:
print(valid_prompts[0])
print("-" * 60)

In [None]:
%%time

sampling_params = SamplingParams(
    temperature=0.00,
    top_p=0.95,
    logprobs=1,
    prompt_logprobs=1,
    max_tokens=950,
)

# Run inference
all_requests_outputs = llm.generate(valid_prompts, sampling_params)

In [None]:
print(f">>> {len(all_requests_outputs)=}")
print(f">>> {len(valid_prompts)=}")

In [13]:
DEFAULT_GRID = [[0, 0], [0, 0]]
DEFAULT_ATTEMPTS = {"attempt_1": DEFAULT_GRID, "attempt_2": DEFAULT_GRID}
submission = {
    task_id: [copy.deepcopy(DEFAULT_ATTEMPTS) for _ in range(len(eval_tasks[task_id]["test"]))]
    for task_id in sorted_task_ids
}

In [None]:
solved_task_ids = []

def pad_grid(grid: list[list[int]], pad_value: int = 10):
    padded_grid = []
    row_length = len(grid[0])
    num_rows = len(grid)

    for i in range(30):
        padded_row = []
        for j in range(30):
            c = pad_value
            if i < num_rows and j < row_length:
               c = grid[i][j]

            padded_row.append(c)
        padded_grid.append(padded_row)

    return np.array(padded_grid, dtype=np.uint8)

perc_correct_pixels = []
for idx, task_index in enumerate(valid_indices):
    task_id = sorted_task_ids[task_index]
    expected_output_grid = eval_tasks[task_id]["test"][0]["output"]
    req = all_requests_outputs[idx]
    response_body = req.outputs[0].text
    # print(f"{response_body=}")

    grid = copy.deepcopy(DEFAULT_GRID)
    try:
        # ! Note: need to improve this parsing
        # ! ``` is not good to demarcate the end of a grid
        start_grid = "" # "!grid\n"
        end_grid = "\n+++"
        start_index = 0 # response_body.find(start_grid)
        # end_index = start_index + len(start_grid) + response_body[start_index + len(start_grid):].find(end_grid)
        end_index = response_body.find(end_grid)

        grid_str = response_body[start_index+len(start_grid):end_index]

        grid = [[int(c) for c in row] for row in grid_str.split("\n")]
        print(f"Response body:    {repr(response_body)}")
        print(f"Start index grid: {start_index}")
        print(f"End index grid:   {end_index}")
        print(f"Grid string:      {repr(grid_str)}")
        print(f"Pre-parsing grid: {grid}")
        print(f"{expected_output_grid=}")

        len_first_row = len(grid[0])
        for row in grid:
            if len(row) != len_first_row:
                raise ValueError("Not same numbe of row elements")
        if len_first_row == 0:
            grid = copy.deepcopy(DEFAULT_GRID)
        if len_first_row > 30:
            grid = copy.deepcopy(DEFAULT_GRID)
        if len(grid) > 30:
            grid = copy.deepcopy(DEFAULT_GRID)
    except Exception as e:
        print(f"ERROR: {str(e)}")
        grid = copy.deepcopy(DEFAULT_GRID)

    print(f"Final grid: {grid}")

    assert len(grid) <= 30
    submission[task_id][0]["attempt_1"] = grid
    if grid != DEFAULT_GRID:
        num_correct_pixels = np.sum(pad_grid(grid, pad_value=10) == pad_grid(expected_output_grid, pad_value=11))
        tot_pixels = len(expected_output_grid) * len(expected_output_grid[0])
        print(f"Total same pixels: {num_correct_pixels}")
        print(f"Number pixels: {tot_pixels}")
        perc_correct_pixels.append(num_correct_pixels / tot_pixels)
    else:
        print(f"!ERROR final grid is DEFAULT_GRID={grid=}")
        perc_correct_pixels.append(0.0)


    if grid == expected_output_grid:
        solved_task_ids.append(task_id)
        print(f"{task_id=} was solved")
        print(f"{grid=}")
        print(f"{expected_output_grid=}")

    print("-"* 60)

In [None]:
solved_task_ids

In [None]:
print(f"Average percentage correct pixels: {100 * np.mean(perc_correct_pixels):.2f} %")
print(f"Percentage correct prompt answers: "
      f"{len(solved_task_ids)} / {len(sorted_task_ids)} ~ {100*(len(solved_task_ids) / len(sorted_task_ids)):.2f} %")

In [19]:
# import json

# with open("submission_fp16_descr.json", "w") as f:
#     json.dump(submission, f)

In [None]:
for idx, task_index in enumerate(valid_indices):
    task_id = sorted_task_ids[task_index]
    expected_output_grid = eval_tasks[task_id]["test"][0]["output"]
    req = all_requests_outputs[idx]
    print(req)

## Results (caveat: only on 1 test input/ouput)

***
#### 1 - tmp_finetuning_llama_32_3B_max_seq_2048_comb_instr

Solved training (1 input)
```python
['239be575', '6f8cd79b', '6fa7a44f']
```

Solved eval (1 input)
```python
['332efdb3']
```

```text
Average percentage correct pixels: 14.38 %
Percentage correct prompt answers: 4 / 800 ~ 0.50 %
```

#### 2 - tmp_finetuning_llama_32_3B_max_seq_2048_comb_descr

Solved training (1 input)

```python
['25ff71a9', '27a28665', '44f52bb0', '6fa7a44f', 'a85d4709', 'ff28f65a']
```

Solved eval (1 input)

```python
['9110e3c5', 'f3e62deb']
```

```text
Average percentage correct pixels: 25.64 %
Percentage correct prompt answers: 8 / 800 ~ 1.00 %
```

3 - `finetuned_models/merged_llama_32_3B_rearc_400x200_8bit_lr1_4-8bit-lr1-4`

Solved train (1 input)

```python
['23b5c85d', '44f52bb0', '6150a2bd', '68b16354', '7b7f7511', '7e0986d6', '7f4411dc', '8be77c9e', '9110e3c5', 'b9b7f026']
```

Solved eval (1 input)

```python
['9110e3c5']
```

```text
Average percentage correct pixels: 34.90 %
Percentage correct prompt answers: 11 / 800 ~ 1.38 %
```