In [None]:
import argparse
import os
import time

assert (
    os.name != "nt"
), "Windows is not supported, due to Triton not having a wheel for Windows. Please use Linux or macOS."

import datasets
from dotenv import load_dotenv
from transformers import DataCollatorForSeq2Seq, TrainingArguments, pipeline
from trl import SFTTrainer
from unsloth import FastLanguageModel, is_bfloat16_supported
from unsloth.chat_templates import train_on_responses_only

load_dotenv()

assert (
    "HF_TOKEN" in os.environ
), "HF_TOKEN is not set. Use os.environ['HF_TOKEN'] = '...' to set it, or put it in .env file."

In [None]:
def get_default_arguments():
    args = argparse.Namespace()

    args.max_seq_length = 2048
    args.dtype = None
    args.load_in_4bit = True

    # More models at https://huggingface.co/unsloth
    args.model_name = "unsloth/Llama-3.2-3B-Instruct"
    args.chat_template = "llama-3.1"

    # PEFT parameters
    args.peft_rank = 16
    args.lora_alpha = 16
    args.lora_dropout = 0

    # Prompt
    args.replacement_policy_system_prompt = "You are a helpful assistant helping with CPU cache line eviction. You always output something from the cache lines list"

    return args


args = get_default_arguments()
print(f"Using arguments: {args}")

In [None]:
print("Loading model")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=args.model_name, max_seq_length=args.max_seq_length, dtype=args.dtype, load_in_4bit=args.load_in_4bit
)
print("Model loaded")

In [None]:
# PEFT - Parameter Efficient Fine-Tuning: https://huggingface.co/docs/peft
model = FastLanguageModel.get_peft_model(
    model,
    r=args.peft_rank,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_alpha=args.lora_alpha,
    lora_dropout=args.lora_dropout,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    use_rslora=False,
    loftq_config=None,
)

In [None]:
from cache_replacement.policy_learning.cache.eviction_policy import BeladyScorer, GreedyEvictionPolicy
from cache_replacement.policy_learning.cache.memtrace import MemoryTrace
from task13.cache import Cache

cache_config = {"cache_line_size": 64, "capacity": 64 * 16 * 64, "associativity": 16}
memtrace_path = os.path.join(os.getcwd(), "data", "sample_trace.csv")
memtrace = MemoryTrace(memtrace_path, cache_line_size=cache_config["cache_line_size"], max_look_ahead=int(1e5))

# Initialize Belady's optimal eviction policy
belady_scorer = BeladyScorer(memtrace)
belady_policy = GreedyEvictionPolicy(belady_scorer)

cache = Cache.from_config(cache_config, eviction_policy=belady_policy)


class BeladyObserver:
    def __init__(self):
        self.last_cache_lines = list()

    def __call__(self, cache_access, eviction_decision):
        cache_lines = [hex(line) for (line, _) in cache_access.cache_lines]
        self.last_cache_lines = cache_lines


belady_observer = BeladyObserver()

dataset = list()


def get_prompt(pc, address, cache_lines):
    prompt = ""
    prompt += f"Current PC is {hex(pc)}\n"
    prompt += f"Current address: {hex(address)}\n"
    prompt += f"Cache lines are: {cache_lines}\n"
    prompt += "Eviction: "
    return prompt


def generate_datapoint(pc, address, last_cache_lines, belady_eviction):
    system_prompt_data = {"role": "system", "content": args.replacement_policy_system_prompt}

    user_prompt_data = {"role": "user", "content": get_prompt(pc, address, last_cache_lines)}

    answer_data = {"role": "assistant", "content": f"{belady_eviction}"}

    return [system_prompt_data, user_prompt_data, answer_data]


with memtrace:
    while not memtrace.done():
        pc, address = memtrace.next()
        cache.read(pc, address, observers=[belady_observer])

        belady_eviction = cache.last_evicted_cache_line
        if belady_eviction is not None:
            prompt = get_prompt(pc, address, belady_observer.last_cache_lines)
            assert str(hex(belady_eviction)) in belady_observer.last_cache_lines
            dataset.append(generate_datapoint(pc, address, belady_observer.last_cache_lines, str(hex(belady_eviction))))
