In [1]:
import pandas as pd
from sentence_transformers import CrossEncoder
import torch
from tqdm import tqdm

# Load CSV file
df = pd.read_csv('../data/pipeline/BM25Ensemble_top100_original.csv')

# # Initialize CrossEncoder model
model = CrossEncoder('dragonkue/bge-reranker-v2-m3-ko', default_activation_function=torch.nn.Sigmoid())

# Function to rank and sort contexts based on model's score
def rank_contexts(row):
    question = row['question']
    
    # Collect all top-k contexts (in this case, top-100)
    contexts = [row[f'top{i}_context'] for i in range(1, 101)]
    
    # Prepare pairs of question and contexts
    input_pairs = [[question, context] for context in contexts]
    
    # Predict scores for each pair
    scores = model.predict(input_pairs)
    
    # Create a list of tuples (context, score) and sort them by score in descending order
    sorted_contexts = sorted(zip(contexts, scores), key=lambda x: x[1], reverse=True)
    
    # Return sorted contexts (only the contexts, not the scores)
    return [context for context, score in sorted_contexts]

# Apply the ranking function to each row in the dataframe
for index, row in tqdm(df.iterrows(), total=len(df), desc="Ranking contexts"):
    sorted_contexts = rank_contexts(row)
    
    # Save the sorted contexts back into the dataframe
    for i in range(1, 101):
        df.at[index, f'top{i}_context'] = sorted_contexts[i - 1]

# Save the updated dataframe back to a new CSV file
output_path = 'BM25Ensemble_top100_bge-reranker.csv'  # 정렬된 결과를 저장할 경로
df.to_csv(output_path, index=False)

print(f"Sorted CSV saved to {output_path}")


  from tqdm.autonotebook import tqdm, trange
Ranking contexts:   2%|▏         | 5/240 [03:08<2:27:39, 37.70s/it]


KeyboardInterrupt: 

In [1]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-7B-instruct", trust_remote_code=True)
# In case you want to reduce the maximum length:
model.max_seq_length = 512

queries = [
    "how much protein should a female eat",
    "summit define",
]
documents = [
    "As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
    "Definition of summit for English Language Learners. : 1  the highest point of a mountain : the top of a mountain. : 2  the highest level. : 3  a meeting or series of meetings between the leaders of two or more governments.",
]

query_embeddings = model.encode(queries, prompt_name="query")
document_embeddings = model.encode(documents)
print(query_embeddings.shape)
print(document_embeddings.shape)
query_embeddings = model.encode(queries, prompt_name="query",batch_size=4)
document_embeddings = model.encode(documents,batch_size=4)
print(query_embeddings.shape)
print(document_embeddings.shape)

scores = (query_embeddings @ document_embeddings.T) * 100
print(scores.tolist())


  from tqdm.autonotebook import tqdm, trange
Loading checkpoint shards: 100%|██████████| 7/7 [00:01<00:00,  6.16it/s]


(2, 3584)
(2, 3584)
(2, 3584)
(2, 3584)
[[70.39698028564453, 3.4318222999572754], [4.516165733337402, 81.91805267333984]]
