In [None]:
# @title 1. Setup: Install Libraries
# (Keep as is)
!pip install transformers torch accelerate bitsandbytes sentencepiece -q
!pip install -U bitsandbytes # Ensure latest bitsandbytes

# @title 2. Import Libraries
# (Keep as is)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, LogitsProcessorList, MinLengthLogitsProcessor, StoppingCriteriaList, MaxLengthCriteria
import gc
import re
import time # For timing comparison

# @title 3. Configuration: Model Names, Quantization, Device
# (Keep as is)
# --- Specify your Qwen2 models ---
model_base_name = "Qwen/Qwen2.5-72B"
model_expert_name = "Qwen/QwQ-32B"
model_anti_expert_name = "Qwen/Qwen2.5-32b"

print(f"--- Model Configuration ---")
print(f"Base (M):         {model_base_name}")
print(f"Expert (M+):      {model_expert_name}")
print(f"Anti-Expert (M-): {model_anti_expert_name}")
print("-------------------------")

# --- Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if not torch.cuda.is_available():
    print("WARNING: CUDA not available, running on CPU will be extremely slow.")

# --- Quantization Configuration (Applied to ALL models) ---
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)
print("Using 4-bit NF4 quantization for all models.")

# @title 4. Load Models and Tokenizer

# --- Load Tokenizer (Use tokenizer from the base model) ---
print("Loading Tokenizer...")
try:
    tokenizer = AutoTokenizer.from_pretrained(model_base_name, trust_remote_code=True)

    # Check if a chat template is defined
    if tokenizer.chat_template is None:
         print("WARNING: Tokenizer does not have a chat_template defined. Using default (or potentially incorrect) formatting.")
         # Example: tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        print(f"Tokenizer pad_token set to eos_token ({tokenizer.eos_token}).")
    tokenizer.padding_side = "left"
    print(f"Tokenizer padding side set to '{tokenizer.padding_side}'.")

    print("Tokenizer Loaded. Chat template likely available via tokenizer.apply_chat_template.")
except Exception as e:
    print(f"ERROR loading tokenizer for {model_base_name}: {e}")
    raise

# --- Function to load model (Keep as is) ---
def load_model(model_name, config, device):
    print(f"Loading Model: {model_name}...")
    try:
        # using device_map='auto' already sends layers to GPU/CPU as needed
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=config,
            device_map="auto", # Automatically distributes layers
            trust_remote_code=True
        )
        model.eval() # Set to evaluation mode
        print(f"{model_name} Loaded Successfully.")
        mem_bytes = model.get_memory_footprint()
        print(f"Estimated memory footprint for {model_name}: {mem_bytes / 1e9:.2f} GB")
        gc.collect()
        torch.cuda.empty_cache()
        return model
    except Exception as e:
        print(f"ERROR loading model {model_name}: {e}")
        gc.collect(); torch.cuda.empty_cache(); return None

# --- Load Models (Keep as is) ---
model_base = load_model(model_base_name, bnb_config, device)
model_expert = load_model(model_expert_name, bnb_config, device)
model_anti_expert = load_model(model_anti_expert_name, bnb_config, device)

if not all([model_base, model_expert, model_anti_expert]):
     raise RuntimeError("One or more models failed to load. Cannot proceed.")
else:
    print("\nAll models loaded successfully with 4-bit quantization.")


# @title 5. Implement Proxy-Tuning Generation (Using KV Cache)

# Constants for roles (optional, makes code readable)
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"

