In [None]:
# !jupyter nbconvert PMLM.ipynb --to python

**Imports**
---

In [None]:
import os
import re
import sys
from typing import Union, List
import random
import numpy as np
import torch
import pandas as pd
from pathlib import Path
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, MinLengthLogitsProcessor
from peft import PeftModel
from transformers import StoppingCriteria, StoppingCriteriaList

**Installations**
---

In [None]:
# !{sys.executable} -m pip install sentencepiece

**Config**
---

In [None]:
# Ensure HF_HOME is set explicitly before model download
os.environ["HF_HOME"] = "../huggingface_cache"
os.environ["HF_HUB_CACHE"] = "../huggingface_cache"

In [None]:
from huggingface_hub import login
login("INSERT_YOUR_OWN_TOKEN", add_to_git_credential=True)

In [None]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

In [None]:
from huggingface_hub import scan_cache_dir

# Scan the Hugging Face cache
cache_info = scan_cache_dir("../huggingface_cache")

# Print cache details
print(f"Cache size: {cache_info.size_on_disk / (1024**3):.2f} GB")  # Convert bytes to GB
print("Cached Repositories:")
for repo in cache_info.repos:
    print(f"- {repo.repo_id} ({repo.repo_type}): {repo.size_on_disk / (1024**3):.2f} GB")

**Hparams**
---

In [None]:
batch_size = 20
USING_MAX_DATASET = False
output_token_limit = 2048

if USING_MAX_DATASET:
    custom_prompts_data_path = "Datasets/10k_category_prompts.csv"
else:
    custom_prompts_data_path = "Datasets/1k_category_prompts.csv"
    
strategy_hint = (
        """
        Generate a [SYSTEM INSTRUCTION] based on the provided [USER REQUEST]. This [SYSTEM INSTRUCTION] will be combined 
        with the [USER REQUEST] and input into another language model to produce a watermarked output. 
        The [SYSTEM INSTRUCTION] should specify watermarking strategies that adapt dynamically to the content of the [USER REQUEST].
        Example [SYSTEM INSTRUCTION]: 'Use specific strategies to embed watermarks such as including special tokens or phrases that fit naturally with the content. The watermark should be later detectable by a classifier.'
        Example watermarking strategies:
        • Lexical Strategy: Incorporate specific rare or uncommon tokens as watermarks.
        • Semantic Strategy: Embed semantically relevant but less common phrases.
        • Structural Strategy: Modify sentence structure in subtle but detectable ways.
        • <You can add Strategies if necessary>
        Ensure watermarks are evenly distributed throughout the output.
        Your task is to output ONLY the [SYSTEM INSTRUCTION] that specifies the concrete watermarking strategy.
        """
    )

# Define the model names
MODEL_NAMES = {
    # Working PLM models
    "mistral_7b_v03_instruct": "mistralai/Mistral-7B-Instruct-v0.3",  #✅ Works
    
    # MLM models (teacher)
    "deepseek_llm_chat": "deepseek-ai/deepseek-llm-7b-chat",  # ✅ Works
    "qwen2.5_7b_instruct": "Qwen/Qwen2.5-7B-Instruct",  # ✅ Works
    "llama3_8b_instruct": "meta-llama/Meta-Llama-3-8B-Instruct",  # ✅ Works
    "gemma_7b_it": "google/gemma-7b-it",  # ✅ Works
    "ministral_8b_instruct": "mistralai/Ministral-8B-Instruct-2410",  #✅ Works
    "glm_4_9b_chat": "THUDM/glm-4-9b-chat",  # ✅ Works
    "internlm2.5_7b_chat": "internlm/internlm2-chat-7b",  # ✅ Works

    "gpt4o: manual"
    
    # Student models
    "mistral_7b_v02_instruct": "mistralai/Mistral-7B-Instruct-v0.2",  # ✅ Works
    "qwen1.5_1.8b_instruct": "Qwen/Qwen1.5-1.8B-Chat",  # ✅ Works
    "tinyllama_1.1b_chat": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",  # ✅ Works
    "gemma_1.1_2b_it": "google/gemma-1.1-2b-it",  # ✅ Works
}

