In [18]:
from transformers import DPRQuestionEncoder, DPRContextEncoder, DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer

# question encodera and tokenizer
question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')

# context encoder and tokenizer
context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')


Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the

In [None]:
import json
import os
from transformers import pipeline, BartTokenizer, BartForConditionalGeneration

# Load BART model and tokenizer
model_name = "facebook/bart-large"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)

# Create text generation pipeline
generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)

def clean_generated_text(generated_text):
    """Remove irrelevant prefix content"""
    prefixes = [
        "Generate a description for this table row:",
        "Description for this table row:",
        "a description for this table row:",
        "table row:"
    ]
    for prefix in prefixes:
        if generated_text.lower().startswith(prefix.lower()):
            generated_text = generated_text[len(prefix):].strip()
    return generated_text

def generate_sentence_for_row(row, headers):
    """Generate a descriptive sentence for a table row"""
    # Extract useful table information, ignore empty values or invalid content
    row_text = ", ".join([f"{header}: {value[0]}" for header, value in zip(headers, row) if value[0]])
    
    # Ensure table data is reasonable
    if not row_text:
        return "No relevant data"
    
    # Generate prompt for descriptive sentence
    prompt = f"Generate a descriptive sentence for this table row: {row_text}"
    
    try:
        # Generate concise descriptive text
        output = generator(prompt, max_length=50, num_return_sequences=1)
        generated_text = output[0]['generated_text'].strip()
        
        # Post-processing: remove unnecessary prefixes
        cleaned_text = clean_generated_text(generated_text)
        
        return cleaned_text
    except Exception as e:
        print(f"Error generating sentence: {e}")
        return "Error"

def process_table(json_data):
    """Process the entire table, generate descriptive sentences for each row"""
    headers = [header[0] for header in json_data['header']]
    sentences = []
    for row in json_data['data']:
        sentence = generate_sentence_for_row(row, headers)
        sentences.append(sentence)
    return sentences

def process_files(input_json_path, table_tok_dir, output_dir, num_files=None):

    os.makedirs(output_dir, exist_ok=True)
    
    # Read initial JSON file
    with open(input_json_path, 'r') as file:
        input_data = json.load(file)
    
    # Get all table_ids
    table_ids = [item['table_id'] for item in input_data]
    
    # If number of files is specified, only process that many
    if num_files is not None:
        table_ids = table_ids[:num_files]
    
    for table_id in table_ids:
        # Construct input JSON file path
        input_file_path = os.path.join(table_tok_dir, f"{table_id}.json")
        
        # Check if file exists
        if not os.path.exists(input_file_path):
            print(f"File not found: {input_file_path}")
            continue
        
        # Read JSON file
        with open(input_file_path, 'r') as file:
            json_data = json.load(file)
        
        # Process table
        sentences = process_table(json_data)
        
        # Write results to file
        output_file_path = os.path.join(output_dir, f"{table_id}.txt")
        with open(output_file_path, 'w', encoding='utf-8') as file:
            formatted_sentences = [f'"{sentence}"' for sentence in sentences if sentence]
            file.write(',\n'.join(formatted_sentences))
        
        print(f"Summary has been written to {output_file_path}")

input_json_path = "dev.json"
table_tok_dir = "WikiTables-WithLinks-master/tables_tok"  # Folder containing JSON files with table data
output_dir = "row_summary"  # Output directory
    
num_files = input("Enter the number of files to process (press Enter for all): ").strip()
num_files = int(num_files) if num_files else None
    
process_files(input_json_path, table_tok_dir, output_dir, num_files)

In [25]:
import torch

# question encoder
question = "What place was achieved by the person who finished the Berlin marathon in 2 hours, 13 minutes, ans 32 seconds in 2011 the first time he competed in a marathon ?"
question_inputs = question_tokenizer(question, return_tensors='pt')
question_embedding = question_encoder(**question_inputs).pooler_output


