<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/Finetune_deepseek_Essential_web_v1_0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://huggingface.co/datasets/EssentialAI/essential-web-v1.0

In [None]:
# --- 1. Set Up Your Environment ---
!pip install scikit-learn -q # For potential evaluation metrics (optional)
!pip install -U transformers -q
!pip install -U datasets -q
!pip install -U accelerate -q
!pip install -U peft -q
!pip install -U trl -q # For SFTTrainer
!pip install -U bitsandbytes -q
!pip install unsloth -q # Recommended for speed and efficiency
!pip install --force-reinstall --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git # For latest Unsloth

!pip install colab-env -q

In [None]:
import colab_env

## Dataset

In [None]:
from datasets import load_dataset

# Load in streaming mode
raw_dataset = load_dataset("EssentialAI/essential-web-v1.0", streaming=True)
data_stream = raw_dataset["train"]

In [None]:
# Iterate through examples
for example in data_stream.take(5):
    print(example)


{'id': -3908994749044929748, 'text': "Wednesday, May 25, 2011\n\nNEVER GROW UP\n\n\n\nThese are two persons I am always happy to meet, not only because they have amazing style.\n\nThe bubbles were in front of Monki's Helsinki store last Saturday, when Monki and Weekday were celebrating their one year in Helsinki.\nI was an extra at Monki party in the evening and I apologize for spilling those two beers onto somebodys clothes!\nI think I'm tired or something, today I just fell on the ground at work very smoothly.\nI guess there wasn't a chair there...my workmates didn't even notice before I was like, well, I'm on the ground, oops.\nWell, well, I hope everyone else is paying attention, I will soon follow!\n\n\n\nWednesday, May 18, 2011\n\nSHOOK ONES\n\n\n\n\nFirst photo Karoliina Niemenkari, second Sinikka Konttinen.\n\nSorry guys for not updating or reading your great blogs!\nIt's all good in the hood, I'm interning the whole summer at a pr office that's in a really posh street in Helsi

In [None]:
# --- 4. Prepare the Training Dataset ---
print("Loading and preparing EssentialAI/essential-web-v1.0 dataset...")

raw_dataset = data_stream

#FOR POC
eval_set_size = 2
train_set_size = 10

eval_dataset = raw_dataset.take(eval_set_size)

train_dataset = raw_dataset.skip(eval_set_size).take(train_set_size)

test_dataset = eval_dataset

Loading and preparing EssentialAI/essential-web-v1.0 dataset...


In [None]:
print("Dataset preparation complete.")
print("\n" + "="*70 + "\n")

# To see the actual examples in test_dataset
for i, example in enumerate(test_dataset):
    print(f"Example {i+1}: {example}")
    if i >= 2: # Stop after 2 examples as test_dataset only has 2
        break

print("\n" + "="*70 + "\n")

Dataset preparation complete.


Example 1: {'id': -3908994749044929748, 'text': "Wednesday, May 25, 2011\n\nNEVER GROW UP\n\n\n\nThese are two persons I am always happy to meet, not only because they have amazing style.\n\nThe bubbles were in front of Monki's Helsinki store last Saturday, when Monki and Weekday were celebrating their one year in Helsinki.\nI was an extra at Monki party in the evening and I apologize for spilling those two beers onto somebodys clothes!\nI think I'm tired or something, today I just fell on the ground at work very smoothly.\nI guess there wasn't a chair there...my workmates didn't even notice before I was like, well, I'm on the ground, oops.\nWell, well, I hope everyone else is paying attention, I will soon follow!\n\n\n\nWednesday, May 18, 2011\n\nSHOOK ONES\n\n\n\n\nFirst photo Karoliina Niemenkari, second Sinikka Konttinen.\n\nSorry guys for not updating or reading your great blogs!\nIt's all good in the hood, I'm interning the whole summer at a pr off

In [None]:
print(f"Eval sample size: {len(list(eval_dataset))}")

Eval sample size: 2


In [None]:
print(f"Train sample size: {len(list(train_dataset))}")

Train sample size: 10


## Fine-tuning  

In [None]:
import torch
import gc # For garbage collection, helps manage memory
from datasets import load_dataset, Dataset # Import Dataset to convert lists to Dataset objects
from unsloth import FastLanguageModel
from transformers import AutoTokenizer, TrainingArguments
from trl import SFTTrainer

# Clear memory before loading heavy models
gc.collect()
torch.cuda.empty_cache()

# --- Configuration for POC ---
# These define the small number of examples to be used for the Proof of Concept.
eval_set_size = 2000
train_set_size = 12000

# --- Helper Functions (CRITICAL CORRECTION HERE) ---

# Corrected filter function: Now checks for the 'text' field which exists in EssentialAI/essential-web-v1.0
def filter_essential_web_example(example):
    """Filters examples to ensure the 'text' key exists and has a non-empty value,
    as EssentialAI/essential-web-v1.0 provides content in a 'text' field."""
    return "text" in example and example["text"] is not None and example["text"].strip() != ""

# Corrected and improved format function:
# Takes the 'text' field and artificially creates 'prompt' and 'chosen' parts for SFTTrainer.
def format_essential_web_example(example):
    """
    Formats a raw text example from EssentialAI/essential-web-v1.0 into a chat-like string.
    It heuristically splits the raw 'text' into a 'user' prompt and an 'assistant' response.
    This heuristic should be replaced with a more domain-specific method for real applications.
    """
    full_text = example.get("text", "")
    if not full_text.strip():
        # If the original text is empty or just whitespace, return an empty example to be filtered out
        example["text"] = ""
        return example

    # Heuristic: Take the first 100-300 characters as a "prompt" to summarize/analyze the content,
    # and the rest as the "chosen" (assistant's elaboration/continuation).
    # This creates a synthetic instruction-response pair.
    prompt_length = min(len(full_text) // 3, 300) # Take roughly first third, max 300 chars for prompt
    if prompt_length < 50: # Ensure minimum prompt length
        prompt_length = min(len(full_text), 50) # Take whole text if very short, up to 50

    user_prompt_content = f"Analyze the following web content: {full_text[:prompt_length].strip()}"
    assistant_response_content = full_text[prompt_length:].strip()

    # If, after splitting, the content is too short for a meaningful conversation, discard.
    if not user_prompt_content.strip() or not assistant_response_content.strip():
        example["text"] = "" # Mark for filtering
        return example

    # DeepSeek models (based on Llama 3.1) expect a specific chat template format
    messages = [
        {"role": "user", "content": user_prompt_content},
        {"role": "assistant", "content": assistant_response_content}
    ]

    try:
        # Ensure 'tokenizer' is defined in the global scope when this function is executed.
        if 'tokenizer' not in globals():
            raise ValueError("Tokenizer is not available in format_essential_web_example function scope.")

        # Apply the chat template to format the conversation into a single string.
        # tokenize=False means it returns a string, not token IDs.
        formatted_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_special_tokens=True # Include special chat tokens (e.g., <s>, </s>, [INST], etc.)
        )
        example["text"] = formatted_text

        # If the formatted text somehow becomes empty or just whitespace, ensure the 'text' field is empty.
        if not example["text"].strip():
            example["text"] = ""
            # print(f"Warning: Formatted text is empty after apply_chat_template for example ID: {example.get('id', 'N/A')}")

    except Exception as e:
        # If there's an error during formatting, set 'text' to empty to allow filtering it out later.
        # print(f"Error applying chat template: {e}. Messages: {messages}")
        example["text"] = ""
    return example

# --- 2. Load the Model and Tokenizer (No changes needed from previous version) ---
print("Loading DeepSeek-R1 model and tokenizer (unsloth/DeepSeek-R1-Distill-Llama-8B)...")
max_seq_length = 4096 # Maximum sequence length for tokenization
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
load_in_4bit = True # Enable 4-bit quantization for significant memory savings

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/DeepSeek-R1-Distill-Llama-8B", # Model name specified for fine-tuning
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)
print("Model and tokenizer loaded.")

# Ensure the tokenizer has a padding token defined, essential for batching.
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
print(f"Tokenizer pad_token was None, set to eos_token: {tokenizer.pad_token}")
print("="*70 + "\n")

gc.collect()
torch.cuda.empty_cache()

# --- 3. Apply LoRA Adapters (No changes needed from previous version) ---
print("Applying LoRA adapters to the model for efficient fine-tuning...")
model = FastLanguageModel.get_peft_model(
    model,
    r=16, # Rank of the LoRA matrices
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], # DeepSeek's common target modules
    lora_alpha=16, # Scaling factor for LoRA weights
    lora_dropout=0, # Dropout rate for LoRA layers
    bias="none", # How bias terms are handled in LoRA
    use_gradient_checkpointing=True, # Recommended for memory saving
    random_state=3407, # For reproducibility
    use_rslora=False, # Use standard LoRA
    loftq_config=None, # No LoftQ configuration
)
print("LoRA adapters applied.")
print("="*70 + "\n")

