In [1]:
!pip install -U transformers



## Local Inference on GPU
Model page: https://huggingface.co/meta-llama/Llama-2-7b-hf

⚠️ If the generated code snippets do not work, please open an issue on either the [model repo](https://huggingface.co/meta-llama/Llama-2-7b-hf)
			and/or on [huggingface.js](https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/model-libraries-snippets.ts) 🙏

The model you are trying to use is gated. Please make sure you have access to it by visiting the model page.To run inference, either set HF_TOKEN in your environment variables/ Secrets or run the following cell to login. 🤗

In [1]:
from huggingface_hub import login
login(new_session=False)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
messages = [
    {"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(
	messages,
	add_generation_prompt=True,
	tokenize=True,
	return_dict=True,
	return_tensors="pt",
).to(model.device)

outputs = model.generate(**inputs, max_new_tokens=40)
print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:]))

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


I'm an artificial intelligence model known as Llama. Llama stands for "Large Language Model Meta AI."<|eot_id|>


In [4]:
import torch
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM

# 1. Clean up memory from previous attempts
if 'model' in globals():
    del model
if 'pipe' in globals():
    del pipe
gc.collect()
torch.cuda.empty_cache()

# 2. Global Model Loading (Resource Efficient)
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load model in bfloat16 (Best for A100, saves memory)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
model.eval() # Set to evaluation mode
print("Model loaded successfully!")

# 3. Define the function using the global model
def call_llama2(prompt, logprob=False):
    """
    Calls the Llama-3.1 model.
    Supports returning logprobs for uncertainty estimation.
    """
    # Apply chat template for Llama 3 to ensure best performance
    messages = [{"role": "user", "content": prompt}]
    input_ids = tokenizer.apply_chat_template(
        messages,
        return_tensors="pt",
        add_generation_prompt=True
    ).to(model.device)

    # EXPLICITLY create attention mask of 1s
    # This prevents the model from ignoring parts of the prompt if pad_token_id == eos_token_id
    attention_mask = torch.ones_like(input_ids)

    input_len = input_ids.shape[1]

    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            attention_mask=attention_mask, # Pass the mask
            max_new_tokens=200,
            do_sample=not logprob, # Use greedy if we need consistent logprobs, sample otherwise
            temperature=0.7 if not logprob else None,
            top_p=0.9 if not logprob else None,
            return_dict_in_generate=True,
            output_scores=logprob,
            pad_token_id=tokenizer.eos_token_id
        )

    sequences = outputs.sequences[0]
    generated_ids = sequences[input_len:]  # Remove prompt tokens
    generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

    result = {"text": generated_text, "avg_logprob": None, "uncertainty": None}

    if logprob:
        # Calculate average log probability from scores
        # scores is a tuple of tensors (one per generated token)
        scores = outputs.scores
        token_logprobs = []

        for t, token_id in enumerate(generated_ids):
            if t < len(scores):
                # scores[t] is (batch_size, vocab_size), we take batch 0
                logits = scores[t][0]
                # Calculate log_softmax to get log probabilities
                log_probs = torch.log_softmax(logits, dim=-1)
                # Get the log prob of the chosen token
                token_logprobs.append(log_probs[token_id].item())

        if token_logprobs:
            avg_logprob = sum(token_logprobs) / len(token_logprobs)
            result["avg_logprob"] = avg_logprob
            result["uncertainty"] = -avg_logprob # Higher uncertainty = lower logprob (more negative)

    return result

Loading model...此过程可能需要几分钟


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model loaded successfully!


In [5]:
import sqlite3

In [6]:
import json
def load_multipleqa_entries(json_path):
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    entries = []
    for item in data:
        ambiguous_question = (item.get("question") or "").strip()
        if not ambiguous_question:
            continue
        annotations = item.get("annotations", [])
        for ann in annotations:
            if ann.get("type") != "multipleQAs":
                continue
            qa_pairs = ann.get("qaPairs", [])
            interpretations = [
                qa.get("question").strip()
                for qa in qa_pairs
                if qa.get("question")
            ]
            if interpretations:
                entries.append(
                    {
                        "id": item.get("id"),
                        "question": ambiguous_question,
                        "interpretations": interpretations,
                        "qa_pairs": qa_pairs,  # 保存完整的 qa_pairs
                    }
                )
    return entries

In [15]:
import requests
import json
from dotenv import load_dotenv
import os
import sqlite3

DATASET_PATH = "train_light.json"

MAX_REQUESTS = 500