contexts = [
'Patrick Makau Musyoki from Kenya finished in 1st place with a time of 2:3:38',
'Stephen Kwelio Chemlany from Kenya came in 2nd place with a time of 2:7:55',
 "Edwin Kimaiyo from Kenya secured 3rd place with a time of 2:9:50",
"Felix Limo from Kenya finished 4th with a time of 2:10:38",
    "Scott Overall from the United Kingdom placed 5th with a time of 2:10:55",
    "Ricardo Serrano from Spain took 6th place with a time of 2:13:32",
    "Simon Munyutu from France came in 8th place with a time of 2:14:20",
    "Driss El Himer from France secured 9th place with a time of 2:14:46",
    "Hendrick Ramaala from South Africa finished in 10th place with a time of 2:16:0",
]

context_embeddings = []
for context in contexts:
    context_inputs = context_tokenizer(context, return_tensors='pt')
    context_embedding = context_encoder(**context_inputs).pooler_output
    context_embeddings.append(context_embedding)

similarities = [torch.nn.functional.cosine_similarity(question_embedding, context_embedding).item() for context_embedding in context_embeddings]

most_similar_idx = similarities.index(max(similarities))
print(f"The most relevant document is: '{contexts[most_similar_idx]}' with similarity score {similarities[most_similar_idx]}")

The most relevant document is: 'Scott Overall from the United Kingdom placed 5th with a time of 2:10:55' with similarity score 0.5861918926239014


In [21]:
import torch
from transformers import AutoTokenizer, AutoModel
import re

# 使用适合问答任务的模型
model_name = "deepset/roberta-base-squad2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

def encode_text(text):
    inputs = tokenizer(text, return_tensors='pt', max_length=512, truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1)

question = "What is the middle name of the player with the second most National Football League career rushing yards?"
question_embedding = encode_text(question)

keywords = ["second", "most", "rushing yards"]

def retrieve_context(question, contexts, top_k=6):

    context_scores = []
    for idx, context in enumerate(contexts):
        context_embedding = encode_text(context)
        similarity = torch.nn.functional.cosine_similarity(question_embedding, context_embedding).item()
        keyword_score = sum(keyword.lower() in context.lower() for keyword in keywords)
        context_scores.append((idx, similarity + keyword_score * 0.1))  # 将相似度和关键词得分结合
    top_k_indices = sorted(range(len(context_scores)), key=lambda i: context_scores[i][1], reverse=True)[:top_k]
    
    candidates = [contexts[i] for i in top_k_indices]
    final_scores = []
    for candidate in candidates:
        # 提取排名和码数
        rank_match = re.search(r'^(\d+)', candidate)
        yards_match = re.search(r'\| (\d{1,3}(,\d{3})*) \|', candidate)
        if rank_match and yards_match:
            rank = int(rank_match.group(1))
            yards = int(yards_match.group(1).replace(',', ''))
            final_scores.append(rank_score + yards_score)
        else:
            final_scores.append(0)
    
    best_index = final_scores.index(max(final_scores))
    return candidates[best_index]

retrieved_context = retrieve_context(question, contexts)
print("Retrieved context:", retrieved_context)

Some weights of RobertaModel were not initialized from the model checkpoint at deepset/roberta-base-squad2 and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Retrieved context: Walter Payton, ranked second, played for the Chicago Bears (1975-1987). He had 3,838 carries for 16,726 yards, averaging 4.4 yards per carry.


In [18]:
import heapq

similarities = [torch.nn.functional.cosine_similarity(question_embedding, context_embedding).item() for context_embedding in context_embeddings]

# Get the indices of the top n highest similarity scores
top_n_indices = heapq.nlargest(9, range(len(similarities)), key=similarities.__getitem__)

# Print the top n most relevant documents and their similarity scores
for i, idx in enumerate(top_n_indices, 1):
    print(f"Top {i} relevant document: '{contexts[idx]}' with similarity score {similarities[idx]}")



Top 1 relevant document: 'Walter Payton, ranked second, played for the Chicago Bears (1975-1987). He had 3,838 carries for 16,726 yards, averaging 4.4 yards per carry.' with similarity score -0.0035529606975615025
Top 2 relevant document: 'Frank Gore, ranked third, played for the San Francisco 49ers (2005-2014), Indianapolis Colts (2015-2017), Miami Dolphins (2018), and Buffalo Bills (2019-present). He had 3,548 carries for 15,347 yards, averaging 4.3 yards per carry.' with similarity score -0.007335161790251732
Top 3 relevant document: 'Emmitt Smith, ranked first, played for the Dallas Cowboys (1990-2002) and Arizona Cardinals (2003-2004). He had 4,409 carries for 18,355 yards, averaging 4.2 yards per carry.' with similarity score -0.03946688771247864