gc.collect()
torch.cuda.empty_cache()

# --- 4. Prepare the Training Dataset (Corrected Materialization Strategy) ---
# This section now ensures that the dataset contains valid examples before passing to SFTTrainer.
print("Preparing and explicitly materializing small training and evaluation datasets for SFTTrainer...")

try:
    # Load the full dataset in streaming mode. This creates an IterableDataset.
    # It's re-loaded here to ensure a fresh stream, unaffected by previous operations.
    raw_data_stream_for_processing = load_dataset("EssentialAI/essential-web-v1.0", streaming=True)["train"]

    # 4.1. Process and Materialize the Evaluation Dataset
    print(f"Processing {eval_set_size} examples for evaluation dataset...")
    eval_dataset_materialized = list( # `list()` forces immediate processing of the stream
        raw_data_stream_for_processing.take(eval_set_size) # Take only the specified number of examples
        .filter(filter_essential_web_example) # Apply the corrected filter
        .map(format_essential_web_example, batched=False) # Apply the corrected format function
        .filter(lambda x: x.get("text", "") != "") # Final filter to remove any examples where formatting failed/resulted in empty text
    )
    # Convert the list of dictionaries into a Hugging Face Dataset object (non-streaming for SFTTrainer)
    eval_dataset_for_trainer = Dataset.from_list(eval_dataset_materialized)
    print(f"Evaluation dataset materialized with {len(eval_dataset_for_trainer)} examples.")


    # 4.2. Process and Materialize the Training Dataset
    # `skip()` ensures we get examples *after* those taken for the eval set.
    print(f"Processing {train_set_size} examples for training dataset...")
    train_dataset_materialized = list( # `list()` forces immediate processing of the stream
        raw_data_stream_for_processing.skip(eval_set_size).take(train_set_size) # Skip eval examples, then take train examples
        .filter(filter_essential_web_example) # Apply the corrected filter
        .map(format_essential_web_example, batched=False) # Apply the corrected format function
        .filter(lambda x: x.get("text", "") != "") # Final filter to remove empty texts
    )
    # Convert the list of dictionaries into a Hugging Face Dataset object (non-streaming for SFTTrainer)
    train_dataset_for_trainer = Dataset.from_list(train_dataset_materialized)
    print(f"Training dataset materialized with {len(train_dataset_for_trainer)} examples.")

    print("Dataset preparation complete. Materialized subsets are ready for trainer.")