# Tier	Model
# ⭐ Top-tier	LLaMA-3 8B Instruct
# ⭐ Top-tier	Qwen2.5-7B-Instruct
# ⭐ Top-tier	DeepSeek v2 7B-Chat
# ⭐ Mid-tier	Mistral-7B-Instruct-v0.3
# ⭐ Mid-tier	GEMMA-7B-IT
# ✅ Bonus	InternLM2.5-7B-Chat
# ✅ Bonus	GLM-4-9B-Chat

In [None]:
# Initial list:
# ---
# GPT3.5-turbo-0125
# QWEN-plus
# LLAMA3-8B
# QWEN2.5-7B
# QWEN2-1.5B
# vicuna_7b_v1_3
# vicuna_7b_v1_5
# open_llama_3b
# open_llama_7b
# mistral_7b_v03
# mistral_7b_v03_instruct
# baize_v2_7b
# GLM-4-plus
# GLM-3-Turbo
# LLAMA3-8B
# GEMMA-7B
# GPT4o-mini
# GPT-4o
# DEEPSEEK v2
# CLAUDE-3.5-sonnet
# INTERNLM2.5-7B

**Special case for ablation**
---

In [None]:
test_ablation = True
experiment = 'structural'

if experiment == 'lexical':
    strategy_hint = (
        """
        Generate a [SYSTEM INSTRUCTION] based on the provided [USER REQUEST]. This [SYSTEM INSTRUCTION] will be combined 
        with the [USER REQUEST] and input into another language model to produce a watermarked output. 
        The [SYSTEM INSTRUCTION] should specify watermarking strategies that adapt dynamically to the content of the [USER REQUEST].
        Example [SYSTEM INSTRUCTION]: 'Use specific strategies to embed watermarks such as including special tokens or phrases that fit naturally with the content. The watermark should be later detectable by a classifier.'
        Watermarking strategy:
        • Lexical Strategy: Incorporate specific rare or uncommon tokens as watermarks.
        Ensure watermarks are evenly distributed throughout the output.
        Your task is to output ONLY the [SYSTEM INSTRUCTION] that specifies the concrete watermarking strategy.
        """
    )
elif experiment == 'semantic':
    strategy_hint = (
        """
        Generate a [SYSTEM INSTRUCTION] based on the provided [USER REQUEST]. This [SYSTEM INSTRUCTION] will be combined 
        with the [USER REQUEST] and input into another language model to produce a watermarked output. 
        The [SYSTEM INSTRUCTION] should specify watermarking strategies that adapt dynamically to the content of the [USER REQUEST].
        Example [SYSTEM INSTRUCTION]: 'Use specific strategies to embed watermarks such as including special tokens or phrases that fit naturally with the content. The watermark should be later detectable by a classifier.'
        Watermarking strategy:
        • Semantic Strategy: Embed semantically relevant but less common phrases.
        Ensure watermarks are evenly distributed throughout the output.
        Your task is to output ONLY the [SYSTEM INSTRUCTION] that specifies the concrete watermarking strategy.
        """
    )
elif experiment == 'structural':
    strategy_hint = (
        """
        Generate a [SYSTEM INSTRUCTION] based on the provided [USER REQUEST]. This [SYSTEM INSTRUCTION] will be combined 
        with the [USER REQUEST] and input into another language model to produce a watermarked output. 
        The [SYSTEM INSTRUCTION] should specify watermarking strategies that adapt dynamically to the content of the [USER REQUEST].
        Example [SYSTEM INSTRUCTION]: 'Use specific strategies to embed watermarks such as including special tokens or phrases that fit naturally with the content. The watermark should be later detectable by a classifier.'
        Watermarking strategy:
        • Structural Strategy: Modify sentence structure in subtle but detectable ways.
        Ensure watermarks are evenly distributed throughout the output.
        Your task is to output ONLY the [SYSTEM INSTRUCTION] that specifies the concrete watermarking strategy.
        """
    )

In [None]:
if test_ablation:
    custom_prompts_data_path = "Datasets/test.csv"

**Dataset**
---

In [None]:
# Custom TextDataset class to handle both prompt and category
class TextDataset(Dataset):
    def __init__(self, requests):
        self.requests = requests  # Each item in requests is a tuple (prompt, category)

    def __len__(self):
        return len(self.requests)

    def __getitem__(self, idx):
        prompt, category = self.requests[idx]
        return {"prompt": prompt, "category": category}