hit_rate = cache.hit_rate_statistic.success_rate()
print(hit_rate)

In [None]:
hf_dataset = datasets.Dataset.from_dict({"conversations": dataset})


def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
    return {
        "text": texts,
    }


hf_dataset = hf_dataset.map(formatting_prompts_func, batched=True)

In [None]:
print("The following is an example of an input prompt to the model:\n")
print(hf_dataset[0]["text"])

In [None]:
model = FastLanguageModel.for_training(model)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=hf_dataset,
    dataset_text_field="text",
    max_seq_length=args.max_seq_length,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
    dataset_num_proc=8,
    packing=False,  # Can make training 5x faster for short sequences.
    args=TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        warmup_steps=5,
        # num_train_epochs = 1, # Set this for 1 full training run.
        max_steps=60,
        learning_rate=1e-5,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=1,
        optim="adamw_8bit",
        lr_scheduler_type="linear",
        seed=3407,
        output_dir="outputs",
        report_to="none",  # Use this for WandB etc
    ),
)

trainer = train_on_responses_only(
    trainer,
    instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
    response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
)

In [None]:
print("Starting training")
t_start = time.time()
trainer.train()
t_end = time.time()
print(f"Training complete in {t_end - t_start:.2f} seconds")

In [None]:
def skip_input_prompt(input_tokens, output_tokens):
    input_squeezed = input_tokens.squeeze()
    output_squeezed = output_tokens.squeeze()
    assert len(input_squeezed) < len(output_squeezed)
    return output_squeezed[len(input_squeezed) :]

In [None]:
model = FastLanguageModel.for_inference(model)

messages = [
    {"role": "system", "content": args.replacement_policy_system_prompt},
    {
        "role": "user",
        "content": "Current PC is 0x10\nCurrent address: 0x1000\nCache lines are: ['0x1e244', '0x1e23', 0x13]\nEviction: ",
    },
]

inputs = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,  # Must add for generation
    return_tensors="pt",
).to("cuda")

out = model.generate(input_ids=inputs, max_new_tokens=1024, use_cache=True, temperature=1.4, min_p=0.1, do_sample=True)

decoded = tokenizer.decode(skip_input_prompt(inputs, out).cpu().numpy(), skip_special_tokens=True)
print("LLM output:\n\n", decoded)

In [None]:
FastLanguageModel.for_inference(model)  # Enable native 2x faster inference

messages = [
    {"role": "system", "content": "You are a helpful assistant helping with CPU cache line policy improvement"},
    {
        "role": "user",
        "content": "So far, you only get PC, current address, cache lines. How do we combine them to generate a new feature for better line eviction predictor? Give us just one, it's a feature engineering task",
    },
]

inputs = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,  # Must add for generation
    return_tensors="pt",
).to("cuda")

out = model.generate(input_ids=inputs, max_new_tokens=1024, use_cache=True, temperature=0.3, min_p=0.1, do_sample=True)

decoded = tokenizer.decode(skip_input_prompt(inputs, out).cpu().numpy(), skip_special_tokens=True)
print("LLM output:\n\n", decoded)

In [None]:
# Store model
store_path = os.path.join(os.getcwd(), "outputs", "llama-3.2-3b-instruct")
model.save_pretrained(store_path)
tokenizer.save_pretrained(store_path)
print(f"Model stored in {store_path}")

In [None]:
# Load model
load_path = os.path.join(os.getcwd(), "outputs", "llama-3.2-3b-instruct")
model, tokenizer = FastLanguageModel.from_pretrained(load_path)
print(f"Model loaded from {load_path}")

In [None]:
# Use model with a pipeline
model = FastLanguageModel.for_inference(model)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=500)
print(pipe("What prompt should I put here?", return_full_text=False)[0]["generated_text"])