except Exception as e:
    print(f"ERROR: A critical error occurred during dataset preparation and materialization: {e}")
    # Set datasets to None to prevent subsequent errors if preparation fails
    eval_dataset_for_trainer = None
    train_dataset_for_trainer = None

print("="*70 + "\n")

gc.collect()
torch.cuda.empty_cache()

# --- 5. Set Up and Configure the Trainer (No changes needed from previous version) ---
print("Defining TrainingArguments for SFTTrainer...")
training_args = TrainingArguments(
    output_dir="./sft_results", # Directory for model checkpoints and logs
    per_device_train_batch_size=4, # Samples per training batch per device
    gradient_accumulation_steps=2, # Accumulate gradients
    warmup_steps=5, # Learning rate warmup steps
    num_train_epochs=1, # Total training epochs (1 for POC)
    learning_rate=2e-4, # Initial learning rate
    fp16=(dtype == torch.float16), # Enable FP16 if supported
    bf16=(dtype == torch.bfloat16), # Enable BF16 if supported (preferred)
    logging_steps=50, # Log training metrics every N steps
    optim="adamw_8bit", # Optimizer
    seed=3407, # Random seed
    save_steps=500, # Save checkpoints rarely for POC speed
    save_total_limit=1, # Keep only the last checkpoint
    eval_strategy="steps", # Evaluate every 'eval_steps'
    eval_steps=500, # Evaluate rarely for POC speed
    load_best_model_at_end=False, # Do not load best model for POC speed
    metric_for_best_model="eval_loss", # Metric for best model
    greater_is_better=False, # For loss, lower is better
    report_to="none", # Disable reporting to external services
)
print("TrainingArguments defined.")
print("="*70 + "\n")