In [None]:
# Helper function to prepare DataLoader
def prepare_dataset(requests, batch_size=batch_size):
    dataset = TextDataset(requests)
    return DataLoader(dataset, batch_size=batch_size, shuffle=False)

# Step 1: Load additional prompts from the CSV file
csv_prompts = pd.read_csv(custom_prompts_data_path)  # Replace with your file path

# Step 2: Extract prompts and their corresponding categories
user_requests = csv_prompts['prompt'].tolist()
categories = csv_prompts['category'].tolist()

# Step 3: Combine prompts and categories into tuples for the dataset
all_user_requests = list(zip(user_requests, categories))

# Step 4: Prepare the DataLoader with the CSV-only dataset
dataloader = prepare_dataset(all_user_requests, batch_size=batch_size)

In [None]:
len(dataloader)

**Prompt engineering**
---

In [None]:
def format_prompt(model_key: str, prompt: str) -> str:
    """Format prompt string based on model conventions for selected models only."""

    if "gpt" in model_key:
        return prompt
    
    elif "llama3" in model_key:
        return f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n{prompt}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n"

    elif "tinyllama" in model_key:
        return f"<|user|>\n{prompt}\n<|assistant|>\n"

    elif "mistral" in model_key or "ministral" in model_key:
        return f"<s>[INST]{prompt}[/INST]"

    elif "qwen" in model_key:
        return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant"

    elif "gemma" in model_key:
        return f"<start_of_turn>user\n{prompt}\n<end_of_turn>\n<start_of_turn>model\n"

    elif "internlm" in model_key:
        return f"<|User|>:{prompt}\n<|Bot|>:"

    elif "deepseek" in model_key:
        return f"### Instruction:\n{prompt}\n### Response:"

    elif "glm" in model_key:
        return f"[Round 1]\n\n问：{prompt}\n\n答："

    else:
        raise ValueError(f"Unknown model_key '{model_key}' in format_prompt.")

**Loading model**
---

In [None]:
def load_model(model_key: str, modified_model_path: str = None):
    if model_key not in MODEL_NAMES:
        raise ValueError(f"Invalid model key '{model_key}'. Choose from: {list(MODEL_NAMES.keys())}")
    
    # Tokenizer
    model_name = MODEL_NAMES[model_key]
    tokenizer_path = modified_model_path if modified_model_path else model_name
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_path,
        cache_dir=os.environ["HF_HOME"],
        trust_remote_code=True,
        padding_side="left"
    )
    
    tokenizer.pad_token_id = tokenizer.eos_token_id

    
    # Modified model
    if modified_model_path:
        print(f"Loading fine-tuned model from {modified_model_path}...")
        base_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            cache_dir=os.environ["HF_HOME"],
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True
        )

        # Safely load LoRA
        model = PeftModel.from_pretrained(base_model, modified_model_path)

        try:
            model = model.merge_and_unload()
            print(f"Successfully merged LoRA and base model!")
        except Exception as e:
            print(f"Warning: merge_and_unload failed, using adapters dynamically. {e}")

    # Base model
    else:
        print(f"Loading base model: {model_name}...")
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            cache_dir=os.environ["HF_HOME"],
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True
        )
        
    # Freeze model for inference
    for param in model.parameters():
        param.requires_grad = False
        
    return tokenizer, model

**Query**
---

In [None]:
from transformers import StoppingCriteria, StoppingCriteriaList

class StopOnNthString(StoppingCriteria):
    def __init__(self, stop_strings, tokenizer, occurrence=1):
        self.stop_strings = stop_strings
        self.tokenizer = tokenizer
        self.occurrence = occurrence
        self.counter = {s: 0 for s in stop_strings}
        self.buffer = ""

    def __call__(self, input_ids, scores, **kwargs):
        # Only decode the last 20 tokens to avoid full decode
        last_ids = input_ids[0, -20:].tolist()
        new_text = self.tokenizer.decode(last_ids, skip_special_tokens=True)
        self.buffer += new_text

        for stop_str in self.stop_strings:
            self.counter[stop_str] = self.buffer.count(stop_str)
            if self.counter[stop_str] >= self.occurrence:
                return True
        return False

def custom_stopping_criteria_for_strings(stop_strings, tokenizer, occurrence=1):
    return StoppingCriteriaList([StopOnNthString(stop_strings, tokenizer, occurrence=occurrence)])