In [25]:
from transformers import T5ForConditionalGeneration, T5Tokenizer

# 加载T5模型和tokenizer
model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# 输入问题
question = "What is the middle name of the player with the second most National Football League career rushing yards?"

# 准备T5输入，使用关键词提取任务
input_text = "extract keywords: " + question
inputs = tokenizer(input_text, return_tensors="pt", padding=True)

# 生成关键词，调整生成参数以限制输出长度
outputs = model.generate(
    inputs['input_ids'], 
    max_length=16,  # 限制最大生成长度
    num_beams=5, 
    early_stopping=True, 
    no_repeat_ngram_size=1  # 避免重复
)


generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
cleaned_text = generated_text.replace("extract keywords:", "").strip()  

keywords = [word.strip() for word in cleaned_text.split() if len(word) > 2]  # 过滤长度小于2的词

print("Extracted keywords:", keywords)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Extracted keywords: ['Which', 'middle', 'name', 'the', 'player', 'with', 'second', 'most', 'career', 'rushing']


In [None]:
import torch
from transformers import AutoTokenizer, AutoModel, DPRQuestionEncoder, DPRContextEncoder, DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch.nn.functional as F
import re
import json
import os

# Initialize DPR models and tokenizers
question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')

# Load RoBERTa model for keyword and rule processing
roberta_model_name = "deepset/roberta-base-squad2"
tokenizer = AutoTokenizer.from_pretrained(roberta_model_name)
roberta_model = AutoModel.from_pretrained(roberta_model_name)

# Load T5 model and tokenizer
t5_model_name = "t5-small"
t5_tokenizer = T5Tokenizer.from_pretrained(t5_model_name)
t5_model = T5ForConditionalGeneration.from_pretrained(t5_model_name)

# DPR encoding function
def encode_text_dpr(text, encoder, tokenizer):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
    with torch.no_grad():
        return encoder(**inputs).pooler_output

# Keyword extraction function
def extract_keywords(question):
    input_text = "extract keywords: " + question
    inputs = t5_tokenizer(input_text, return_tensors="pt", padding=True)
    outputs = t5_model.generate(
        inputs['input_ids'], 
        max_length=16,
        num_beams=5, 
        early_stopping=True, 
        no_repeat_ngram_size=1
    )
    generated_text = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
    cleaned_text = generated_text.replace("extract keywords:", "").strip()
    keywords = [word.strip() for word in cleaned_text.split() if len(word) > 2]
    return keywords

# Stage 1: Initial DPR retrieval
def initial_retrieval(question, contexts, top_k=8):
    question_embedding = encode_text_dpr(question, question_encoder, question_tokenizer)
    context_scores = []
    
    for idx, context in enumerate(contexts):
        context_embedding = encode_text_dpr(context, context_encoder, context_tokenizer)
        similarity_dpr = F.cosine_similarity(question_embedding, context_embedding).item()
        context_scores.append((idx, similarity_dpr))
    
    top_k_indices = sorted(range(len(context_scores)), key=lambda i: context_scores[i][1], reverse=True)[:top_k]
    return [contexts[i] for i in top_k_indices], top_k_indices

def dpr_fine_ranking(question, candidates, keywords):
    question_embedding = encode_text_dpr(question, question_encoder, question_tokenizer)
    final_scores = []
    
    for candidate in candidates:
        candidate_embedding = encode_text_dpr(candidate, context_encoder, context_tokenizer)
        dpr_similarity = F.cosine_similarity(question_embedding, candidate_embedding).item()
        
        keyword_score = sum(keyword.lower() in candidate.lower() for keyword in keywords)
        
        final_score = dpr_similarity + keyword_score * 0.1  # Keyword score weight 0.1, adjustable
        final_scores.append((candidate, final_score))
    
    # Sort by score in descending order
    sorted_results = sorted(final_scores, key=lambda x: x[1], reverse=True)
    
    # Return top three results
    return sorted_results[:3]