In [None]:
# --- 6. Instantiate SFTTrainer and Start Training ---
print("Instantiating SFTTrainer...")
# This step should now be fast as datasets are materialized and non-empty.
if train_dataset_for_trainer is not None and eval_dataset_for_trainer is not None:
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_dataset_for_trainer, # Pass the materialized training dataset
        eval_dataset=eval_dataset_for_trainer,   # Pass the materialized evaluation dataset
        args=training_args,
        dataset_text_field="text", # Specifies the column in your dataset that contains the formatted text
        max_seq_length=max_seq_length, # Passes the maximum sequence length for tokenization
        # packing=True, # Optional: Enables packing of short sequences for efficiency.
                      # Can be beneficial but might need careful handling with very diverse text lengths.
    )

    # Optional: Save the final fine-tuned model (adapters only)
    # trainer.save_model("final_deepseek_finetuned_model_adapters")
    # print("Fine-tuned model (adapters) saved.")
else:
    print("SFTTrainer instantiation and training skipped due to an error in dataset preparation.")
    print("Please check the error messages in section 4: 'Prepare the Training Dataset'.")

In [None]:
print("SFTTrainer instantiated successfully.")
print("="*70 + "\n")

# Start the fine-tuning process
print("Starting fine-tuning...")
trainer.train()
print("Fine-tuning complete!")
print("="*70 + "\n")

SFTTrainer instantiated successfully.

Starting fine-tuning...


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 12,000 | Num Epochs = 1 | Total steps = 1,500
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 2
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 2 x 1) = 8
 "-____-"     Trainable parameters = 41,943,040/8,000,000,000 (0.52% trained)


Step,Training Loss,Validation Loss


## evaluation

In [None]:
import torch
import gc
from datasets import load_dataset, Dataset # Import load_dataset and Dataset
from nltk.translate.bleu_score import sentence_bleu
from nltk.tokenize import word_tokenize
import nltk
import os # Import os module for path operations

# --- Configuration for Evaluation Dataset Preparation ---
# Define the number of examples for the evaluation set.
# This should match the eval_set_size used during the training script's data preparation.
eval_set_size = 2000


# --- Helper Functions for Dataset Preparation (Copied from training script for self-containment) ---
# Corrected filter function: Checks for the 'text' field which exists in EssentialAI/essential-web-v1.0
def filter_essential_web_example(example):
    """Filters examples to ensure the 'text' key exists and has a non-empty value."""
    return "text" in example and example["text"] is not None and example["text"].strip() != ""

