Load the pre trained model

In [None]:
from unsloth import FastModel
from transformers import AutoTokenizer
from peft import PeftModel, PeftConfig
tokenizer = AutoTokenizer.from_pretrained("tokenizers/tokenizer") #change this to the path of the tokenizer

model, _ = FastModel.from_pretrained(
    "trained_models/model", #change this to the path of the model
    max_seq_length=config["max_seq_length"],
    load_in_4bit=True,  # Must match training quantization
    resize_model_vocab=len(tokenizer),
    device_map="auto",
)

#use model for inference
model = FastModel.for_inference(model)

Get the validation dataset so it is possible to calculate the K

In [None]:
config = {
    "max_seq_length": 128,
    "batch_size": 128,
    "valid_ratio": 0.2,
    "dataset_path": "datasets/5k_hdfs_train.txt"
}

In [None]:
from datasets import load_dataset

dataset = load_dataset("text", data_files=config["dataset_path"])
dataset = dataset["train"].train_test_split(test_size=config["valid_ratio"], shuffle=False, seed=42)

def sliding_window_tokenize(examples):
    return tokenizer(examples["text"], truncation=True, max_length=config["max_seq_length"])

valid_tokenized_dataset = dataset["test"].map(
    sliding_window_tokenize,
    batched=True,
    batch_size = config["batch_size"],
    remove_columns=["text"]
)



Get the k 

In [None]:
import numpy as np
from tqdm import tqdm
import torch
def calculate_multiple_topk_miss_rates(sequence, topk_candidates):
    model.eval()
    device = model.device

    miss_rates_by_k = {k: [] for k in topk_candidates}

    with torch.no_grad():
        with tqdm(total=len(sequence), desc=f"Evaluation", unit="sequence") as pbar:
            for seq_idx, token_list in enumerate(sequence):
                token_tensor = torch.tensor(token_list, dtype=torch.long, device=device).unsqueeze(0)
                input_ids = token_tensor[:, :-1]
                labels = token_tensor[:, 1:]

                if input_ids.size(1) == 0:
                    for k in topk_candidates:
                        miss_rates_by_k[k].append(1.0)
                    pbar.update(1)
                    continue

                with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                    outputs = model(input_ids=input_ids)

                logits = outputs.logits.to(torch.float32)

                total_tokens_to_predict = labels.size(1)

                if total_tokens_to_predict <= 2:
                    for k in topk_candidates:
                        miss_rates_by_k[k].append(0.0)
                    pbar.update(1)
                    continue


                max_k = max(topk_candidates)
                _, topk_predictions = torch.topk(logits, k=max_k, dim=-1)

                for k in topk_candidates:
                    correct = 0
                    for i in range(total_tokens_to_predict):
                        true_token_id = labels[0, i].item()
                        predicted_k_token_ids = topk_predictions[0, i, :k].tolist()
                        if true_token_id in predicted_k_token_ids:
                            correct += 1
                    miss_rate = 1.0 - (correct / total_tokens_to_predict)
                    miss_rate = 0.01 if (0 < miss_rate < 0.01) else miss_rate
                    miss_rates_by_k[k].append(miss_rate)

                pbar.update(1)

    return {k: np.array(miss_rates) for k, miss_rates in miss_rates_by_k.items()}


topk_candidates = list(range(1, 20)) #change this to a range that fits your dataset
miss_rates_dict = calculate_multiple_topk_miss_rates(valid_tokenized_dataset["input_ids"], topk_candidates)

for k, miss_rates in miss_rates_dict.items():
    print(f"Top-{k} avg miss rate: {miss_rates.mean():.4f}")