def retrieve_context(question, contexts, top_k=8):
    keywords = extract_keywords(question)
    candidates, _ = initial_retrieval(question, contexts, top_k=top_k)
    top_contexts = dpr_fine_ranking(question, candidates, keywords)
    return top_contexts

def retrieve_and_save(question, contexts, file_name, output_dir="row_retrieve"):
    os.makedirs(output_dir, exist_ok=True)
    
    top_contexts = retrieve_context(question, contexts)
    
    output_data = {
        "table_id": os.path.splitext(file_name)[0],
        "question": question
    }
    
    for i, (context, score) in enumerate(top_contexts, 1):
        output_data[f"retrieve_content{i}"] = context
        output_data[f"number{i}"] = contexts.index(context) + 1  # +1 because row numbers usually start from 1
        output_data[f"score{i}"] = score

    output_file = os.path.join(output_dir, f"{os.path.splitext(file_name)[0]}_retrieve.json")
    
    # Write to JSON file
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, ensure_ascii=False, indent=4)
    print(f"Results have been saved to {output_file}")

question = "What was the nickname of the gold medal winner in the men's heavyweight greco-roman wrestling event of the 1932 Summer Olympics?"
file_name = "Sweden_at_the_1932_Summer_Olympics_0.txt"
file_path = os.path.join("row_summary", file_name)

with open(file_path, 'r', encoding='utf-8') as file:
    contexts = file.readlines()
contexts = [line.strip() for line in contexts if line.strip()]

retrieve_and_save(question, contexts, file_name)

In [16]:
import json
import os
from transformers import pipeline, BartTokenizer, BartForConditionalGeneration

# Load BART model and tokenizer
model_name = "facebook/bart-large"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)

# Create text generation pipeline
generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)

def clean_generated_text(generated_text):
    """Remove irrelevant prefix content"""
    prefixes = [
        "Generate a description for this table row:",
        "Description for this table row:",
        "a description for this table row:",
        "table row:"
    ]
    for prefix in prefixes:
        if generated_text.lower().startswith(prefix.lower()):
            generated_text = generated_text[len(prefix):].strip()
    return generated_text

def generate_sentence_for_row(row, headers):
    """Generate a descriptive sentence for a table row"""
    # Extract useful table information, ignore empty or invalid content
    row_text = ", ".join([f"{header}: {value[0]}" for header, value in zip(headers, row) if value[0]])
    
    # Ensure table data is reasonable
    if not row_text:
        return "No relevant data"
    
    # Generate prompt for descriptive sentence
    prompt = f"Generate a description for this table row: {row_text}"
    
    try:
        # Generate concise descriptive text
        output = generator(prompt, max_length=50, num_return_sequences=1)
        generated_text = output[0]['generated_text'].strip()
        
        # Post-processing: remove unnecessary prefixes
        cleaned_text = clean_generated_text(generated_text)
        return cleaned_text
    except Exception as e:
        print(f"Error generating sentence: {e}")
        return "Error"

def process_table(json_data):
    """Process the entire table, generating descriptive sentences for each row"""
    headers = [header[0] for header in json_data['header']]
    sentences = []
    for row in json_data['data']:
        sentence = generate_sentence_for_row(row, headers)
        sentences.append(sentence)
    return sentences

def process_files(input_json_path, table_tok_dir, output_dir, num_files=None):
    """Process specified number of files"""
    os.makedirs(output_dir, exist_ok=True)
    
    # Read initial JSON file
    with open(input_json_path, 'r') as file:
        input_data = json.load(file)
    
    # Get all table_ids
    table_ids = [item['table_id'] for item in input_data]
    
    # If number of files is specified, only process that many
    if num_files is not None:
        table_ids = table_ids[:num_files]
    
    for table_id in table_ids:
        # Construct input JSON file path
        input_file_path = os.path.join(table_tok_dir, f"{table_id}.json")
        
        # Check if file exists
        if not os.path.exists(input_file_path):
            print(f"File not found: {input_file_path}")
            continue
        
        # Read JSON file
        with open(input_file_path, 'r') as file:
            json_data = json.load(file)
        
        # Process table
        sentences = process_table(json_data)
        
        # Write results to file
        output_file_path = os.path.join(output_dir, f"{table_id}.txt")
        with open(output_file_path, 'w', encoding='utf-8') as file:
            formatted_sentences = [f'"{sentence}"' for sentence in sentences if sentence]
            file.write(',\n'.join(formatted_sentences))
        
        print(f"Summary has been written to {output_file_path}")