# Corrected and improved format function:
# Takes the 'text' field and artificially creates 'prompt' and 'chosen' parts for SFTTrainer.
# This function requires 'tokenizer' to be defined globally before it's called.
def format_essential_web_example(example):
    """
    Formats a raw text example from EssentialAI/essential-web-v1.0 into a chat-like string.
    It heuristically splits the raw 'text' into a 'user' prompt and an 'assistant' response.
    This is a simplification for EssentialAI/essential-web-v1.0 which lacks explicit prompt/chosen.
    """
    full_text = example.get("text", "")

    # 1. Add a minimum length check for the original text
    MIN_ORIGINAL_TEXT_LENGTH = 100 # Set a minimum length for the raw text
    if len(full_text.strip()) < MIN_ORIGINAL_TEXT_LENGTH:
        debug_text_snippet = full_text[:50].replace('\n', ' ')
        print(f"DEBUG: Skipping example (short original text, len={len(full_text.strip())}): '{debug_text_snippet}...'")
        example["text"] = "" # Mark for filtering
        return example

    # Heuristic for splitting: Try to make prompt about 20-30% of total text, with min/max bounds.
    # Ensure remaining text for assistant is also substantial.
    effective_length = len(full_text.strip())
    prompt_len_candidate = max(50, min(effective_length // 3, 300)) # Prompt is at least 50 chars, max 300 chars, and ~1/3 of total

    user_prompt_content = full_text[:prompt_len_candidate].strip()
    assistant_response_content = full_text[prompt_len_candidate:].strip()

    # 2. Add a minimum length check for prompt and response parts after splitting
    MIN_PART_LENGTH = 30 # Minimum characters for a meaningful prompt/response part
    if len(user_prompt_content) < MIN_PART_LENGTH or len(assistant_response_content) < MIN_PART_LENGTH:
        print(f"DEBUG: Skipping example (short parts after split): Prompt len={len(user_prompt_content)}, Response len={len(assistant_response_content)}")
        example["text"] = "" # Mark for filtering
        return example

    # Add an instruction to the user prompt
    user_prompt_content = f"Analyze the following web content: {user_prompt_content}"

    messages = [
        {"role": "user", "content": user_prompt_content},
        {"role": "assistant", "content": assistant_response_content}
    ]

    try:
        # Check if 'tokenizer' is defined in the global scope
        if 'tokenizer' not in globals():
            raise ValueError("Tokenizer is not available in format_essential_web_example function scope.")

        formatted_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_special_tokens=True
        )
        example["text"] = formatted_text

        # 3. Final check if apply_chat_template resulted in an empty string
        if not example["text"].strip():
            print(f"DEBUG: Skipping example (empty formatted text after apply_chat_template): Original len={effective_length}")
            example["text"] = ""

    except Exception as e:
        print(f"DEBUG: Error during apply_chat_template for example: {e}")
        example["text"] = "" # Mark for filtering
    return example


# --- Load the Fine-tuned Model ---
# The fine-tuned model (LoRA adapters) is saved in the output_dir.
# We need to load the base model first, then apply the saved adapters.
from unsloth import FastLanguageModel
from peft import PeftModel

print("Loading the fine-tuned model from '/content/sft_results/checkpoint-150'...")

# Define model loading parameters (should match training parameters)
max_seq_length = 4096 # Use the same max_seq_length as during training
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
load_in_4bit = True

# Load the base model using Unsloth's FastLanguageModel
base_model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/DeepSeek-R1-Distill-Llama-8B", # The original base model name
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)

# Ensure tokenizer pad_token is set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load the PEFT adapters from the specified checkpoint directory
peft_model_load_path = "/content/sft_results/checkpoint-150"
try:
    model = PeftModel.from_pretrained(base_model, peft_model_load_path)
    print(f"Fine-tuned model (LoRA adapters) loaded from '{peft_model_load_path}' successfully.")
except Exception as e:
    print(f"Error loading fine-tuned model adapters from '{peft_model_load_path}': {e}")
    print("Please ensure the specified checkpoint path is correct and contains 'adapter_config.json'.")
    raise e # Re-raise the exception to stop execution if model loading fails

# Merge the LoRA adapters into the base model for efficient inference
# This creates a single, merged model that behaves like a full fine-tuned model.
model = model.merge_and_unload()
print("LoRA adapters merged into base model successfully.")

# --- Prepare Evaluation Dataset (Now self-contained) ---
print("Preparing evaluation dataset...")
try:
    # Load the full dataset stream (freshly)
    raw_data_stream_for_processing = load_dataset("EssentialAI/essential-web-v1.0", streaming=True)["train"]

    # Process and Materialize the Evaluation Dataset
    eval_dataset_materialized = []
    processed_count = 0
    # Iterate over a larger initial segment of the stream to find enough valid examples
    for data_point in raw_data_stream_for_processing.take(eval_set_size * 100): # Attempt to take up to 200 raw examples
        if filter_essential_web_example(data_point):
            formatted_data_point = format_essential_web_example(data_point)
            if formatted_data_point.get("text", "").strip(): # Check if formatting was successful
                eval_dataset_materialized.append(formatted_data_point)
                processed_count += 1
                if processed_count >= eval_set_size:
                    break # Stop once we have enough valid examples

    # Convert the list of dictionaries into a Hugging Face Dataset object
    eval_dataset_for_trainer = Dataset.from_list(eval_dataset_materialized)
    print(f"Evaluation dataset prepared with {len(eval_dataset_for_trainer)} examples.")

except Exception as e:
    print(f"ERROR: A critical error occurred during evaluation dataset preparation: {e}")
    eval_dataset_for_trainer = None # Set to None to prevent further errors if preparation fails
    raise e # Re-raise to halt execution if data is not ready