In [None]:
def query_model(model_key: str, tokenizer, model, chat_prompts: Union[str, List[str]]) -> Union[str, List[str]]:
    """
    Generates responses from the selected model for a single or batch of prompts.
    """
    is_single = isinstance(chat_prompts, str)
    if is_single:
        chat_prompts = [chat_prompts]

    formatted_prompts = [format_prompt(model_key, p) for p in chat_prompts]
    inputs = tokenizer(
        formatted_prompts,
        return_tensors="pt",
        padding=True
    ).to(model.device)

    eos_token_id = torch.tensor([tokenizer.eos_token_id], device=model.device)

    generation_kwargs = {
        "max_new_tokens": output_token_limit,
        "do_sample": True,
        "temperature": 0.5,
        "top_p": 0.7,
        "eos_token_id": tokenizer.eos_token_id,
        "pad_token_id": tokenizer.eos_token_id,
        "use_cache": True,
    }
    # 🔥 SPECIAL CASE PATCHING
    if any(m in model_key for m in ["internlm", "glm_4"]):
        print(f"Note: For {model_key}, forcing safe generation parameters.")
        generation_kwargs["do_sample"] = False
        generation_kwargs["use_cache"] = False
        generation_kwargs["temperature"] = None
        generation_kwargs["top_p"] = None
        generation_kwargs["max_new_tokens"] = 512
        
    # GLM: Stop on FIRST [Round 2]
    if any(m in model_key for m in ["glm_4", "glm_3"]):
        generation_kwargs["stopping_criteria"] = custom_stopping_criteria_for_strings(
            ["[Round 2]"], tokenizer, occurrence=1
        )

    # InternLM: Stop on SECOND <|User|>:
    if "internlm" in model_key:
        generation_kwargs["stopping_criteria"] = custom_stopping_criteria_for_strings(
            ["<|User|>:"], tokenizer, occurrence=2
        )
        
    with torch.no_grad():
        outputs = model.generate(**inputs, **generation_kwargs)

    responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return responses[0] if is_single else responses

**Cleaning**
---

In [None]:
def clean_response(model_key: str, prompts: Union[str, List[str]], responses: Union[str, List[str]]) -> Union[str, List[str]]:
    """Cleans model responses based on prompt and model_key."""

    def clean_pair(prompt: str, response: str) -> str:
        prompt = prompt.strip()
        response = response.strip()

        if "gpt" in model_key:
            return response  # raw output, no cleaning
        
        elif "llama3" in model_key:
            if "assistant" in response:
                response = response.split("assistant", 1)[-1].strip()
            return response

        elif "tinyllama" in model_key:
            if "<|assistant|>" in response:
                response = response.split("<|assistant|>")[-1].strip()
            return response

        elif "mistral" in model_key or "ministral" in model_key:
            prompt_clean = re.sub(r"<s>|\[INST\]|\[/INST\]", "", prompt).strip()
            response_clean = re.sub(r"<s>|\[INST\]|\[/INST\]", "", response).strip()
            return response_clean[len(prompt_clean):].strip() if response_clean.startswith(prompt_clean) else response_clean

        elif "qwen" in model_key:
            match = re.search(r"assistant\s*", response, flags=re.IGNORECASE)
            if match:
                response = response[match.end():].strip()
            response = re.sub(r"^(assistant\s*)+", "", response, flags=re.IGNORECASE).strip()
            return response

        elif "gemma" in model_key:
            if "model" in response:
                parts = response.split("model", 1)
                return parts[1].strip() if len(parts) > 1 else response.strip()
            return response.strip()

        elif "internlm" in model_key:
            # Keep everything between first <|Bot|>: and second <|User|>: (truncate at second <|User|>:)
            if "<|User|>:" in response and "<|Bot|>:" in response:
                first_bot_idx = response.find("<|Bot|>:")
                second_user_idx = response.find("<|User|>:", first_bot_idx + 1)
                if second_user_idx != -1:
                    return response[first_bot_idx + len("<|Bot|>:"):second_user_idx].strip()
                else:
                    return response[first_bot_idx + len("<|Bot|>:"):].strip()
            else:
                return response.strip()

        elif "deepseek" in model_key:
            prefix = f"### Instruction:\n{prompt}\n### Response:"
            return response[len(prefix):].strip() if response.startswith(prefix) else response

        elif "glm" in model_key:
            # Remove up to and including the prefix
            prefix = f"[Round 1]\n\n问：{prompt}\n\n答："
            if response.startswith(prefix):
                response = response[len(prefix):].strip()
            # Remove everything after [Round 2]
            if "[Round 2]" in response:
                response = response.split("[Round 2]")[0].strip()
            return response

    if isinstance(prompts, str) and isinstance(responses, str):
        return clean_pair(prompts, responses)

    if isinstance(prompts, list) and isinstance(responses, list):
        return [clean_pair(p, r) for p, r in zip(prompts, responses)]

    raise ValueError("prompts and responses must both be str or List[str]")