input_json_path = "released_data/dev_traced.json"
table_tok_dir = "WikiTables-WithLinks-master/tables_tok"  # Folder containing table data JSON files
output_dir = "row_summary"  # Output directory

num_files = input("Enter the number of files to process (press Enter for all): ").strip()
num_files = int(num_files) if num_files else None

process_files(input_json_path, table_tok_dir, output_dir, num_files)


Enter the number of files to process (press Enter for all):  20


Summary has been written to row_summary/List_of_National_Football_League_rushing_yards_leaders_0.txt
Summary has been written to row_summary/Sweden_at_the_1932_Summer_Olympics_0.txt
Summary has been written to row_summary/2004_United_States_Grand_Prix_0.txt
Summary has been written to row_summary/List_of_museums_in_Atlanta_0.txt
Summary has been written to row_summary/2011_Berlin_Marathon_0.txt
Summary has been written to row_summary/List_of_football_stadiums_in_Paraguay_0.txt
Summary has been written to row_summary/List_of_wealthiest_non-inflated_historical_figures_13.txt
Summary has been written to row_summary/1929_International_Cross_Country_Championships_0.txt
Summary has been written to row_summary/List_of_Somali_cities_by_population_0.txt
Summary has been written to row_summary/Flora_and_fauna_of_Madhya_Pradesh_0.txt
Summary has been written to row_summary/List_of_Mohun_Bagan_A.C._managers_0.txt
Summary has been written to row_summary/List_of_Indonesian_dishes_3.txt
Summary has b

In [14]:
import torch
from transformers import AutoTokenizer, AutoModel, DPRQuestionEncoder, DPRContextEncoder, DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch.nn.functional as F
import json
import os

class FlexibleRetrievalSystem:
    def __init__(self):
        self.question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
        self.context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
        self.question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
        self.context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')

        roberta_model_name = "deepset/roberta-base-squad2"
        self.tokenizer = AutoTokenizer.from_pretrained(roberta_model_name)
        self.roberta_model = AutoModel.from_pretrained(roberta_model_name)

        t5_model_name = "t5-small"
        self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_model_name)
        self.t5_model = T5ForConditionalGeneration.from_pretrained(t5_model_name)

    def encode_text_dpr(self, text, encoder, tokenizer):
        inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
        with torch.no_grad():
            return encoder(**inputs).pooler_output

    def extract_keywords(self, question):
        input_text = "extract keywords: " + question
        inputs = self.t5_tokenizer(input_text, return_tensors="pt", padding=True)
        outputs = self.t5_model.generate(
            inputs['input_ids'], 
            max_length=16,
            num_beams=5, 
            early_stopping=True, 
            no_repeat_ngram_size=1
        )
        generated_text = self.t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
        cleaned_text = generated_text.replace("extract keywords:", "").strip()
        keywords = [word.strip() for word in cleaned_text.split() if len(word) > 2]
        return keywords

    def initial_retrieval(self, question, contexts, top_k=8):
        question_embedding = self.encode_text_dpr(question, self.question_encoder, self.question_tokenizer)
        context_scores = []
        
        for idx, context in enumerate(contexts):
            context_embedding = self.encode_text_dpr(context, self.context_encoder, self.context_tokenizer)
            similarity_dpr = F.cosine_similarity(question_embedding, context_embedding).item()
            context_scores.append((idx, similarity_dpr))
        
        top_k_indices = sorted(range(len(context_scores)), key=lambda i: context_scores[i][1], reverse=True)[:top_k]
        return [contexts[i] for i in top_k_indices], top_k_indices

    def dpr_fine_ranking(self, question, candidates, keywords):
        question_embedding = self.encode_text_dpr(question, self.question_encoder, self.question_tokenizer)
        final_scores = []
        
        for candidate in candidates:
            candidate_embedding = self.encode_text_dpr(candidate, self.context_encoder, self.context_tokenizer)
            dpr_similarity = F.cosine_similarity(question_embedding, candidate_embedding).item()
            
            keyword_score = sum(keyword.lower() in candidate.lower() for keyword in keywords)
            
            final_score = dpr_similarity + keyword_score * 0.1  # Keyword score weight 0.1, adjustable
            final_scores.append((candidate, final_score))
        
        sorted_results = sorted(final_scores, key=lambda x: x[1], reverse=True)
        return sorted_results[:3]

    def retrieve_context(self, question, contexts, top_k=8):
        keywords = self.extract_keywords(question)
        candidates, _ = self.initial_retrieval(question, contexts, top_k=top_k)
        top_contexts = self.dpr_fine_ranking(question, candidates, keywords)
        return top_contexts

    def retrieve_and_save(self, question, contexts, file_name, output_dir="row_retrieve"):
        os.makedirs(output_dir, exist_ok=True)
        
        top_contexts = self.retrieve_context(question, contexts)
        
        output_data = {
            "table_id": os.path.splitext(file_name)[0],
            "question": question
        }
        
        for i, (context, score) in enumerate(top_contexts, 1):
            # Remove extra quotes and backslashes from the context
            cleaned_context = context.strip('"').replace('\\"', '"')
            output_data[f"retrieve_content{i}"] = cleaned_context
            output_data[f"number{i}"] = contexts.index(context) + 1
            output_data[f"score{i}"] = score

        output_file = os.path.join(output_dir, f"{os.path.splitext(file_name)[0]}_retrieve.json")
        
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(output_data, f, ensure_ascii=False, indent=4)

        print(f"Results have been saved to {output_file}")