# --- NLTK Data Downloads for Tokenization ---
# Explicitly set NLTK data path to a writeable directory in /tmp
nltk_data_path = "/tmp/nltk_data"
os.makedirs(nltk_data_path, exist_ok=True)
nltk.data.path.append(nltk_data_path)

# Download NLTK 'punkt' tokenizer data
try:
    nltk.data.find('tokenizers/punkt')
    print("NLTK 'punkt' tokenizer data found.")
except LookupError:
    print("Downloading NLTK 'punkt' tokenizer data...")
    nltk.download('punkt', download_dir=nltk_data_path, quiet=True)
    print("NLTK 'punkt' download complete.")

# Download NLTK 'averaged_perceptron_tagger' data (often used by word_tokenize internally for better results)
try:
    nltk.data.find('taggers/averaged_perceptron_tagger')
    print("NLTK 'averaged_perceptron_tagger' data found.")
except LookupError:
    print("Downloading NLTK 'averaged_perceptron_tagger' data...")
    nltk.download('averaged_perceptron_tagger', download_dir=nltk_data_path, quiet=True)
    print("NLTK 'averaged_perceptron_tagger' download complete.")

# Add download for 'punkt_tab' as required by the traceback
try:
    # Check if 'punkt_tab' is already available in the specified path for 'english'
    # nltk.data.find handles finding the resource path for the specified name ('tokenizers/punkt_tab/english/')
    nltk.data.find('tokenizers/punkt_tab/english/', paths=[nltk_data_path])
    print("NLTK 'punkt_tab' tokenizer data found.")
except LookupError:
    print("Downloading NLTK 'punkt_tab' tokenizer data...")
    # Download 'punkt_tab' to the temporary directory
    nltk.download('punkt_tab', download_dir=nltk_data_path, quiet=True)
    print("NLTK 'punkt_tab' download complete.")

except Exception as e:
    # Catch any other unexpected errors during the check/download process
    print(f"An unexpected error occurred while checking/downloading NLTK tokenizers: {e}")
    print("Please ensure NLTK is correctly installed and you have network access.")


# Set the model to evaluation mode. This disables dropout and batch normalization.
model.eval()
print("Model set to evaluation mode.")
print("="*70 + "\n")

# Clear CUDA cache for evaluation
gc.collect()
torch.cuda.empty_cache()

print("Starting evaluation of the fine-tuned model...")

# Define generation parameters
generation_config = dict(
    max_new_tokens=100,      # Max tokens to generate per response
    do_sample=True,          # Enable sampling (diverse outputs)
    temperature=0.7,         # Controls randomness
    top_p=0.9,               # Nucleus sampling
    num_return_sequences=1,  # Generate one sequence per prompt
    eos_token_id=tokenizer.eos_token_id, # Stop generation at EOS token
    pad_token_id=tokenizer.pad_token_id, # Use pad token if needed for batching
)

total_bleu_score = 0.0
evaluated_count = 0

