In [1]:
import os

cache_dir ='/scratch/hakeem.at/Queryable-Shared-Reference-Repository/notebooks/pretrained_models'

os.environ['HF_HOME'] = cache_dir
os.environ['TRANSFORMERS_CACHE'] = cache_dir
os.environ['HUGGINGFACE_HUB_CACHE'] = cache_dir

import json
import random
from tqdm.auto import tqdm
import pandas as pd

import torch
from vllm import LLM, SamplingParams

seed = 42
random.seed(seed)




INFO 11-18 15:41:07 [__init__.py:216] Automatically detected platform cuda.


In [2]:
input_file = "prompt_thresholding_eval_dataset.json"
with open(input_file, "r", encoding="utf-8") as f:
    eval_dataset = json.load(f)

eval_data = []
random.seed(42)
random.shuffle(eval_dataset)
max_data_points = 500
for idx, item in enumerate(eval_dataset):
    if not item: 
        continue
    if idx==max_data_points:
        break
    
    questions = [
        {
            'question': item['answerable_query'],
            'question_type': 'answerable',
            'ground_truth': item['excerpt'],
            'chunk': item['chunk'],
            'source': item['source']
        },
        {
            'question': item['imaginative_query'],
            'question_type': 'borderline',
            'ground_truth': None,
            'chunk': item['chunk'],
            'source': item['source']
        },
        {
            'question': item['unanswerable_query'],
            'question_type': 'unanswerable',
            'ground_truth': None,
            'chunk': item['chunk'],
            'source': item['source']
        }
    ]
    eval_data.extend(questions)
    

print(f"Total evaluation questions: {len(eval_data)}")
print(f"Question type distribution:")
for qtype in ['answerable', 'borderline', 'unanswerable']:
    count = sum(1 for q in eval_data if q['question_type'] == qtype)
    print(f"  {qtype}: {count}")

Total evaluation questions: 1500
Question type distribution:
  answerable: 500
  borderline: 500
  unanswerable: 500


In [None]:
PROMPTS = {
    'baseline': """
<context>
{context}
</context>

<question>
{question}
</question>

<answer>""",
    
    'explicit_idk': """<instructions>
Answer the question using ONLY information from the context provided below. If the context does not contain enough information to answer the question, respond with exactly: "I don't know."
</instructions>

<context>
{context}
</context>

<question>
{question}
</question>

<answer>""",
    
    'confidence_threshold': """<instructions>
Answer only if you are at least 80% confident that your answer is correct based on the context provided. 

Scoring system:
- Mistakes cost you 4 points (calculated as 80/(100-80) = 4)
- Correct answers earn 1 point
- Saying "I don't know" earns 0 points

If your confidence is below 80%, respond with exactly: "I don't know."
</instructions>

<context>
{context}
</context>

<question>
{question}
</question>

<answer>""",
    
    'confidence_rubric': """<instructions>
Before answering, evaluate these 5 criteria:

1. The answer is explicitly stated in the context (not requiring inference or speculation)
2. All necessary information to answer the question is present in the context
3. There is no contradictory or ambiguous information in the context
4. You can identify the specific sentence(s) in the context that support your answer
5. The question asks for information of the type provided in the context (not future predictions, comparisons to unmentioned work, etc.)

Only answer if at least 4 out of 5 criteria are clearly satisfied. Otherwise, respond with exactly: "I don't know."
</instructions>

<context>
{context}
</context>

<question>
{question}
</question>

<answer>"""
}

model_id = "Qwen/Qwen3-8B"  

gpu_memory_utilization = 0.95
max_model_len = 4096
max_num_seqs = 64
enforce_eager = True

model = LLM(
    model=model_id,
    gpu_memory_utilization=gpu_memory_utilization,
    max_model_len=max_model_len,
    max_num_seqs=max_num_seqs,
    enforce_eager=enforce_eager,
    trust_remote_code=True,  
)

In [4]:
sampling_params = SamplingParams(
    temperature=0,
    max_tokens=512,
    stop = ["</answer>"],
    # include_stop_str_in_output = True
)

In [5]:
# test_prompt = """<context>
# Researchers found trash at the bottom of the ocean.
# </context>

# <question>
# Which brand of beer did the researchers found at the bottom of the ocean
# </question>

# <answer>"""

# # test_params = SamplingParams(temperature=0, max_tokens=100)
# output = model.generate([test_prompt], sampling_params)[0]
# print(f"âœ… Response: {output.outputs[0].text}")

In [None]:
results = []
output_file = "prompt_thresholding_responses.jsonl"
batch_size = 2048

for prompt_type, prompt_template in PROMPTS.items():
    print(f"\n{'='*60}")
    print(f"Running inference with prompt type: {prompt_type}")
    print(f"{'='*60}\n")
    
    for idx in tqdm(range(0, len(eval_data), batch_size), desc=f"{prompt_type}"):
        batch_data = eval_data[idx:min(idx + batch_size, len(eval_data))]
        
        batch_prompts = []
        for item in batch_data:
            prompt = prompt_template.format(
                context=item['chunk'],
                question=item['question']
            )
            batch_prompts.append(prompt)
        
        try:
            responses = model.generate(
                batch_prompts,
                sampling_params,
            )
            
            for item, response in zip(batch_data, responses):
                result = {
                    'question': item['question'],
                    'question_type': item['question_type'],
                    'context': item['chunk'],
                    'source': item['source'],
                    'ground_truth': item['ground_truth'],
                    'prompt_type': prompt_type,
                    'raw_response': response.outputs[0].text.strip(),
                }
                results.append(result)
                
                # with open(output_file, "a", encoding="utf-8") as f:
                #     f.write(json.dumps(results, ensure_ascii=False) + "\n")
                    
        except Exception as e:
            print(f"Error in batch {idx}: {e}")
            continue

In [13]:
with open(output_file, "w", encoding="utf-8") as f:
    f.write(json.dumps(results, ensure_ascii=False) + "\n")