def process_questions_from_json(json_file_path, retrieval_system, contexts_dir="row_summary", output_dir="row_retrieve"):
    with open(json_file_path, 'r', encoding='utf-8') as f:
        questions_data = json.load(f)

    for item in questions_data:
        question_id = item['question_id']
        question = item['question']
        table_id = item['table_id']
        
        file_name = f"{table_id}.txt"
        file_path = os.path.join(contexts_dir, file_name)
        
        if not os.path.exists(file_path):
            print(f"File not found: {file_path}")
            continue
        
        with open(file_path, 'r', encoding='utf-8') as file:
            contexts = file.readlines()
        contexts = [line.strip() for line in contexts if line.strip()]
        
        retrieval_system.retrieve_and_save(question, contexts, file_name, output_dir)
        print(f"Processed question ID: {question_id}")


In [15]:
def process_questions_from_json(json_file_path, retrieval_system, contexts_dir="row_summary", output_dir="row_retrieve", num_questions=None):
    with open(json_file_path, 'r', encoding='utf-8') as f:
        questions_data = json.load(f)

    if num_questions is not None:
        questions_data = questions_data[:num_questions]

    for item in questions_data:
        question_id = item['question_id']
        question = item['question']
        table_id = item['table_id']
        
        file_name = f"{table_id}.txt"
        file_path = os.path.join(contexts_dir, file_name)
        
        if not os.path.exists(file_path):
            print(f"File not found: {file_path}")
            continue
        
        with open(file_path, 'r', encoding='utf-8') as file:
            contexts = file.readlines()
        contexts = [line.strip() for line in contexts if line.strip()]
        
        retrieval_system.retrieve_and_save(question, contexts, file_name, output_dir)
        print(f"Processed question ID: {question_id}")

    print(f"Processed {len(questions_data)} questions.")


retrieval_system = FlexibleRetrievalSystem()
json_file_path = "released_data/dev_traced.json"
    
    # 添加用户输入来决定处理的问题数量
num_questions = input("Enter the number of questions to process (press Enter for all): ").strip()
num_questions = int(num_questions) if num_questions else None
    
process_questions_from_json(json_file_path, retrieval_system, num_questions=num_questions)

Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the

Enter the number of questions to process (press Enter for all):  3


Results have been saved to row_retrieve/List_of_National_Football_League_rushing_yards_leaders_0_retrieve.json
Processed question ID: 00153f694413a536
Results have been saved to row_retrieve/Sweden_at_the_1932_Summer_Olympics_0_retrieve.json
Processed question ID: 001a9923f31d6a91
Results have been saved to row_retrieve/2004_United_States_Grand_Prix_0_retrieve.json
Processed question ID: 0035c791af3d9666
Processed 3 questions.