# Iterate through the evaluation dataset
if eval_dataset_for_trainer is not None and len(eval_dataset_for_trainer) > 0:
    for i, example in enumerate(eval_dataset_for_trainer):
        full_formatted_text = example["text"]

        try:
            # Use the exact markers observed in the debug output: <｜begin of sentence｜>, <｜User｜>, <｜Assistant｜>, <｜end of sentence｜>
            # Note the specific Unicode-like pipe characters '｜'

            begin_sentence_marker = "<｜begin of sentence｜>"
            user_start_marker = "<｜User｜>"
            assistant_start_marker = "<｜Assistant｜>"
            end_sentence_marker = "<｜end of sentence｜>"

            # The exact indices of the markers are crucial.
            # We need to ensure we account for ALL parts of the template.
            # Let's clean the string a bit before finding.
            # The structure is: <｜begin of sentence｜><｜User｜>...<｜Assistant｜>...<｜end of sentence｜>

            # Find the start of the user content
            # Start search after the begin_sentence_marker if it exists
            search_start_idx = full_formatted_text.find(begin_sentence_marker)
            if search_start_idx == -1:
                search_start_idx = 0 # If not found, start from beginning

            user_marker_pos = full_formatted_text.find(user_start_marker, search_start_idx)
            if user_marker_pos == -1:
                raise ValueError(f"'{user_start_marker}' marker not found in formatted text.")

            # The prompt is between <｜User｜> and <｜Assistant｜>
            assistant_marker_pos = full_formatted_text.find(assistant_start_marker, user_marker_pos)
            if assistant_marker_pos == -1:
                raise ValueError(f"'{assistant_start_marker}' marker not found after user prompt.")

            original_user_prompt = full_formatted_text[user_marker_pos + len(user_start_marker) : assistant_marker_pos].strip()

            # The ground truth response is between <｜Assistant｜> and <｜end of sentence｜> (or end of string)
            end_sentence_marker_pos = full_formatted_text.rfind(end_sentence_marker)
            if end_sentence_marker_pos == -1:
                # If end_sentence_marker is not found, assume it goes to the end of the string
                ground_truth_response = full_formatted_text[assistant_marker_pos + len(assistant_start_marker) :].strip()
            else:
                ground_truth_response = full_formatted_text[assistant_marker_pos + len(assistant_start_marker) : end_sentence_marker_pos].strip()

            # Final check that parsed parts are not empty
            if not original_user_prompt or not ground_truth_response:
                raise ValueError("Parsed prompt or ground truth response is empty after token parsing.")

        except ValueError as ve:
            print(f"DEBUG: Skipping example {i+1} due to malformed chat template structure during parsing: {ve}")
            debug_full_text_snippet = full_formatted_text.replace('\n', ' ')
            print(f"DEBUG: Full text (showing markers): '{debug_full_text_snippet}'")
            continue # Skip to next example
        except Exception as e:
            print(f"DEBUG: Error parsing example {i+1}'s formatted text: {e}")
            debug_full_text_snippet_short = full_formatted_text[:200].replace('\n', ' ')
            print(f"DEBUG: Full text (snippet): '{debug_full_text_snippet_short}'")
            continue # Skip to next example


        # Prepare input for model generation
        input_messages = [{"role": "user", "content": original_user_prompt}]
        # Applying chat template for generation, similar to how it was done for training input.
        input_text = tokenizer.apply_chat_template(input_messages, tokenize=False, add_special_tokens=True)

        # Tokenize the input text
        input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)

        # Generate response from the model
        with torch.no_grad():
            outputs = model.generate(
                input_ids=input_ids,
                **generation_config
            )

        # Decode the generated tokens
        generated_text = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
        # Clean any remaining end sentence markers that might have been generated by the model
        generated_text = generated_text.replace(end_sentence_marker, "").strip()


        # --- BLEU Score Calculation ---
        reference_tokens = word_tokenize(ground_truth_response)
        candidate_tokens = word_tokenize(generated_text)

        if len(reference_tokens) > 0 and len(candidate_tokens) > 0:
            bleu_score = sentence_bleu([reference_tokens], candidate_tokens)
            total_bleu_score += bleu_score
            evaluated_count += 1
        else:
            bleu_score = 0.0
            print(f"Warning: Cannot calculate BLEU for example {i+1} due to empty reference or candidate tokens.")


        print(f"\n--- Evaluation Example {i+1} ---")
        print(f"User Prompt:\n{original_user_prompt}")
        print(f"\nModel's Response:\n{generated_text}")
        print(f"\nGround Truth:\n{ground_truth_response}")
        print(f"\nBLEU Score: {bleu_score:.4f}")
        print("-" * 70)
else:
    print("\nEvaluation skipped: eval_dataset_for_trainer is empty or not prepared. Check dataset preparation logs above.")


print("\nEvaluation complete.")

if evaluated_count > 0:
    average_bleu_score = total_bleu_score / evaluated_count
    print(f"\nAverage BLEU Score across {evaluated_count} evaluated examples: {average_bleu_score:.4f}")
else:
    print("\nNo examples were successfully evaluated for BLEU score.")

print("\nNote on BLEU Score Interpretation for this POC:")
print("BLEU measures the n-gram overlap between generated text and reference text.")
print("For this specific dataset and heuristic-based 'prompt'/'chosen' creation,")
print("a very high BLEU score might not be expected, as the model's generated")
print("continuation might diverge from the arbitrary 'ground truth' split, even if it's coherent.")
print("BLEU is most effective when the generated text is expected to be a close paraphrase or exact match.")
