In [None]:
# --- Cell 1: Setup: Install Libraries ---
print("--- Installing Libraries ---")

# Uninstall potentially conflicting versions first
#!pip uninstall torch torchvision torchaudio transformers accelerate bitsandbytes -y -q
print("--- Uninstalled existing versions ---")

# Reinstall PyTorch, torchvision, and torchaudio together to ensure compatibility
# Let pip determine the compatible versions based on your Colab environment's CUDA
print("--- Installing PyTorch, torchvision, torchaudio ---")
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 -q
# (Using cu121 as Colab often uses CUDA 12.1. Adjust if your environment differs, though Colab usually manages this well)

# Install specific or minimum versions for other critical libraries
print("--- Installing transformers, accelerate, bitsandbytes ---")
!pip install transformers>=4.51.0 -q
!pip install accelerate>=0.28.0 -q
!pip install bitsandbytes>=0.41.3 -q
!pip install sentencepiece -q

print("\n--- Checking Installed Library Versions ---")
import torch
import torchvision
import transformers
import accelerate
import bitsandbytes

print(f"PyTorch version: {torch.__version__}")
print(f"Torchvision version: {torchvision.__version__}")
print(f"Transformers version: {transformers.__version__}")
print(f"Accelerate version: {accelerate.__version__}")
print(f"BitsandBytes version: {bitsandbytes.__version__}")
print("--- Libraries Installed ---")


--- Installing Libraries ---
--- Uninstalled existing versions ---
--- Installing PyTorch, torchvision, torchaudio ---
--- Installing transformers, accelerate, bitsandbytes ---

--- Checking Installed Library Versions ---
PyTorch version: 2.5.1+cu121
Torchvision version: 0.20.1+cu121
Transformers version: 4.51.3
Accelerate version: 1.6.0
BitsandBytes version: 0.45.5
--- Libraries Installed ---


In [None]:

# --- Cell 2: Import Libraries ---
# (No changes needed)
print("\n--- Importing Libraries ---")
import torch
import torch.nn.functional as F # For padding
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, LogitsProcessorList, MinLengthLogitsProcessor, StoppingCriteriaList, MaxLengthCriteria
import gc
import re
import time # For timing comparison
import traceback # For detailed error printing
from typing import List, Dict, Union # Added for type hinting
import os
import json
import numpy
import numpy as np
print("--- Libraries Imported ---")


# --- Cell 3: Configuration: Model Names, Quantization, Device ---
# (No changes needed)
print("\n--- Configuring Models and Device ---")

# --- ADD THESE ---

print(f"--- Using Fixed Task Template ---")

CHUNK_SIZE_FOR_FOCUSED_SMALL_MODELS = 128
print(f"Chunk size for focused M_Small: {CHUNK_SIZE_FOR_FOCUSED_SMALL_MODELS} tokens")

# --- Paths for manual run ---
PROJECT_BASE_PATH = '/content/' # Base path in Colab
NEEDLE_SET_HARD_PATH = os.path.join(PROJECT_BASE_PATH, "needle_set_hard.json")
HAYSTACK_BOOK_PATH = os.path.join(PROJECT_BASE_PATH, "my_book.txt")

# --- Test Parameters for this Simplified Run ---
# CHOOSE ONE CONTEXT LENGTH TO START, e.g., 32k. Max your M_Large can handle.
CONTEXT_LENGTHS_TO_TEST = [2048]
DEPTH_PERCENTAGES_TO_TEST = [0.5] # Test a few depths
ALPHA_MOD = 1 # Not used in baseline mode, but kept for completeness
MAX_NEW_TOKENS_GENERATION = 50
GLOBAL_SYSTEM_PROMPT_FOR_RUN = "You are a helpful AI assistant. Use the information provided in the book snippet to answer the question. Your answer should be short and based on either explicitly stated facts or strong, logical inferences."
ENABLE_QWEN3_THINKING_FOR_RUN = True # Set to False for baseline as Qwen3 thinking is often a fine-tuning feature


# --- Specify your Qwen2/Qwen3 models ---
model_large_name = "Qwen/Qwen3-1.7B" # Example: Using Qwen3 as the large model
model_small_name = "Qwen/Qwen3-0.6B"   # Example: Using a smaller Qwen3 for deltas

print(f"--- Model Configuration ---")
print(f"Large (M_Large):    {model_large_name}")
print(f"Small (M_Small):    {model_small_name} (for ICL-based deltas)")
print("-------------------------")

# --- Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if not torch.cuda.is_available():
    print("WARNING: CUDA not available, running on CPU will be extremely slow.")

# --- Quantization Configuration (Applied to ALL models) ---
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)
#bnb_config = BitsAndBytesConfig(
#    load_in_8bit=True,
#    bnb_8bit_compute_dtype=torch.bfloat16 # Keep compute_dtype for potential performance benefits
#)
print("Using 4-bit NF4 quantization with bfloat16 compute dtype for all models.")
print("--- Configuration Complete ---")


# --- Cell 4: Load Models and Tokenizer ---
# (MODIFIED for Qwen3 considerations)
print("\n--- Loading Tokenizer and Models ---")

# --- Load Tokenizer (Use tokenizer from the large Qwen3 model) ---
print("Loading Tokenizer...")
try:
    # For Qwen3's `enable_thinking`, it's crucial this tokenizer supports it.
    # Typically, this means loading the tokenizer from the specific Qwen3 model.
    tokenizer = AutoTokenizer.from_pretrained(model_large_name, trust_remote_code=True) # Ensure model_large_name is a Qwen3 model

    # Qwen3 tokenizer should have a chat_template. This check is still good practice.
    if tokenizer.chat_template is None:
         print("WARNING: Tokenizer does not have a chat_template defined. Attempting to set a default Qwen template.")
         # This default might need adjustment for Qwen3 if it differs significantly
         tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{% if enable_thinking %}{{ '<|im_start|>assistant\n<think>' }}{% else %}{{ '<|im_start|>assistant\n' }}{% endif %}{% endif %}"
         print("Default Qwen chat template (with basic thinking tag consideration) applied. VERIFY THIS IS CORRECT FOR QWEN3.")
         # The above template is a guess. Qwen3's actual template might be more complex.
         # It's best if the tokenizer.chat_template is already correctly defined by from_pretrained.

    if tokenizer.pad_token is None:
        if tokenizer.eos_token is not None:
            tokenizer.pad_token = tokenizer.eos_token
            print(f"Tokenizer pad_token set to eos_token ({tokenizer.eos_token}).")
        else:
            print("WARNING: Tokenizer has no EOS token. Adding a default PAD token '<|pad|>' for Qwen3.")
            # Qwen3 might have a specific pad token or handle padding differently.
            # For Qwen2 it was <|endoftext|> (ID 151643) or <|extra_0|> if you add one.
            # For Qwen3, if eos_token is None, this needs careful checking.
            # Let's assume eos_token will be present for Qwen3.
            tokenizer.add_special_tokens({'pad_token': '<|pad|>'}) # A generic pad if truly needed
    if tokenizer.pad_token_id is None:
         raise ValueError("Tokenizer pad_token_id is None even after setting pad_token. Cannot proceed.")
    print(f"Using PAD token ID: {tokenizer.pad_token_id}")
    tokenizer.padding_side = "left"
    print(f"Tokenizer padding side set to '{tokenizer.padding_side}'.")

    # Get the ID for </think> token for parsing, if it exists
    # The ID 151668 was for Qwen/Qwen3-30B-A3B. It might vary for other Qwen3 models.
    # It's safer to get it dynamically if possible, or make it a constant if you know the model.
    try:
        think_end_token_id_qwen3 = tokenizer.encode("</think>", add_special_tokens=False)[0]
        print(f"Qwen3 '</think>' token ID found: {think_end_token_id_qwen3}")
    except:
        think_end_token_id_qwen3 = 151668 # Fallback to the example ID
        print(f"Warning: Could not dynamically encode '</think>'. Using fallback ID: {think_end_token_id_qwen3}. VERIFY THIS ID.")

    print("Tokenizer Loaded Successfully.")