TEXT_DB_PATH = "clarification_texts.db"
def extract_interpretations_and_answers(gpt_text):
    pairs = []
    current_interpretation = None

    lines = [line.strip() for line in gpt_text.splitlines() if line.strip()]
    for line in lines:
        lower = line.lower()
        if lower.startswith("interpretation"):

            parts = line.split(":", 1)
            current_interpretation = parts[1].strip() if len(parts) > 1 else ""
        elif lower.startswith("answer") and current_interpretation is not None:
            parts = line.split(":", 1)
            answer = parts[1].strip() if len(parts) > 1 else ""
            pairs.append((current_interpretation, answer))
            current_interpretation = None

    return pairs

def init_text_database(db_path: str = TEXT_DB_PATH):

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    cursor.execute("""
        CREATE TABLE IF NOT EXISTS clarifications (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            source_id TEXT,
            ambiguous_question TEXT,
            search_result_idx INTEGER,
            concatenated_text TEXT,
            embedding BLOB
        )
    """)

    conn.commit()
    conn.close()

def call_google_search_api(query, num_results=10):
    url = "https://www.googleapis.com/customsearch/v1"
    params = {
        "key": API key,
        "cx": search engine id,
        "q": query,
        "num": num_results
    }

    response = requests.get(url, params=params)
    if response.status_code == 200:
        data = response.json()
        results = []
        for item in data.get("items", []):
            results.append({
                "title": item.get("title", ""),
                "snippet": item.get("snippet", ""),
                "link": item.get("link", "")
            })
        return results
    else:
        raise Exception(f"Google Search API return error: {response.status_code}: {response.text}")


def relax_question(ambiguous_question):
    relax_prompt = f"""Given an ambiguous question, relax it to turn it into a suitable query for web search engines. Please do not add restrictions or extra details. Return only the relaxed query text, nothing else.

Ambiguous question: {ambiguous_question}
"""
    try:
        result = call_llama2(relax_prompt)
        relaxed_query = result['text']
        print("relaxed query：",relaxed_query)
        return relaxed_query
    except Exception as e:
        print(f"Relax failed，use original question: {str(e)}")
        return ambiguous_question

def load_multipleqa_entries(json_path):
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    entries = []
    for item in data:
        ambiguous_question = (item.get("question") or "").strip()
        if not ambiguous_question:
            continue
        annotations = item.get("annotations", [])
        for ann in annotations:
            if ann.get("type") != "multipleQAs":
                continue
            qa_pairs = ann.get("qaPairs", [])
            interpretations = [
                qa.get("question").strip()
                for qa in qa_pairs
                if qa.get("question")
            ]
            if interpretations:
                entries.append(
                    {
                        "id": item.get("id"),
                        "question": ambiguous_question,
                        "interpretations": interpretations,
                        "qa_pairs": qa_pairs
                    }
                )
    return entries

def build_prompt(ambiguous_question, passage):
    intro = """Given an ambiguous query and one of the passages from retrieval results, provide a disambiguated query which can be answered by the passage. Try to infer the user's intent with the ambiguous query and think of possible concrete, non-ambiguous rewritten questions. If you cannot find any of
them, which can be answered by the provided document, simply abstain by replying with 'null'. You should provide at most one subquestion, the most relevant one you can think of.
Here are the rules to follow when generating the question and answer:
1. The generated question must be a disambiguation of the original ambiguous query.
2. The question should be fully answerable from information present in given passage. Even if the passage is relevant to the original ambiguous query, if it is not self-contained, abstain by
responding with 'null'.
3. Make sure the question is clear and unambiguous, while clarifying the intent of the original
ambiguous question.
4. Phrases like 'based on the provided context', 'according to the passage', etc., are not allowed to
appear in the question. Similarly, questions such as "What is not mentioned about something in the
passage?" are not acceptable.
5. When addressing questions tied to a specific moment, provide the clearest possible time
reference. Avoid ambiguous questions such as "Which country has won the most recent World
Cup?" since the answer varies depending on when the question is asked.
6. The answer must be specifically based on the information provided in the passage. Your prior
knowledge should not intervene in answering the identified clarification question.
Input fields are:Question: {ambiguous question (q)}Passage: {passage (p)}
Output fields are:Interpretation: {generated interpretation (ˆq)}Answer: {generated answer (ˆy)}""".strip()

    prompt = [intro]
    prompt.append(f"\nQuestion: {ambiguous_question}")
    prompt.append(f"\nPassage: {passage}")
    prompt.append("\n")
    return "\n".join(prompt)

def save_to_database(db_path, entry_id, question, qa_pairs, llm_answer, uncertainty, evaluation_score=None):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    qa_pairs_json = json.dumps(qa_pairs, ensure_ascii=False)

    cursor.execute("""
        INSERT OR REPLACE INTO results
        (id, question, qa_pairs, llm_answer, uncertainty, evaluation_score)
        VALUES (?, ?, ?, ?, ?, ?)
    """, (str(entry_id), question, qa_pairs_json, llm_answer, uncertainty, evaluation_score))

    conn.commit()
    conn.close()
    print(f"Data saved to database (ID: {entry_id})\n")