**Single prompt**
---

In [None]:
# # Choose the model you want to use
# MODEL_KEY = "mistral_7b_v03_instruct"
# tokenizer, model = load_model(MODEL_KEY)

In [None]:
# # Example usage
# test_prompt = "Write a story about a dog going to the moon."
# # compound_prompt = f"[USER REQUEST]: '{test_prompt}'\n{strategy_hint}"
# compound_prompt = f"[USER REQUEST]: '{test_prompt}'\n{strategy_hint}"

# response = query_model(MODEL_KEY, tokenizer, model, compound_prompt)

In [None]:
# print(test_prompt)
# print(compound_prompt)

In [None]:
# print(response)

In [None]:
# cleaned_response = clean_response(MODEL_KEY, compound_prompt, response)

In [None]:
# print(cleaned_response)

**Batch Prompts**
---

In [None]:
# BREAKPOINT_0

In [None]:
# Choose the model you want to use
PROMPING_MODEL_KEY = "llama3_8b_instruct"
prompting_tokenizer, prompting_model = load_model(PROMPING_MODEL_KEY)

MARKING_MODEL_KEY = "llama3_8b_instruct"
marking_tokenizer, marking_model = load_model(MARKING_MODEL_KEY)

# fine_tuned_path = f"tuned_models/fine_tuned_{MARKING_MODEL_KEY}"
# marking_tokenizer, marking_model = load_model(MARKING_MODEL_KEY, fine_tuned_path)

# TEACHER_MODEL_KEY = "deepseek_llm_chat"
# distilled_path = f"tuned_models/distilled_{TEACHER_MODEL_KEY}_to_{MARKING_MODEL_KEY}"
# marking_tokenizer, marking_model = load_model(MARKING_MODEL_KEY, distilled_path)

In [None]:
# Test on a single batch from train_loader
for idx, batch in enumerate(dataloader):
    prompts = batch['prompt']  # Get the prompts from the batch
    categories = batch['category']  # Get the category from the batch

    prompts_for_promptingLLM = [f"[USER REQUEST]: '{p}'\n{strategy_hint}" for p in prompts]
    system_instructions = query_model(PROMPING_MODEL_KEY, prompting_tokenizer, prompting_model, prompts_for_promptingLLM)
    system_instructions = [clean_response(PROMPING_MODEL_KEY, p, s) for p, s in zip(prompts_for_promptingLLM, system_instructions)]
    
    prompts_for_markingLLM_nowatermark = [f"{p}\nThe response must be a paragraph under 100 words." for p in prompts]
    non_watermarked_responses = query_model(MARKING_MODEL_KEY, marking_tokenizer, marking_model, prompts_for_markingLLM_nowatermark)
    
    prompts_for_markingLLM_watermark = [f"[USER REQUEST]: {p}\n[SYSTEM INSTRUCTION]: {s}\nThe response must be a paragraph under 100 words." for p, s in zip(prompts, system_instructions)]
    watermarked_responses = query_model(MARKING_MODEL_KEY, marking_tokenizer, marking_model, prompts_for_markingLLM_watermark)
    break

In [None]:
cleaned_non_watermarked_responses = [clean_response(MARKING_MODEL_KEY, p, r) for p, r in zip(prompts_for_markingLLM_nowatermark, non_watermarked_responses)]
cleaned_watermarked_responses = [clean_response(MARKING_MODEL_KEY, p, r) for p, r in zip(prompts_for_markingLLM_watermark, watermarked_responses)]