except Exception as e:
    print(f"ERROR loading tokenizer for {model_large_name}: {e}")
    traceback.print_exc()
    raise


def top_k_top_p_filtering(logits: torch.Tensor,
                          top_k: int = 0,
                          top_p: float = 0.0,
                          filter_value: float = -float('Inf'),
                          min_tokens_to_keep: int = 1) -> torch.Tensor:
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (batch size, vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                         Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
            filter_value: value to assign to filtered logits.
            min_tokens_to_keep: minimum number of tokens we must keep (default 1).

        Returns:
            Logits with filtered elements set to filter_value.
            Shape: (batch size, vocabulary size)
    """
    if top_k == 0 and top_p == 0.0:
        return logits # No filtering needed

    if logits.ndim == 1:
        # print("top_k_top_p_filtering: Warning: logits is 1D, unsqueezing to add batch dimension.") # Removed for speed
        logits = logits.unsqueeze(0)

    batch_size, vocab_size = logits.size()

    if top_k > 0:
        # Safety check: ensure top_k is not larger than vocab size
        top_k = min(max(top_k, min_tokens_to_keep), vocab_size)
        # Keep at least min_tokens_to_keep (default 1) tokens

        # Find the top_k values and their indices for each batch item
        topk_values, _ = torch.topk(logits, top_k, dim=-1)

        # Get the k-th value for each batch item (shape: batch_size, 1)
        kth_value = topk_values[..., -1, None]

        # Create a mask for values less than the k-th value
        indices_to_remove = logits < kth_value
        # Apply the filter
        logits = logits.masked_fill(indices_to_remove, filter_value)

    if top_p > 0.0:
        # Sort logits in descending order
        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)

        # Calculate cumulative probabilities
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold (nucleus filtering)
        sorted_indices_to_remove = cumulative_probs > top_p

        # Ensure we keep at least min_tokens_to_keep tokens
        if min_tokens_to_keep > 1:
             sorted_indices_to_remove[..., :min_tokens_to_keep] = False
        else:
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = False

        # Create a final mask for the original logits tensor
        indices_to_remove = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
        indices_to_remove.scatter_(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)

        # Apply the filter
        logits = logits.masked_fill(indices_to_remove, filter_value)

    return logits
# --- End of modeling.utils inclusion ---

# --- Function to load model ---
def load_model(model_name, config, device):
    print(f"Loading Model: {model_name}...")
    try:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=config,
            device_map="auto", # Handles multi-GPU and places layers optimally
            trust_remote_code=True,
            torch_dtype=torch.bfloat16 # Consistent dtype
        )
        model.eval()
        print(f"{model_name} Loaded Successfully.")
        try:
            mem_bytes = model.get_memory_footprint()
            print(f"Estimated memory footprint for {model_name}: {mem_bytes / 1e9:.2f} GB")
        except Exception:
            print("Could not estimate memory footprint automatically.")
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return model
    except Exception as e:
        print(f"ERROR loading model {model_name}: {e}")
        traceback.print_exc()
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return None

# --- Load Models ---
model_large = load_model(model_large_name, bnb_config, device)
model_small = load_model(model_small_name, bnb_config, device)

if not all([model_large, model_small]):
    # ... (error handling for model loading) ...
    raise RuntimeError("One or more models failed to load. Cannot proceed. Check logs above.")
else:
    print("\nAll models loaded successfully with 4-bit quantization.")
    # ... (resize token embeddings if necessary, same logic as before) ...

print("--- Tokenizer and Models Loaded ---")



# --- Cell 5: Implement Mixture of Dexperts Generation (BATCHED with KV Cache) ---
print("\n--- Defining BookHaystack Class & MoD Generation ---")

# --- BookHaystack Class (simplified version) ---
class BookHaystack:
    def __init__(self, book_path: str, tokenizer_instance: AutoTokenizer): # Takes tokenizer instance
        self.book_path = book_path
        self.tokenizer = tokenizer_instance # Store it
        if not os.path.exists(book_path):
            raise FileNotFoundError(f"Book path {book_path} does not exist")
        with open(book_path, 'r', encoding='utf-8') as f:
            self.text = f.read()
        self.text_encoded_full = None # Store full book encoding

    def _get_book_tokens(self):
        if self.text_encoded_full is None:
            print("Encoding full book for BookHaystack (one time)...")
            self.text_encoded_full = self.tokenizer.encode(self.text, add_special_tokens=False)
            print(f"Book has {len(self.text_encoded_full)} tokens.")
        return self.text_encoded_full

    def get_haystack_with_needle(self, needle_text: str, target_haystack_len: int, depth_percentage: float) -> str:
        """
        Creates a haystack of approx. target_haystack_len tokens from the book,
        with needle_text inserted at depth_percentage.
        """
        book_tokens = self._get_book_tokens()
        if not book_tokens: return needle_text # Should not happen if book is loaded

        # 1. Determine the slice of the book to use as the base for the haystack
        # We want the final haystack (book part + needle) to be around target_haystack_len
        # This is a simplified approach. A more robust one might take needle length into account earlier.

        # For very long books and large target_haystack_len, we might not start from token 0.
        # For simplicity here, we'll always try to take from the beginning of the book.
        book_tokens_for_haystack = book_tokens[:target_haystack_len]

        # 2. Insert needle into this slice
        needle_tokens = self.tokenizer.encode(" " + needle_text + "\n", add_special_tokens=False)

        insertion_point_in_slice = int(len(book_tokens_for_haystack) * depth_percentage)
        insertion_point_in_slice = max(0, min(insertion_point_in_slice, len(book_tokens_for_haystack)))

        tokens_before = book_tokens_for_haystack[:insertion_point_in_slice]
        tokens_after = book_tokens_for_haystack[insertion_point_in_slice:]

        final_haystack_tokens = tokens_before + needle_tokens + tokens_after

        # Truncate if the addition of needle made it exceed target_haystack_len by too much
        # (This is a simple safety; NoLiMa's context is usually about model input capacity)
        if len(final_haystack_tokens) > target_haystack_len + len(needle_tokens): # Allow some overflow for needle
             final_haystack_tokens = final_haystack_tokens[:target_haystack_len + len(needle_tokens)]

        return self.tokenizer.decode(final_haystack_tokens)


print("\n--- Defining BATCHED MoD Generation Function (Qwen3 Thinking Aware) ---")

SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"

@torch.inference_mode()
def generate_mod_batch_kv_cache(
    haystacks_full_text: List[str], # MODIFIED: List of haystack strings for the batch
    questions_text: List[str],     # MODIFIED: List of question strings for the batch
    max_new_tokens: int = 150,
    temperature: float = 0.6,
    top_k: int = 50,
    top_p: float = 0.0,
    alpha: float = 1.0,
    delta_magnitude_threshold: float = 0.0,
    # These _in parameters are to avoid conflict if globals with same name exist
    global_system_prompt_in: str = "You are a helpful AI assistant.",
    enable_thinking_in: bool = True,
    qwen3_think_end_token_id_in: int = 151668, # Fallback
    chunk_size_small_in: int = 8192 # Fallback
) -> Dict[str, Union[List[str], List[str], float, str]]:
    global tokenizer, model_large, model_small, device # Using globals for this manual script

    if not all([model_large, model_small, tokenizer]):
        return {"outputs": [], "thinking_outputs": [], "tokens_per_second": 0.0, "error": "Models or tokenizer not loaded."}
    if len(haystacks_full_text) != len(questions_text):
        return {"outputs": [], "thinking_outputs": [], "tokens_per_second": 0.0, "error": "Haystacks and questions must have the same batch size."}


    generation_start_time = time.time()
    batch_size = len(haystacks_full_text)

    # --- Prepare Initial Inputs ---

    # 1. For M_Large (SYSTEM + FULL_HAYSTACK + QUESTION)
    prompts_for_large_model_templated_text = []
    for i in range(batch_size):
        # MODIFIED: Construct combined user content for M_Large
        # The NoLiMa template is "Haystack: {haystack} ... Question: {question}"
        # We recreate this structure if haystacks_full_text[i] is the haystack with needle
        # and questions_text[i] is the retrieval question.
        # Note: The exact formatting might depend on how NoLiMa's original `task_template`
        #       was structured if you were using it. For a manual run, this is a direct way.
        #questions_text[i] += '<think>\n'
        if 'You can only think for 20 words, then give a one word answer.' not in questions_text[i]:
          questions_text[i] += 'You can only think for 20 words, then give a one word answer.\n<think>'

        #combined_user_content_for_large = "Question: " + questions_text[i]
        combined_user_content_for_large = haystacks_full_text[i] + "\n\nQuestion: " + questions_text[i]

        messages_large = [
            {"role": "system", "content": global_system_prompt_in},
            {"role": "user", "content": combined_user_content_for_large}
        ]
        prompts_for_large_model_templated_text.append(tokenizer.apply_chat_template(
            messages_large, add_generation_prompt=True, tokenize=False, enable_thinking=enable_thinking_in
        ))
    initial_inputs_large = tokenizer(prompts_for_large_model_templated_text, return_tensors="pt", padding=True, truncation=True, max_length=model_large.config.max_position_embeddings - 10).to(device)
    prompt_lengths_large = [torch.sum(mask).item() for mask in initial_inputs_large["attention_mask"]] # Used for decoding later

    # 2. For M_Small_unprimed (SYSTEM + QUESTION_ONLY)
    prompts_for_small_unprimed_templated_text = []
    for i in range(batch_size):
        # MODIFIED: User content is ONLY the question
        messages_small_unprimed = [
            {"role": "system", "content": global_system_prompt_in},
            {"role": "user", "content": questions_text[i]}
        ]
        prompts_for_small_unprimed_templated_text.append(tokenizer.apply_chat_template(
            messages_small_unprimed, add_generation_prompt=True, tokenize=False, enable_thinking=enable_thinking_in
        ))
    initial_inputs_small_unprimed = tokenizer(prompts_for_small_unprimed_templated_text, return_tensors="pt", padding=True, truncation=True, max_length=model_small.config.max_position_embeddings - 10).to(device)

    # 3. For M_Small_focused_chunk_i (SYSTEM + CHUNK_i_OF_HAYSTACK + QUESTION)
    #    MODIFIED: This whole section is new/adapted from previous chunking logic
    temp_chunks_plus_question_for_prompts = [[] for _ in range(batch_size)]
    max_num_chunks_across_batch = 0

    for i in range(batch_size):
        current_haystack_text = haystacks_full_text[i] # This is haystack_with_needle
        current_question_text = questions_text[i]

        haystack_only_tokens = tokenizer.encode(current_haystack_text, add_special_tokens=False)
        num_chunks_for_this_prompt = (len(haystack_only_tokens) + chunk_size_small_in - 1) // chunk_size_small_in
        if num_chunks_for_this_prompt == 0 and len(haystack_only_tokens) > 0: num_chunks_for_this_prompt = 1
        elif len(haystack_only_tokens) == 0: num_chunks_for_this_prompt = 0
        max_num_chunks_across_batch = max(max_num_chunks_across_batch, num_chunks_for_this_prompt)

        for chunk_idx in range(num_chunks_for_this_prompt):
            start_token_idx = chunk_idx * chunk_size_small_in
            end_token_idx = min((chunk_idx + 1) * chunk_size_small_in, len(haystack_only_tokens))
            current_haystack_chunk_token_ids = haystack_only_tokens[start_token_idx:end_token_idx]
            current_haystack_chunk_str = tokenizer.decode(current_haystack_chunk_token_ids)

            combined_user_content_for_focused = current_haystack_chunk_str + "\n\nQuestion: " + current_question_text
            chunk_plus_question_messages = [
                {"role": "system", "content": global_system_prompt_in},
                {"role": "user", "content": combined_user_content_for_focused}
            ]
            templated_prompt = tokenizer.apply_chat_template(
                chunk_plus_question_messages, add_generation_prompt=True, tokenize=False, enable_thinking=enable_thinking_in
            )
            tokenized_data = tokenizer(
                templated_prompt, return_tensors="pt", truncation=True,
                max_length=model_small.config.max_position_embeddings - 10
            )
            temp_chunks_plus_question_for_prompts[i].append({
                "input_ids": tokenized_data.input_ids.squeeze(0),
                "attention_mask": tokenized_data.attention_mask.squeeze(0)
            })

    initial_inputs_small_focused_list = []
    if max_num_chunks_across_batch > 0:
        for chunk_k_idx in range(max_num_chunks_across_batch):
            batch_input_ids_k, batch_attn_masks_k = [], []
            max_len_k = 0
            for prompt_idx in range(batch_size):
                if chunk_k_idx < len(temp_chunks_plus_question_for_prompts[prompt_idx]):
                    ids = temp_chunks_plus_question_for_prompts[prompt_idx][chunk_k_idx]["input_ids"]
                    mask = temp_chunks_plus_question_for_prompts[prompt_idx][chunk_k_idx]["attention_mask"]
                    batch_input_ids_k.append(ids)
                    batch_attn_masks_k.append(mask)
                    max_len_k = max(max_len_k, ids.size(0))

            padded_ids_list_k, padded_masks_list_k = [], []
            for idx_in_batch in range(batch_size): # Iterate up to batch_size to ensure all spots are filled
                if chunk_k_idx < len(temp_chunks_plus_question_for_prompts[idx_in_batch]): # If this prompt had this chunk
                    ids = temp_chunks_plus_question_for_prompts[idx_in_batch][chunk_k_idx]["input_ids"]
                    mask = temp_chunks_plus_question_for_prompts[idx_in_batch][chunk_k_idx]["attention_mask"]
                    pad_len = max_len_k - ids.size(0)
                    padded_ids_list_k.append(F.pad(ids, (pad_len, 0), value=tokenizer.pad_token_id))
                    padded_masks_list_k.append(F.pad(mask, (pad_len, 0), value=0))
                else: # This prompt was shorter than k chunks, add full padding
                    padded_ids_list_k.append(torch.full((max_len_k,), tokenizer.pad_token_id, dtype=torch.long, device=device))
                    padded_masks_list_k.append(torch.zeros((max_len_k,), dtype=torch.long, device=device))

            if padded_ids_list_k: # Should always be true if max_num_chunks > 0
                 initial_inputs_small_focused_list.append({
                    "input_ids": torch.stack(padded_ids_list_k).to(device),
                    "attention_mask": torch.stack(padded_masks_list_k).to(device)
                })
    num_focused_experts = len(initial_inputs_small_focused_list)

    # --- KV Cache Init --- (Same as your original MoD script, adapted names)
    past_key_values_large = None
    past_key_values_small_unprimed = None
    past_key_values_small_focused_list = [None] * num_focused_experts # MODIFIED: for focused experts

    # --- Termination Setup --- (Same as your original MoD script)
    stop_token_ids_set = set()
    try:
        im_end_token_list = tokenizer.encode("<|im_end|>", add_special_tokens=False)
        if im_end_token_list: stop_token_ids_set.add(im_end_token_list[0])
        if tokenizer.eos_token_id is not None and tokenizer.eos_token_id not in stop_token_ids_set:
            stop_token_ids_set.add(tokenizer.eos_token_id)
    except Exception as e: print(f"Warning: Error getting termination token IDs: {e}.")
    if not stop_token_ids_set and tokenizer.eos_token_id is not None: stop_token_ids_set = {tokenizer.eos_token_id}
    if not stop_token_ids_set: print("CRITICAL WARNING: No EOS token ID found for stop criteria.")
    else: print(f"Termination Token IDs: {stop_token_ids_set}")
    pad_token_id = tokenizer.pad_token_id
    if pad_token_id is None: raise ValueError("tokenizer.pad_token_id is None.")


    # --- Vocab Size Handling --- (Same as your original MoD script)
    vocab_size_large = model_large.config.vocab_size
    vocab_size_small = model_small.config.vocab_size
    max_vocab_size = max(vocab_size_large, vocab_size_small)
    needs_padding_large = vocab_size_large < max_vocab_size
    needs_padding_small = vocab_size_small < max_vocab_size
    pad_value_logits = float('-inf')

    # --- Initialize generated_ids and attention masks for the loop ---
    generated_ids = initial_inputs_large["input_ids"].clone() # M_Large drives the sequence length
    current_attention_mask_large = initial_inputs_large["attention_mask"].clone()
    current_attention_mask_small_unprimed = initial_inputs_small_unprimed["attention_mask"].clone() # MODIFIED: for unprimed
    current_attention_masks_small_focused = [] # MODIFIED: for focused
    if num_focused_experts > 0:
        current_attention_masks_small_focused = [item["attention_mask"].clone() for item in initial_inputs_small_focused_list]

    unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
    total_new_tokens_generated = 0
    # --- ADD THESE ---
    loop_start_time = time.time()
    last_print_time = loop_start_time
    # --- END ADD ---

    for step in range(max_new_tokens):
        # --- ADD THIS ---
        step_start_time = time.time()
        # --- END ADD ---

        if step == 0:
            input_ids_large_step = initial_inputs_large["input_ids"]
            input_ids_small_unprimed_step = initial_inputs_small_unprimed["input_ids"] # MODIFIED
            if num_focused_experts > 0:
                input_ids_small_focused_step_list = [item["input_ids"] for item in initial_inputs_small_focused_list] # MODIFIED
        else:
            input_ids_large_step = next_token_id # next_token_id is from combined logits
            input_ids_small_unprimed_step = next_token_id # MODIFIED
            if num_focused_experts > 0:
                input_ids_small_focused_step_list = [next_token_id] * num_focused_experts # MODIFIED

        try:
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available(), dtype=torch.bfloat16):
                # 1. M_Large Forward Pass
                outputs_large = model_large(
                    input_ids=input_ids_large_step,
                    attention_mask=current_attention_mask_large,
                    past_key_values=past_key_values_large,
                    use_cache=True
                )
                logits_large_raw = outputs_large.logits[:, -1, :]
                past_key_values_large = outputs_large.past_key_values
                logits_large_aligned = F.pad(logits_large_raw, (0, max_vocab_size - logits_large_raw.shape[-1]), value=pad_value_logits) if needs_padding_large else logits_large_raw

                average_expert_delta = torch.zeros_like(logits_large_aligned)

                # MODIFIED: Logic for unprimed and focused experts
                if num_focused_experts > 0: # Only proceed if there are chunks to process
                    # 2. M_Small (Unprimed, Question-Only) Forward Pass
                    outputs_small_unprimed = model_small(
                        input_ids=input_ids_small_unprimed_step,
                        attention_mask=current_attention_mask_small_unprimed,
                        past_key_values=past_key_values_small_unprimed,
                        use_cache=True
                    )
                    logits_small_unprimed_raw = outputs_small_unprimed.logits[:, -1, :]
                    past_key_values_small_unprimed = outputs_small_unprimed.past_key_values
                    logits_small_unprimed_aligned = F.pad(logits_small_unprimed_raw, (0, max_vocab_size - logits_small_unprimed_raw.shape[-1]), value=pad_value_logits) if needs_padding_small else logits_small_unprimed_raw

                    # 3. M_Small (Focused on Chunks + Question) Forward Passes
                    all_processed_deltas = [] # Store deltas after potential filtering

                    for i in range(num_focused_experts):
                        outputs_small_focused = model_small(
                            input_ids=input_ids_small_focused_step_list[i],
                            attention_mask=current_attention_masks_small_focused[i],
                            past_key_values=past_key_values_small_focused_list[i],
                            use_cache=True
                        )
                        logits_small_focused_i_raw = outputs_small_focused.logits[:, -1, :]
                        past_key_values_small_focused_list[i] = outputs_small_focused.past_key_values
                        logits_small_focused_i_aligned = F.pad(logits_small_focused_i_raw, (0, max_vocab_size - logits_small_focused_i_raw.shape[-1]), value=pad_value_logits) if needs_padding_small else logits_small_focused_i_raw

                        delta_i = logits_small_focused_i_aligned - logits_small_unprimed_aligned # Delta uses the new unprimed baseline
                        # --- PRINT DELTA MAGNITUDES (ONLY FOR FIRST STEP) ---
                        if False:
                            abs_delta_values = torch.abs(delta_i)
                            print(f"  Shape of delta_i for this expert batch: {delta_i.shape}")
                            print(f"  Max magnitude (across batch & vocab): {abs_delta_values.max().item():.4f}")
                            print(f"  Mean magnitude (across batch & vocab): {abs_delta_values.mean().item():.4f}")
                            print(f"  Median magnitude (across batch & vocab): {torch.median(abs_delta_values).item():.4f}")
                            flat_abs_deltas = abs_delta_values.flatten()
                            if flat_abs_deltas.numel() > 0:
                                percentiles_to_calc = torch.tensor([0.50, 0.75, 0.90, 0.95, 0.99, 0.999], device=flat_abs_deltas.device)
                                quantiles = torch.quantile(flat_abs_deltas.float(), percentiles_to_calc)
                                for p_idx_loop, p_val_tensor in enumerate(percentiles_to_calc): # Renamed p_idx to p_idx_loop
                                    p_val = p_val_tensor.item()
                                    q_val = quantiles[p_idx_loop].item() # Renamed p_idx to p_idx_loop
                                    print(f"  {p_val*100:.1f}th percentile magnitude: {q_val:.4f}")
                        # --- END PRINT DELTA MAGNITUDES ---


                        if delta_magnitude_threshold > 0.0:
                              # Create a mask where the absolute delta is greater than the threshold
                              significant_delta_mask = torch.abs(delta_i) > delta_magnitude_threshold
                              # Apply the mask: keep delta_i where mask is true, else 0
                              processed_delta_i = torch.where(significant_delta_mask, delta_i, torch.zeros_like(delta_i))
                        else: # No thresholding (or threshold is zero), use the original delta
                            processed_delta_i = delta_i
                        all_processed_deltas.append(processed_delta_i)

                    if all_processed_deltas: # If list is not empty
                        average_expert_delta = torch.stack(all_processed_deltas).mean(dim=0)
                        # If all processed_delta_i were zero tensors, average_expert_delta will correctly be a zero tensor.

                modified_logits = logits_large_aligned + alpha * average_expert_delta

        # --- The rest of the loop is largely the same as your original Cell 5 ---
        # (Error handling, NaN/Inf check, Sampling, Update generated_ids and attention masks, Termination Criteria, Print Progress)

        except torch.cuda.OutOfMemoryError as oom_e:
            # ... (your OOM handling) ...
            # MODIFIED: Use qwen3_think_end_token_id_in and enable_thinking_in
            decoded_results_dict = decode_batch_results_qwen3_thinking(generated_ids, prompt_lengths_large, tokenizer, stop_token_ids_set, pad_token_id, enable_thinking_in, qwen3_think_end_token_id_in)
            return {"outputs": decoded_results_dict["outputs"], "thinking_outputs": decoded_results_dict["thinking_outputs"], "tokens_per_second": 0.0, "error": str(oom_e)}
        except Exception as e:
            # ... (your general exception handling) ...
            decoded_results_dict = decode_batch_results_qwen3_thinking(generated_ids, prompt_lengths_large, tokenizer, stop_token_ids_set, pad_token_id, enable_thinking_in, qwen3_think_end_token_id_in)
            return {"outputs": decoded_results_dict["outputs"], "thinking_outputs": decoded_results_dict["thinking_outputs"], "tokens_per_second": 0.0, "error": str(e)}

        # --- Check for and Handle NaN/Inf in modified_logits ---
        if torch.isnan(modified_logits).any() or torch.isinf(modified_logits).any():
            nan_inf_mask = torch.isnan(modified_logits) | torch.isinf(modified_logits)
            if ((step + 1) % 10 == 0 or step == 0) and nan_inf_mask.any():
                print(f"Warning: NaN/Inf detected in combined logits at step {step+1} for {nan_inf_mask.sum().item()} elements. Replacing NaN with -inf.")
            modified_logits = torch.nan_to_num(modified_logits, nan=float('-inf'), posinf=float('inf'), neginf=float('-inf'))
        # --- Sampling ---
        # (The sampling block from your original code seems fine, so it's used here)
        if temperature > 0:
            scaled_logits = modified_logits / temperature
            filtered_logits = top_k_top_p_filtering(scaled_logits, top_k=top_k, top_p=top_p)
            probabilities = F.softmax(filtered_logits, dim=-1)
            nan_probs_mask = torch.isnan(probabilities).any(dim=-1)

            if nan_probs_mask.any():
                # print(f"Warning: NaN in probabilities at step {step+1} for {nan_probs_mask.sum().item()} sequences. Falling back.") # Less verbose
                fallback_k = min(top_k if top_k > 0 else 5, scaled_logits.size(-1)) # Fallback to top_k or 5
                if fallback_k == 0: fallback_k = 1

                _, top_k_indices_fallback = torch.topk(scaled_logits[nan_probs_mask], k=fallback_k, dim=-1)
                uniform_probs_topk = torch.ones_like(top_k_indices_fallback, dtype=torch.float) / fallback_k
                uniform_sampled_relative_indices = torch.multinomial(uniform_probs_topk, num_samples=1)
                uniform_next_token_id = torch.gather(top_k_indices_fallback, dim=-1, index=uniform_sampled_relative_indices)

                next_token_id = torch.full((batch_size, 1), pad_token_id, dtype=torch.long, device=device)
                normal_probs_mask = ~nan_probs_mask
                if normal_probs_mask.any():
                    normal_probs = probabilities[normal_probs_mask]
                    normal_probs = torch.clamp(normal_probs, min=0.0)
                    normal_probs_sum = normal_probs.sum(dim=-1, keepdim=True)
                    # Avoid division by zero if sum is zero (all filtered out)
                    safe_sum = torch.where(normal_probs_sum == 0, torch.ones_like(normal_probs_sum), normal_probs_sum)
                    normal_probs = normal_probs / safe_sum
                    normal_probs = torch.where(torch.isnan(normal_probs), torch.ones_like(normal_probs) / normal_probs.size(-1), normal_probs)

                    if normal_probs.numel() > 0 and not torch.isnan(normal_probs).all(): # Check if multinomial can run
                         normal_next_token_id = torch.multinomial(normal_probs, num_samples=1)
                         next_token_id[normal_probs_mask] = normal_next_token_id
                    else: # All normal_probs were NaN or empty, use fallback for these too (e.g. greedy)
                        # print(f"Warning: Normal probs also problematic for {normal_probs_mask.sum().item()} sequences. Using greedy on scaled_logits.")
                        greedy_fallback_tokens = torch.argmax(scaled_logits[normal_probs_mask], dim=-1).unsqueeze(-1)
                        next_token_id[normal_probs_mask] = greedy_fallback_tokens


                next_token_id[nan_probs_mask] = uniform_next_token_id
                if (next_token_id == pad_token_id).all() and batch_size > 0 :
                    # print(f"Error/Warning: All sequences failed sampling at step {step+1}. Using greedy from modified_logits.")
                    next_token_id = torch.argmax(modified_logits, dim=-1).unsqueeze(-1) # Last resort
            else:
                probabilities = torch.clamp(probabilities, min=0.0)
                probs_sum = probabilities.sum(dim=-1, keepdim=True)
                safe_sum = torch.where(probs_sum == 0, torch.ones_like(probs_sum), probs_sum)
                probabilities = probabilities / safe_sum
                probabilities = torch.where(torch.isnan(probabilities), torch.ones_like(probabilities) / probabilities.size(-1), probabilities)
                if probabilities.numel() > 0 and not torch.isnan(probabilities).all():
                    next_token_id = torch.multinomial(probabilities, num_samples=1)
                else:
                    # print(f"Warning: All probabilities problematic at step {step+1}. Using greedy from modified_logits.")
                    next_token_id = torch.argmax(modified_logits, dim=-1).unsqueeze(-1) # Last resort

        else: # Greedy (temperature == 0)
            next_token_id = torch.argmax(modified_logits, dim=-1).unsqueeze(-1)        # next_token_id = ...

        # --- Stop finished sequences from generating further ---
        # Mask next_token_id with pad_token_id for sequences that are already finished
        next_token_id = next_token_id * unfinished_sequences.unsqueeze(-1) + \
                        pad_token_id * (1 - unfinished_sequences.unsqueeze(-1))
        # (Update generated sequences and attention masks for NEXT iteration - same structure)
        generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
        new_token_attention = unfinished_sequences.unsqueeze(-1)
        current_attention_mask_large = torch.cat([current_attention_mask_large, new_token_attention], dim=-1)
        current_attention_mask_small_unprimed = torch.cat([current_attention_mask_small_unprimed, new_token_attention], dim=-1) # MODIFIED
        if num_focused_experts > 0:
            for i in range(num_focused_experts):
                current_attention_masks_small_focused[i] = torch.cat([current_attention_masks_small_focused[i], new_token_attention], dim=-1) # MODIFIED

        # --- Check Termination Criteria ---
        current_token_scalars = next_token_id.squeeze(-1)
        hit_stop_token = torch.isin(current_token_scalars, torch.tensor(list(stop_token_ids_set), device=device))

        # Increment token count only for sequences that were active (unfinished or just finished)
        tokens_generated_this_step_count = unfinished_sequences.sum().item()
        total_new_tokens_generated += tokens_generated_this_step_count

        # --- Print Progress ---
        current_time = time.time()
        if current_time - last_print_time >= 10.0 or (step == max_new_tokens - 1) or (unfinished_sequences.max() == 0 and step > 0):
             active_sequences = unfinished_sequences.sum().item()
             print(f"  Step {step+1}/{max_new_tokens} | Active: {active_sequences}/{batch_size} | Last tokens: {current_token_scalars[:min(5, batch_size)].tolist()} | Step time: {current_time - step_start_time:.3f}s")
             last_print_time = current_time

        if unfinished_sequences.max() == 0:
            print(f"All sequences finished generating at step {step+1}.")
            break
        # (Break if all sequences finished - same as your script)

    # --- Post-Loop ---
    print("Generation loop finished.")
    loop_end_time = time.time()
    total_loop_time = loop_end_time - loop_start_time

    print(f"Total generation time (loop): {total_loop_time:.2f} seconds for {total_new_tokens_generated} total new tokens.")
    tokens_per_second = 0.0
    if total_new_tokens_generated > 0 and total_loop_time > 0.01: # Avoid division by zero for very fast runs
        tokens_per_second = total_new_tokens_generated / total_loop_time
        print(f"Aggregate speed: {tokens_per_second:.2f} tokens/second.")
    elif total_new_tokens_generated == 0:
        print("No new tokens were generated.")

    decoded_results_dict = decode_batch_results_qwen3_thinking(
        generated_ids, prompt_lengths_large, tokenizer, stop_token_ids_set,
        pad_token_id, enable_thinking_in, qwen3_think_end_token_id_in # Use the _in suffixed parameters
    )

    generation_end_time = time.time()
    print(f"Total function execution time: {generation_end_time - generation_start_time:.2f} seconds.")
    print("--- BATCH MoD Generation Function (Qwen3 Thinking Aware) Complete ---")
    return {
        "outputs": decoded_results_dict["outputs"],
        "thinking_outputs": decoded_results_dict["thinking_outputs"],
        "tokens_per_second": tokens_per_second,
        "error": None # No error if successful completion
    }


def decode_batch_results_qwen3_thinking(
    generated_ids: torch.Tensor,
    prompt_lengths: List[int],
    tokenizer,
    stop_token_ids_set: set,
    pad_token_id: int,
    is_thinking_enabled: bool,
    think_end_token_id: int
) -> Dict[str, List[str]]:
    """
    Helper function to decode batch results, handling Qwen3-style thinking content.
    Returns a dictionary with 'outputs' and 'thinking_outputs'.
    """
    batch_outputs = []
    batch_thinking_outputs = []
    batch_size = generated_ids.shape[0]
    # stop_token_strings = {tokenizer.decode(stop_id) for stop_id in stop_token_ids_set if stop_id != pad_token_id} # Not directly used in Qwen3 parsing example

    for i in range(batch_size):
        prompt_len = prompt_lengths[i]
        response_ids_with_pad = generated_ids[i, prompt_len:]
        response_ids_list = response_ids_with_pad[response_ids_with_pad != pad_token_id].tolist()

        thinking_content = ""
        main_content = ""

        if is_thinking_enabled:
            try:
                # Find the last occurrence of </think> token ID
                # rindex finding think_end_token_id (e.g., 151668 for </think>)
                idx_think_end = len(response_ids_list) - 1 - response_ids_list[::-1].index(think_end_token_id)
                # The thinking content is from the start of the response up to and including </think>
                # The Qwen example decodes up to BEFORE </think> for thinking_content,
                # and FROM </think> for main_content. Let's follow that.
                # The index returned by .index() is 0-based from the *reversed* list.
                # So, if </think> is the last token, rev_idx = 0. len - 1 - 0 = actual last index.
                # We want tokens *before* </think> for thinking, and tokens *from* </think> for content.
                # Let's adjust `idx_think_end` to be the index of the </think> token itself.

                # Corrected logic based on example:
                # index = len(output_ids) - output_ids[::-1].index(151668)
                # This `index` in the example is the position *after* the </think> tag in the `output_ids` *slice*.
                # It means: `output_ids[:index]` is thinking (inclusive of </think>), `output_ids[index:]` is content.
                # However, the example then decodes thinking_content up to `output_ids[:index]`
                # and content from `output_ids[index:]`.
                # The example code has:
                # thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True)
                # content = tokenizer.decode(output_ids[index:], skip_special_tokens=True)
                # This implies `index` is the split point. If `</think>` is at `k`, `output_ids[:k]` for thinking,
                # and `output_ids[k+1:]` for content. Or rather, what their code implies is
                # `output_ids[:index_of_think_end_tag]` is thinking, and `output_ids[index_of_think_end_tag:]` is content that starts with </think>
                # which is then skipped by `skip_special_tokens=True`.

                # Let's re-interpret the example's `index` logic:
                # `output_ids[::-1].index(think_end_token_id)` gives num elements *after* last `think_end_token_id` in reversed list
                # `split_point_after_think = len(response_ids_list) - response_ids_list[::-1].index(think_end_token_id)`
                # This `split_point_after_think` is the index in `response_ids_list` that is ONE AFTER the `</think>` token.
                # Example: [1,2,THINK_END,4,5]. rev=[5,4,THINK_END,2,1]. rev.index(THINK_END)=2. len=5. 5-2=3.
                # response_ids_list[:3] = [1,2,THINK_END]. response_ids_list[3:] = [4,5]
                # This seems correct for the example's decoding slices.

                split_point_after_think = len(response_ids_list) - response_ids_list[::-1].index(think_end_token_id)

                thinking_ids_part = response_ids_list[:split_point_after_think]
                content_ids_part = response_ids_list[split_point_after_think:]

                # The original Qwen example's decode for thinking_content includes the </think> tag,
                # but skip_special_tokens=True might remove it if it's registered as special.
                # Let's decode up to *before* </think> for "thinking" and from *after* for "content"
                # to be cleaner, if skip_special_tokens doesn't remove </think>.
                # If </think> (151668) is special, then `tokenizer.decode(thinking_ids_part, skip_special_tokens=True)`
                # would give thinking_content, and `tokenizer.decode(content_ids_part, skip_special_tokens=True)` gives content.
                # This seems to be the intent of the Qwen example.

                thinking_content = tokenizer.decode(thinking_ids_part, skip_special_tokens=True).strip()
                main_content = tokenizer.decode(content_ids_part, skip_special_tokens=True).strip()

            except ValueError: # think_end_token_id not found
                # No </think> tag found, assume all is main content
                print(f"Warning: Qwen3 </think> token ID {think_end_token_id} not found in response for item {i}, though thinking was enabled. Treating all as content.")
                main_content = tokenizer.decode(response_ids_list, skip_special_tokens=True).strip()
                thinking_content = "" # No thinking content if tag not found
        else: # Thinking not enabled, all response is main content
            main_content = tokenizer.decode(response_ids_list, skip_special_tokens=True).strip()
            thinking_content = ""


        # Final cleanup of other stop tokens from main_content if necessary
        # (Original script had this, might still be useful)
        # for stop_str in stop_token_strings:
        #     if stop_str and main_content.endswith(stop_str):
        #         main_content = main_content[:-len(stop_str)].rstrip()

        batch_outputs.append(main_content)
        batch_thinking_outputs.append(thinking_content)

    return {"outputs": batch_outputs, "thinking_outputs": batch_thinking_outputs}


def decode_batch_results(generated_ids: torch.Tensor,
                         prompt_lengths: List[int],
                         tokenizer,
                         stop_token_ids_set: set,
                         pad_token_id: int,
                         include_think_prompt: bool) -> List[str]:
    """Helper function to decode the batch results and clean them up."""
    decoded_outputs = []
    batch_size = generated_ids.shape[0]
    stop_token_strings = {tokenizer.decode(stop_id) for stop_id in stop_token_ids_set if stop_id != pad_token_id}
    think_end_tag = "</think>"

    for i in range(batch_size):
        prompt_len = prompt_lengths[i]
        response_ids = generated_ids[i, prompt_len:]

        # Filter out padding tokens
        actual_response_ids = response_ids[response_ids != pad_token_id]

        response_text = tokenizer.decode(actual_response_ids, skip_special_tokens=True)

        # Clean up potential trailing tags/tokens
        if include_think_prompt and response_text.strip().endswith(think_end_tag):
            response_text = response_text.rsplit(think_end_tag, 1)[0].strip()
            # print(f"Cleaned trailing '{think_end_tag}' tag for sequence {i}.") # Optional debug print

        # Final cleanup of stop tokens that might linger
        cleaned = False
        for stop_str in stop_token_strings:
            if stop_str and response_text.endswith(stop_str):
                response_text = response_text[:-len(stop_str)].rstrip()
                cleaned = True
        #if cleaned: print(f"Cleaned trailing stop token for sequence {i}.") # Optional debug print

        decoded_outputs.append(response_text.strip())

    return decoded_outputs




--- Importing Libraries ---
--- Libraries Imported ---

--- Configuring Models and Device ---
--- Using Fixed Task Template ---
Chunk size for focused M_Small: 128 tokens
--- Model Configuration ---
Large (M_Large):    Qwen/Qwen3-1.7B
Small (M_Small):    Qwen/Qwen3-0.6B (for ICL-based deltas)
-------------------------
Using device: cuda
Using 4-bit NF4 quantization with bfloat16 compute dtype for all models.
--- Configuration Complete ---

--- Loading Tokenizer and Models ---
Loading Tokenizer...
Using PAD token ID: 151643
Tokenizer padding side set to 'left'.
Qwen3 '</think>' token ID found: 151668
Tokenizer Loaded Successfully.
Loading Model: Qwen/Qwen3-1.7B...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Qwen/Qwen3-1.7B Loaded Successfully.
Estimated memory footprint for Qwen/Qwen3-1.7B: 1.33 GB
Loading Model: Qwen/Qwen3-0.6B...
Qwen/Qwen3-0.6B Loaded Successfully.
Estimated memory footprint for Qwen/Qwen3-0.6B: 0.53 GB

All models loaded successfully with 4-bit quantization.
--- Tokenizer and Models Loaded ---

--- Defining BookHaystack Class & MoD Generation ---

--- Defining BATCHED MoD Generation Function (Qwen3 Thinking Aware) ---


In [None]:
# --- Test Parameters for this Simplified Run ---
# CHOOSE ONE CONTEXT LENGTH TO START, e.g., 32k. Max your M_Large can handle.
CONTEXT_LENGTHS_TO_TEST = [2048]
DEPTH_PERCENTAGES_TO_TEST = [0.5] # Test a few depths
ALPHA_MOD = 1
MAX_NEW_TOKENS_GENERATION = 100
GLOBAL_SYSTEM_PROMPT_FOR_RUN = "You are a helpful AI assistant. Use the information provided in the book snippet to answer the question. Your answer should be short and based on either explicitly stated facts or strong, logical inferences."
ENABLE_QWEN3_THINKING_FOR_RUN = True # Set to False for baseline as Qwen3 thinking is often a fine-tuning feature



# --- Cell 6: Manual NOLIMA-Hard Evaluation Loop ---
print("\n--- Starting Manual NOLIMA-Hard Evaluation ---")

# Ensure models and tokenizer are loaded from Cell 4
if tokenizer is None or model_large is None or model_small is None:
    print("ERROR: Models or tokenizer not loaded. Please run Cell 4 first.")
else:
    # Load NOLIMA-Hard dataset
    nolima_hard_configs_loaded = []
    if os.path.exists(NEEDLE_SET_HARD_PATH):
        with open(NEEDLE_SET_HARD_PATH, "r") as f_json:
            nolima_hard_configs_loaded = json.load(f_json)
        print(f"Loaded {len(nolima_hard_configs_loaded)} experiment configurations from NOLIMA-Hard set.")
    else:
        print(f"ERROR: NOLIMA-Hard file not found at {NEEDLE_SET_HARD_PATH}")

    # Instantiate BookHaystack
    try:
        book_processor = BookHaystack(HAYSTACK_BOOK_PATH, tokenizer_instance=tokenizer)
        print(f"Loaded haystack content from: {HAYSTACK_BOOK_PATH}")
    except Exception as e_book:
        print(f"Error loading haystack: {e_book}")
        book_processor = None

    all_run_results = []

    if nolima_hard_configs_loaded and book_processor:
        for current_ctx_len in CONTEXT_LENGTHS_TO_TEST:
            print(f"\n===== EVALUATING FOR TARGET HAYSTACK CONTEXT LENGTH: {current_ctx_len} tokens =====")

            for exp_conf in nolima_hard_configs_loaded[3:4]:
                exp_id_val = exp_conf["id"]

                for q_type, q_template in exp_conf["questions"].items():
                    for t_id, t_details in exp_conf["tests"].items():

                        needle_template_val = exp_conf["needle"]
                        input_args_val = t_details["input_args"]
                        gold_answers_val = t_details.get("gold_answers", "")
                        char_set_val = exp_conf.get("character_set", [])

                        # --- Substitute CHAR and args ---
                        final_needle = needle_template_val
                        final_question = q_template
                        actual_selected_char = None # Initialize actual_selected_char

                        if "{CHAR}" in needle_template_val or "{CHAR}" in q_template: # Check original templates
                            if not char_set_val:
                                print(f"WARNING: {{CHAR}} in template but no character_set for {exp_id_val}_{t_id}_{q_type}. Skipping this test case.")
                                # Optionally, log this skipped case or handle it differently
                                all_run_results.append({
                                    "exp_id": exp_id_val, "test_id": t_id, "q_type": q_type,
                                    "haystack_target_len": current_ctx_len, "depth": depth_val, # Assuming depth_val is available here or loop it
                                    "needle": needle_template_val, "question": q_template,
                                    "char_selected": None, "gold": gold_answers_val, # gold_answers_val would be ""
                                    "answer": "SKIPPED_NO_CHAR_SET", "score": 0,
                                    "error": "Character set missing for {CHAR} template."
                                })
                                continue # Skip to the next test case iteration

                            actual_selected_char = np.random.choice(char_set_val)
                            final_needle = final_needle.replace("{CHAR}", actual_selected_char)
                            final_question = final_question.replace("{CHAR}", actual_selected_char)

                            # --- KEY CHANGE FOR GOLD ANSWER ---
                            # If a character was selected, this IS the gold answer for scoring.
                            # We make it a list to be consistent with how scoring might handle multiple gold answers.
                            gold_answers_val = [actual_selected_char]
                            # --- END KEY CHANGE ---

                        # Argument substitution should happen AFTER {CHAR} substitution
                        for arg_i, arg_v in enumerate(input_args_val):
                            placeholder_str = "{" + str(arg_i + 1) + "}"
                            final_needle = final_needle.replace(placeholder_str, str(arg_v))
                            final_question = final_question.replace(placeholder_str, str(arg_v))

                            # If gold_answers_val was set from {CHAR}, it's now a list like ['selected_char_name']
                            # If it was loaded from JSON and was a list/string, substitute args there too.
                            if isinstance(gold_answers_val, list):
                                gold_answers_val = [
                                    ans.replace(placeholder_str, str(arg_v)) if isinstance(ans, str) else ans
                                    for ans in gold_answers_val
                                ]
                            elif isinstance(gold_answers_val, str): # Only if not set by {CHAR}
                                gold_answers_val = gold_answers_val.replace(placeholder_str, str(arg_v))
                        # --- End Substitution ---


                        for arg_i, arg_v in enumerate(input_args_val):
                            placeholder_str = "{" + str(arg_i + 1) + "}" # Assuming {1}, {2} are argument placeholders
                            final_needle = final_needle.replace(placeholder_str, str(arg_v)) # Ensure arg_v is string
                            final_question = final_question.replace(placeholder_str, str(arg_v)) # Ensure arg_v is string
                        # --- End Substitution ---

                        print(f"\n--- Running: {exp_id_val}_{t_id}_{q_type} (Haystack len: {current_ctx_len}) ---")
                        print(f"  Needle: {final_needle[:70]}...")
                        print(f"  Question: {final_question}")

                        for depth_val in DEPTH_PERCENTAGES_TO_TEST:
                            print(f"  Depth: {depth_val*100:.0f}%")

                            # 1. Create haystack with needle
                            # `current_ctx_len` is the target length of the book snippet part
                            haystack_text_with_needle = book_processor.get_haystack_with_needle(
                                needle_text=final_needle,
                                target_haystack_len=current_ctx_len,
                                depth_percentage=depth_val
                            )

                            # 2. Call your MoD generation
                            # (Pass global constants from Cell 3 as arguments here)
                            eval_output = generate_mod_batch_kv_cache(
                                haystacks_full_text=[haystack_text_with_needle], # Batch of 1
                                questions_text=[final_question],             # Batch of 1
                                max_new_tokens=MAX_NEW_TOKENS_GENERATION,
                                temperature=0.0, # Or from config
                                top_k=0,        # Or from config
                                top_p=0.95,      # Or from config
                                alpha=ALPHA_MOD,
                                delta_magnitude_threshold=0.0,
                                global_system_prompt_in=GLOBAL_SYSTEM_PROMPT_FOR_RUN,
                                enable_thinking_in=ENABLE_QWEN3_THINKING_FOR_RUN,
                                qwen3_think_end_token_id_in=think_end_token_id_qwen3,
                                chunk_size_small_in=CHUNK_SIZE_FOR_FOCUSED_SMALL_MODELS
                            )

                            ans_text = eval_output["outputs"][0] if eval_output["outputs"] and not eval_output.get("error") else "ERROR_IN_GENERATION"
                            print(f"    Model Answer: {ans_text[:70]}...")

                            # 3. Score
                            current_score = 0
                            if isinstance(gold_answers_val, list):
                                if any(g.lower() in ans_text.lower() for g in gold_answers_val if g): current_score = 1
                            elif isinstance(gold_answers_val, str) and gold_answers_val:
                                if gold_answers_val.lower() in ans_text.lower(): current_score = 1
                            print(f"    Score: {current_score}")

                            all_run_results.append({
                                "exp_id": exp_id_val, "test_id": t_id, "q_type": q_type,
                                "haystack_target_len": current_ctx_len, "depth": depth_val,
                                "needle": final_needle, "question": final_question,
                                "char_selected": actual_selected_char, "gold": gold_answers_val,
                                "answer": ans_text, "score": current_score,
                                "tps": eval_output.get("tokens_per_second"),
                                "error": eval_output.get("error")
                            })
                            gc.collect()
                            if torch.cuda.is_available(): torch.cuda.empty_cache()

        # --- Save results after each context length or at the end ---
        results_filename_temp = f"manual_nolima_hard_results_{time.strftime('%Y%m%d_%H%M%S')}.json"
        with open(os.path.join(PROJECT_BASE_PATH, results_filename_temp), "w") as f_out_json:
            json.dump(all_run_results, f_out_json, indent=2)
        print(f"Intermediate results saved to {results_filename_temp}")

    print("\n--- Manual NOLIMA-Hard Evaluation Finished ---")


--- Starting Manual NOLIMA-Hard Evaluation ---
Loaded 4 experiment configurations from NOLIMA-Hard set.
Loaded haystack content from: /content/my_book.txt

===== EVALUATING FOR TARGET HAYSTACK CONTEXT LENGTH: 2048 tokens =====

--- Running: 0409Inv_T10_C02_onehop (Haystack len: 2048) ---
  Needle: There was an engineer living in Calvinia, named Diana....
  Question: Which character has been to South Africa?
  Depth: 50%
Encoding full book for BookHaystack (one time)...
Book has 98060 tokens.
Termination Token IDs: {151645}


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available(), dtype=torch.bfloat16):


  Step 7/100 | Active: 1/1 | Last tokens: [1430] | Step time: 1.483s