def save_concatenated_text(
    ambiguous_question: str,
    concatenated_text: str,
    source_id: str = None,
    search_result_idx: int = None,
    db_path: str = TEXT_DB_PATH,
):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    cursor.execute(
        """
        INSERT INTO clarifications
        (source_id, ambiguous_question, search_result_idx, concatenated_text, embedding)
        VALUES (?, ?, ?, ?, NULL)
        """,
        (str(source_id) if source_id is not None else None, ambiguous_question, search_result_idx, concatenated_text),
    )

    conn.commit()
    conn.close()

In [16]:
init_text_database(TEXT_DB_PATH)

try:
    entries = load_multipleqa_entries(DATASET_PATH)
except FileNotFoundError:
    print(f"error:cannot find file {DATASET_PATH}！")
    exit(1)

if not entries:
    print("no valid multipleQAs in data。")
    exit(1)

total = len(entries)
start_idx = 201
end_idx = MAX_REQUESTS

print("=== Google Search & Clarification Generation ===\n")
for i in range(start_idx - 1, end_idx):
    entry = entries[i]
    print(f"Processing Entry [{i+1}/{total}]: {entry['question']}")

    relaxed_query = relax_question(entry["question"])

    print(f"calling Google Search API ({relaxed_query})，please wait\n")
    try:
        search_results = call_google_search_api(relaxed_query)
        print(f"found {len(search_results)} search results\n")
    except Exception as e:
        print(f"Google Search calling failed: {str(e)}\n")
        search_results = []

    if search_results:
        prompts = []
        for result in search_results:
            passage = f"{result['title']}\n{result['snippet']}"
            prompts.append(build_prompt(entry["question"], passage))

        gpt_texts = []
        try:
            tokenizer.padding_side = "left"
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token

            batch_messages = [[{"role": "user", "content": p}] for p in prompts]

            input_ids = tokenizer.apply_chat_template(
                batch_messages,
                return_tensors="pt",
                add_generation_prompt=True,
                padding=True
            ).to(model.device)

            attention_mask = (input_ids != tokenizer.pad_token_id).long()

            with torch.no_grad():
                generated_ids = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    max_new_tokens=200,
                    do_sample=False,
                    temperature=0.7, # Fixed: changed from 0.0 to 0.7
                    top_p=0.9,
                    pad_token_id=tokenizer.eos_token_id
                )

            input_len = input_ids.shape[1]
            new_tokens = generated_ids[:, input_len:]
            gpt_texts = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)

        except Exception as e:
            gpt_texts = []

        for result_idx, (result, gpt_text) in enumerate(zip(search_results, gpt_texts), 1):
            print(f"\n  processing search result {result_idx}/{len(search_results)}:")
            print(f"  title: {result['title']}")
            print(f"  link: {result['link']}")
            print("  response from Llama2：")
            print(f"  {gpt_text.strip()[:100]}...")
            ia_pairs = extract_interpretations_and_answers(gpt_text)
            if ia_pairs:
                concatenated = " || ".join(
                    [f"{interp} {ans}" for interp, ans in ia_pairs]
                )
                print(f"  -> extract succcess: {concatenated}")

                save_concatenated_text(
                    ambiguous_question=entry["question"],
                    concatenated_text=concatenated,
                    source_id=entry.get("id"),
                    search_result_idx=result_idx,
                    db_path=TEXT_DB_PATH,
                )
            else:
                print("  -> failed to extract Interpretation/Answer")

    else:
        print("No search results\n")

    print("=" * 60 + "\n")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  -> 提取成功: Where does the Grinch live in the story "How the Grinch Stole Christmas"? Mount Crumpit.

Processing Entry [450/4792]: Who sang the song a change is going to come?
发送给 LLM 的 relax prompt:
------------------------------------------------------------
Given an ambiguous question, relax it to turn it into a suitable query for web search engines. Please do not add restrictions or extra details. Return only the relaxed query text, nothing else.

Ambiguous question: Who sang the song a change is going to come?

------------------------------------------------------------

relaxed query： Song a change is going to come
正在调用 Google Search API (Song a change is going to come)，请稍候...

找到 10 个搜索结果

  正在批量调用 Llama 3 处理 10 个结果，这会快很多...

  生成的回复数量: 10 / 搜索结果数量: 10

  处理搜索结果 1/10:
  标题: Sam Cooke - A Change Is Gonna Come (Official Lyric Video ...
  链接: https://www.youtube.com/watch?v=wEBlaMOmKV4
  Llama2的回复：
  Interpretation: W