In [None]:
print(f"[USER REQUESTS]:\n----------------------------------\n{prompts}\n")
print(f"[INPUTS FOR PROMPTING_LLM]:\n----------------------------------\n{prompts_for_promptingLLM}\n")
print(f"[SYSTEM INSTRUCTIONS]:\n----------------------------------\n{system_instructions}\n")

print(f"[INPUTS FOR MARKING_LLM] (WATERMARK):\n----------------------------------\n{prompts_for_markingLLM_watermark}\n")
print(f"[WATERMARKED RESPONSES]:\n----------------------------------\n{watermarked_responses}\n")
print(f"[WATERMARKED RESPONSES] (cleaned):\n----------------------------------\n{cleaned_watermarked_responses}\n")

print(f"[INPUTS FOR MARKING_LLM (NO WATERMARK)]:\n----------------------------------\n{prompts_for_markingLLM_nowatermark}\n")
print(f"[NON WATERMARKED RESPONSES]:\n----------------------------------\n{non_watermarked_responses}")
print(f"[NON WATERMARKED RESPONSES] (cleaned):\n----------------------------------\n{cleaned_non_watermarked_responses}")

In [None]:
print('unclean ----------------')
print(watermarked_responses[1])
print('cleaned ----------------')
print(cleaned_watermarked_responses[1])
print('unclean ----------------')
print(non_watermarked_responses[1])
print('cleaned ----------------')
print(cleaned_non_watermarked_responses[1])

**Data generation**
---

In [None]:
# BREAKPOINT_1

In [None]:
# Choose the model you want to use
PROMPING_MODEL_KEY = "mistral_7b_v03_instruct"
prompting_tokenizer, prompting_model = load_model(PROMPING_MODEL_KEY)

MARKING_MODEL_KEY = "mistral_7b_v03_instruct"
marking_tokenizer, marking_model = load_model(MARKING_MODEL_KEY)

# fine_tuned_path = f"tuned_models/fine_tuned_{MARKING_MODEL_KEY}"
# marking_tokenizer, marking_model = load_model(MARKING_MODEL_KEY, fine_tuned_path)

# TEACHER_MODEL_KEY = "deepseek_llm_chat"
# distilled_path = f"tuned_models/distilled_{TEACHER_MODEL_KEY}_to_{MARKING_MODEL_KEY}"
# marking_tokenizer, marking_model = load_model(MARKING_MODEL_KEY, distilled_path)

In [None]:
# Path to the output directory
output_directory = f"Datasets/1k/PLM_{PROMPING_MODEL_KEY}/{MARKING_MODEL_KEY}_data"
# output_directory = f"Datasets/1k/PLM_{PROMPING_MODEL_KEY}/fine_tuned_{MARKING_MODEL_KEY}_data"
# output_directory = f"Datasets/1k/PLM_{PROMPING_MODEL_KEY}/distilled_{TEACHER_MODEL_KEY}_to_{MARKING_MODEL_KEY}_data"

if test_ablation:
    print("Testing on 300")
elif USING_MAX_DATASET:
    print("20k dataset woo!")
    output_directory = f"Datasets/10k/PLM_{PROMPING_MODEL_KEY}/{MARKING_MODEL_KEY}_data"
    # output_directory = f"Datasets/10k/PLM_{PROMPING_MODEL_KEY}/fine_tuned_{MARKING_MODEL_KEY}_data"
    # output_directory = f"Datasets/10k/PLM_{PROMPING_MODEL_KEY}/distilled_{TEACHER_MODEL_KEY}_to_{MARKING_MODEL_KEY}_data"
else:
    print("2k dataset smh")

**Special case for testing**
---

In [None]:
if test_ablation:
    output_directory = f"ablation/PLM_{PROMPING_MODEL_KEY}/{MARKING_MODEL_KEY}_data/{experiment}"

In [None]:
# Check if the directory exists, and create it if it doesn't
if not os.path.exists(output_directory):
    os.makedirs(output_directory)

# Helper function to create an empty CSV file with headers if it doesn't exist
def initialize_csv(file_path):
    if not os.path.exists(file_path):
        # Create a new DataFrame with headers if the file doesn't exist and save it as a CSV file
        df = pd.DataFrame(columns=["CATEGORY", "USER REQUEST", "SYSTEM INSTRUCTION", "WATERMARKED RESPONSE", "NON-WATERMARKED RESPONSE"])
        df.to_csv(file_path, index=False)