@torch.inference_mode()
def generate_proxy_tuned_kv_cache(
    prompt: str,
    max_new_tokens: int = 150,
    temperature: float = 0.6,
    top_k: int = 50,
    alpha: float = 1.0,
    is_math_problem: bool = False,
    include_think_prompt: bool = True,
    ):
    """
    Generates text using proxy-tuning with KV caching for speed.
    """
    if not all([model_base, model_expert, model_anti_expert]):
        print("Error: Models not loaded."); return ""

    start_time = time.time()

    # --- Prepare Messages for Chat Template ---
    system_prompt = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
    user_content = prompt
    if is_math_problem and "reason step by step" not in prompt.lower():
         user_content += "\nPlease reason step by step, and put your final answer within \\boxed{}."

    messages = [
        {"role": SYSTEM, "content": system_prompt},
        {"role": USER, "content": user_content}
    ]

    # --- Apply Chat Template to get Input IDs ---
    try:
        input_ids = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt"
        ).to(device) # Move initial IDs to device

        # --- Add DeepSeek-R1 <think> prompt tokens if requested ---
        if include_think_prompt:
            think_prompt = "<think>\n"
            think_tokens = tokenizer.encode(think_prompt, add_special_tokens=False, return_tensors="pt").to(input_ids.device) # Ensure on same device
            input_ids = torch.cat([input_ids, think_tokens], dim=-1)

        prompt_length = input_ids.shape[1]

        print("--- Final Input String (Decoded from Tokens) ---")
        print(tokenizer.decode(input_ids[0]))
        print("------------------------------------------------")

    except Exception as e:
        print(f"Error applying chat template or adding think prompt: {e}")
        print("Ensure the tokenizer has a valid chat_template attribute.")
        return ""

    # --- KV Cache Initialization ---
    past_key_values_base = None
    past_key_values_expert = None
    past_key_values_anti_expert = None

    # --- Termination Setup ---
    # Get token ID for <|im_end|>
    try:
        im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
        stop_token_ids = [tokenizer.eos_token_id, im_end_token_id]
        if None in stop_token_ids: # Handle case where eos_token_id might be None
             stop_token_ids = [tid for tid in stop_token_ids if tid is not None]
        print(f"Termination Token IDs: {stop_token_ids}")
    except Exception as e:
        print(f"Warning: Could not get <|im_end|> token id: {e}. Using only EOS.")
        stop_token_ids = [tokenizer.eos_token_id]


    generated_ids = input_ids.clone()
    current_token_ids = input_ids # Start with the full prompt for the first pass

    print("Starting generation with KV Cache...")
    for step in range(max_new_tokens):
        step_start_time = time.time()
        try:
            # --- Model Forward Passes with KV Cache ---
            # Only the *first* pass uses the full `current_token_ids`.
            # Subsequent passes use only the last generated token and the cache.
            with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                # Base Model
                outputs_base = model_base(
                    current_token_ids,
                    past_key_values=past_key_values_base,
                    use_cache=True # Explicitly enable cache usage
                )
                logits_base = outputs_base.logits[:, -1, :] # Logits for the next token
                past_key_values_base = outputs_base.past_key_values # Get updated cache

                # Expert Model
                outputs_expert = model_expert(
                    current_token_ids,
                    past_key_values=past_key_values_expert,
                    use_cache=True
                )
                logits_expert = outputs_expert.logits[:, -1, :]
                past_key_values_expert = outputs_expert.past_key_values

                # Anti-Expert Model
                outputs_anti_expert = model_anti_expert(
                    current_token_ids,
                    past_key_values=past_key_values_anti_expert,
                    use_cache=True
                )
                logits_anti_expert = outputs_anti_expert.logits[:, -1, :]
                past_key_values_anti_expert = outputs_anti_expert.past_key_values

        except torch.cuda.OutOfMemoryError:
             print(f"CUDA OOM at step {step+1} (PromptLen: {prompt_length}, NewTokens: {step}). Stopping.")
             gc.collect(); torch.cuda.empty_cache(); break
        except Exception as e:
            print(f"Inference error at step {step+1}: {e}"); break

        # --- Combine Logits (Proxy-Tuning logic) ---
        logit_difference = logits_expert - logits_anti_expert
        modified_logits = logits_base + alpha * logit_difference

        if torch.isnan(modified_logits).any() or torch.isinf(modified_logits).any():
            print(f"Warning: NaN/Inf in logits at step {step+1}. Using base logits.");
            modified_logits = logits_base.clone() # Fallback

        # --- Sampling ---
        if temperature > 0:
            scaled_logits = modified_logits / temperature
            # Apply top-k filtering
            top_k_values, top_k_indices = torch.topk(scaled_logits, min(top_k, scaled_logits.size(-1)))
            # Create a mask tensor filled with -inf
            filtered_logits = torch.full_like(scaled_logits, float('-inf'))
            # Scatter the top-k values back into the mask
            filtered_logits.scatter_(-1, top_k_indices, top_k_values)
            # Calculate probabilities
            probabilities = torch.nn.functional.softmax(filtered_logits, dim=-1)
            # Handle potential NaN in probabilities after softmax (though less likely after filtering)
            if torch.isnan(probabilities).any():
                 print(f"Warning: NaN in probs at step {step+1}. Uniform sample from top-k.");
                 # Sample uniformly from the top-k indices if probabilities are NaN
                 next_token_idx = torch.randint(0, top_k_indices.shape[1], (1,), device=device)
                 next_token_id = top_k_indices[:, next_token_idx]
            else:
                 # Sample from the filtered probability distribution
                 next_token_id = torch.multinomial(probabilities, num_samples=1)
        else:
            # Greedy decoding (select the token with the highest logit)
            next_token_id = torch.argmax(modified_logits, dim=-1).unsqueeze(-1)

        # --- Update generated sequence ---
        generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)

        # --- Prepare input for the *next* iteration ---
        # Input for the next step is *only* the newly generated token ID
        current_token_ids = next_token_id # Shape will be (batch_size, 1)

        # --- Check for termination ---
        if next_token_id.item() in stop_token_ids:
            print(f"Termination token generated (ID: {next_token_id.item()}) at step {step+1}.")
            break

        if step % 20 == 0 and step > 0: # Print progress periodically
             step_end_time = time.time()
             print(f"Step {step+1}/{max_new_tokens} ({step_end_time - step_start_time:.2f}s/step) | Seq Len: {generated_ids.shape[1]}")

    print("Generation finished.")
    end_time = time.time()
    total_gen_tokens = generated_ids.shape[1] - prompt_length
    print(f"Total generation time: {end_time - start_time:.2f} seconds for {total_gen_tokens} tokens.")
    if total_gen_tokens > 0:
        print(f"Average speed: {total_gen_tokens / (end_time - start_time):.2f} tokens/second.")


    # --- Decode the response, excluding the prompt ---
    response_ids = generated_ids[:, prompt_length:]
    response_text = tokenizer.decode(response_ids[0], skip_special_tokens=True)

    # --- Clean up potential trailing markers ---
    if response_text.strip().endswith("</think>"):
       response_text = response_text.rsplit("</think>", 1)[0].strip()

    return response_text.strip()


