In [None]:
import os

cache_dir ='/scratch/hakeem.at/Queryable-Shared-Reference-Repository/notebooks/pretrained_models'

os.environ['HF_HOME'] = cache_dir
os.environ['TRANSFORMERS_CACHE'] = cache_dir
os.environ['HUGGINGFACE_HUB_CACHE'] = cache_dir

import json
import random
from tqdm.auto import tqdm
import pandas as pd

import torch
from vllm import LLM, SamplingParams

seed = 42
random.seed(seed)


In [None]:
input_file = "context_expansion_dataset.jsonl"
eval_data = pd.read_json(input_file, lines=True)

print(f"Total evaluation samples: {len(eval_data)}")
print(f"\nDistribution by expansion type and context %:")
print(eval_data.groupby(['expansion_type', 'target_pct']).size().unstack())
print(f"\nQuestion type distribution:")

In [None]:
PROMPT_TEMPLATE = """<instructions>
Answer the question using ONLY information from the context provided below. If the context does not contain enough information to answer the question, respond with exactly: "I don't know."
</instructions>

<context>
{context}
</context>

<question>
{question}
</question>

<answer>"""


In [None]:
model_id = "Qwen/Qwen3-8B"  

gpu_memory_utilization = 0.95
max_model_len = 32768  
max_num_seqs = 32
enforce_eager = True

model = LLM(
    model=model_id,
    gpu_memory_utilization=gpu_memory_utilization,
    max_model_len=max_model_len,
    max_num_seqs=max_num_seqs,
    enforce_eager=enforce_eager,
    trust_remote_code=True,  
)

In [None]:
sampling_params = SamplingParams(
    temperature=0,
    max_tokens=512,
    stop = ["</answer>"],
)

In [None]:
# test_prompt = """<context>
# Researchers found trash at the bottom of the ocean.
# </context>

# <question>
# Which brand of beer did the researchers found at the bottom of the ocean
# </question>

# <answer>"""

# # test_params = SamplingParams(temperature=0, max_tokens=100)
# output = model.generate([test_prompt], sampling_params)[0]
# print(f"âœ… Response: {output.outputs[0].text}")

In [None]:
results = []
output_file = "context_expansion_responses.jsonl"

for target_pct in sorted(eval_data['target_pct'].unique()):
    subset = eval_data[eval_data['target_pct'] == target_pct]
    
    print(f"\n{'='*60}")
    print(f"Processing context length: {int(target_pct*100)}% ({subset.iloc[0]['target_tokens']} tokens)")
    print(f"Samples: {len(subset)}")
    print(f"{'='*60}\n")
    
    if target_pct <= 0.25:
        batch_size = 512
    elif target_pct <= 0.50:
        batch_size = 256
    elif target_pct <= 0.75:
        batch_size = 128
    else:
        batch_size = 64
    
    for idx in tqdm(range(0, len(subset), batch_size), desc=f"{int(target_pct*100)}%"):
        batch_df = subset.iloc[idx:min(idx + batch_size, len(subset))]
        
        batch_prompts = []
        for _, row in batch_df.iterrows():
            prompt = PROMPT_TEMPLATE.format(
                context=row['expanded_context'],
                question=row['query']
            )
            batch_prompts.append(prompt)
        
        try:
            responses = model.generate(batch_prompts, sampling_params)
            
            for (_, row), response in zip(batch_df.iterrows(), responses):
                result = {
                    'original_idx': row['original_idx'],
                    'source': row['source'],
                    'question_type': row['question_type'],
                    'query': row['query'],
                    'original_context': row['original_context'],
                    'expansion_type': row['expansion_type'],
                    'target_pct': row['target_pct'],
                    'target_tokens': row['target_tokens'],
                    'actual_tokens': row['actual_tokens'],
                    'ground_truth': row['ground_truth'],
                    'raw_response': response.outputs[0].text.strip(),
                }
                results.append(result)
                
        except Exception as e:
            print(f"Error in batch {idx} at {target_pct}: {e}")
            continue
    
    # Checkpoint after each percentage
    with open(output_file, "w", encoding="utf-8") as f:
        for r in results:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")
    print(f"Checkpoint saved: {len(results)} results")
print(f"\nCompleted! Total results: {len(results)}")

In [None]:
# with open(output_file, "w", encoding="utf-8") as f:
#     f.write(json.dumps(results, ensure_ascii=False) + "\n")

In [None]:
# Sanity check
results_df = pd.DataFrame(results)
print("\nIDK rate by context length:")
print(results_df.groupby('target_pct')['raw_response'].apply(
    lambda x: (x.str.lower().str.contains("don't know").sum() / len(x) * 100)
).round(1).rename("IDK %"))

print("\nIDK rate by expansion type:")
print(results_df.groupby('expansion_type')['raw_response'].apply(
    lambda x: (x.str.lower().str.contains("don't know").sum() / len(x) * 100)
).round(1).rename("IDK %"))