In [None]:
def process_requests(dataloader, prompting_tokenizer, prompting_model, marking_tokenizer, marking_model, start_file_index=1, sub_req=100):
    """
    Processes batches of requests and saves the responses in CSV files.

    Args:
        dataloader: The data loader for batched inputs.
        tokenizer: The tokenizer associated with the model.
        model: The loaded language model.
        start_file_index: The starting index for output CSV files.
        start_request_index: The request index to resume processing.
        sub_req: The number of requests to save before switching to a new CSV file.
    """
    start_request_index = (start_file_index - 1) * sub_req
    count = 0
    file_index = start_file_index
    data_accumulator = []

    for idx, batch in enumerate(dataloader):
        if idx * batch_size < start_request_index:
            print(idx * batch_size, start_request_index)
            continue

        # Extract batch data
        prompts = batch['prompt']  # Get batch of prompts
        categories = batch['category']  # Get batch of categories

        # **Step 1: Generate System Instructions**
        prompts_for_promptingLLM = [f"[USER REQUEST]: '{p}'\n{strategy_hint}" for p in prompts]
        system_instructions = query_model(PROMPING_MODEL_KEY, prompting_tokenizer, prompting_model, prompts_for_promptingLLM)
        system_instructions = [clean_response(PROMPING_MODEL_KEY, p, s) for p, s in zip(prompts_for_promptingLLM, system_instructions)]

        # **Step 2: Generate Non-Watermarked Responses**
        prompts_for_markingLLM_nowatermark = [f"{p}\nThe response must be a paragraph under 100 words." for p in prompts]
        non_watermarked_responses = query_model(MARKING_MODEL_KEY, marking_tokenizer, marking_model, prompts_for_markingLLM_nowatermark)
        non_watermarked_responses = [clean_response(MARKING_MODEL_KEY, p, r) for p, r in zip(prompts_for_markingLLM_nowatermark, non_watermarked_responses)]

        # **Step 3: Generate Watermarked Responses**
        prompts_for_markingLLM_watermark = [f"[USER REQUEST]: {p}\n[SYSTEM INSTRUCTION]: {s}\nThe response must be a paragraph under 100 words." for p, s in zip(prompts, system_instructions)]
        watermarked_responses = query_model(MARKING_MODEL_KEY, marking_tokenizer, marking_model, prompts_for_markingLLM_watermark)
        watermarked_responses = [clean_response(MARKING_MODEL_KEY, p, r) for p, r in zip(prompts_for_markingLLM_watermark, watermarked_responses)]

        # **Step 4: Append data to the accumulator**
        for category, prompt, system_inst, wm_response, nwm_response in zip(categories, prompts, system_instructions, watermarked_responses, non_watermarked_responses):
            data_accumulator.append({
                "CATEGORY": category,
                "USER REQUEST": prompt,
                "SYSTEM INSTRUCTION": system_inst,
                "WATERMARKED RESPONSE": wm_response,
                "NON-WATERMARKED RESPONSE": nwm_response
            })
            # print(f"{count + start_request_index}. {category}: {prompt}")

            count += 1
            print(f"Processed {count + start_request_index} user requests.")

        # **Step 5: Save Data in Batches to CSV**
        if count % sub_req == 0:
            output_file_path = os.path.join(output_directory, f"{file_index}.csv")

            # Initialize CSV if not exists
            initialize_csv(output_file_path)

            # Save accumulated data
            pd.DataFrame(data_accumulator).to_csv(output_file_path, mode='a', index=False, header=False)
            print(f"Saved {output_file_path} with {count + start_request_index} requests.")

            # Clear accumulator and increment file index
            data_accumulator = []
            file_index += 1

    # **Step 6: Save any remaining data**
    if data_accumulator:
        output_file_path = os.path.join(output_directory, f"{file_index}.csv")
        initialize_csv(output_file_path)
        pd.DataFrame(data_accumulator).to_csv(output_file_path, mode='a', index=False, header=False)
        print(f"Saved {output_file_path} with remaining {len(data_accumulator)} requests.")

In [None]:
# Example usage:
process_requests(dataloader, prompting_tokenizer, prompting_model, marking_tokenizer, marking_model, start_file_index=1, sub_req=100)