# @title 6. Run Example Generation (with KV Cache)
prompt = "Determine all positive integers $n$ for which there exist positive integers $a$, $b$, and $c$ satisfying \[2a^n + 3b^n = 4c^n.\]"
print(f"Original User Prompt: {prompt}")

# --- Run Proxy-Tuned Generation with KV Cache ---
if 'model_base' in locals() and model_base is not None:
    generated_output_kv = generate_proxy_tuned_kv_cache(
        prompt,
        max_new_tokens=500, # Reduced for faster example, increase if needed
        alpha=1.0,
        temperature=0.6,
        top_k=50,
        is_math_problem=True,
        include_think_prompt=True
    )
    print("\n--- Proxy-Tuned Output (Using KV Cache + R1 Hints) ---")
    print(generated_output_kv)
else:
    print("Models not loaded correctly. Please check the loading logs.")


# @title 7. Qwen2 / DeepSeek R1 Documentation Resources
# (Keep documentation links as before)
print("\nRelevant documentation links provided in comments and text cell above.")

# Optional: Clean up GPU memory if you're done
# print("Cleaning up models...")
# del model_base, model_expert, model_anti_expert
# gc.collect()
# torch.cuda.empty_cache()
# print("Models deleted and cache cleared.")

[2K   [91m━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━[0m [32m173.0/363.4 MB[0m [31m36.7 MB/s[0m eta [36m0:00:06[0m