In [None]:
# Get queries and answers from MMLU
mmlu_queries = {}
mmlu_choices = {}
mmlu_answers = {}
mmlu_subjects = {}

for i, data in enumerate(mmlu_dataset['test']):
    mmlu_queries[i] = data['question']
    mmlu_choices[i] = data['choices']
    mmlu_subjects[i] = data['subject']
    mmlu_answers[i] = data['choices'][data['answer']]

In [3]:
from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from sentence_transformers import SentenceTransformer

import logging
import pathlib, os

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

#### Download scifact.zip dataset and unzip the dataset
dataset = "MSMARCO"
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
# out_dir = os.path.join('/data/richard/taggerv2/test/test6/beir/outputs', "datasets")
data_path = '/data/richard/taggerv2/test/test6/beir/outputs/datasets/msmarco'

#### Provide the data_path where scifact has been downloaded and unzipped
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="dev")

2025-06-03 01:52:08 - Loading Corpus...


  0%|          | 0/8841823 [00:00<?, ?it/s]

2025-06-03 01:52:32 - Loaded 8841823 DEV Documents.
2025-06-03 01:52:32 - Doc Example: {'text': 'The presence of communication amid scientific minds was equally important to the success of the Manhattan Project as scientific intellect was. The only cloud hanging over the impressive achievement of the atomic researchers and engineers is what their success truly meant; hundreds of thousands of innocent lives obliterated.', 'title': ''}
2025-06-03 01:52:32 - Loading Queries...
2025-06-03 01:52:33 - Loaded 6980 DEV Queries.
2025-06-03 01:52:33 - Query Example: how many years did william bradford serve as governor of plymouth colony?


In [4]:
list(qrels.values())[0]

{'7067032': 1}

In [6]:
import torch
idx = torch.tensor([0, 1, 2])

orig = [1, 2, 3, 4, 5]
left = [orig[index] for index in idx]
left

[1, 2, 3]

In [None]:
doc2score = list(qrels.values())[123]
doc2score = {doc_id:int(score > 0) for doc_id, score in doc2score.items()}
doc2score

In [None]:
doc2score

In [None]:
list(queries.keys())[0]

In [None]:
from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from sentence_transformers import SentenceTransformer
import torch

def msmarco_collate_fn(batch):
    # batch is a list of tuples (query, paragraphs, scores)
    queries, paras, scores = zip(*batch)
    return list(queries), list(paras), list(scores)

class MSMARCO_dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        
        # #### Download scifact.zip dataset and unzip the dataset
        # dataset = "MSMARCO"
        # url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
        # # out_dir = os.path.join('/data/richard/taggerv2/test/test6/beir/outputs', "datasets")
        # data_path = '/data/richard/taggerv2/test/test6/beir/outputs/datasets/msmarco'

        if split == 'train':
            corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="train")
        elif split == 'test':
            corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

        self.corpus = corpus

        self.corpus_keys = list(self.corpus.keys())
        self.corpus_values = list(self.corpus.values())
        self.qrels_keys = list(qrels.keys())
        self.queries = list(queries.values())
        self.qrels_values = list(qrels.values())

    def __getitem__(self, idx):

        query = self.queries[idx]
        doc2score = self.qrels_values[idx]            # a dict {doc_id: score}
        doc2score = {doc_id:int(score > 0) for doc_id, score in doc2score.items()}
        doc_ids, scores = zip(*doc2score.items())     # two tuples
        paragraphs     = [ self.corpus[d] for d in doc_ids ]
        return query, list(paragraphs), list(scores)

    def __len__(self):
        return len(self.qrels_keys)
    

from torch.utils.data import DataLoader

# 1) Instantiate the dataset (e.g. the “train” split)
ds = MSMARCO_dataset(split='train')


In [11]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM

llm_model_name = 'meta-llama/Llama-2-7b-chat-hf'
tokenizer = AutoTokenizer.from_pretrained(llm_model_name, padding_side="left")
llm_model = AutoModelForCausalLM.from_pretrained(llm_model_name).to('cuda:0')



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [12]:
# from beir import util, LoggingHandler
# from beir.retrieval import models
# from beir.datasets.data_loader import GenericDataLoader
# from beir.retrieval.evaluation import EvaluateRetrieval
# from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
# from sentence_transformers import SentenceTransformer
# import torch

# def llm_output(llm_model, tokenizer, queries, query_facts, 
#                max_length=256,
#                batch_size=None, **generate_kwargs):
#     """
#     Perform batch inference with a causal LLM, combining each query with its associated facts.

#     Inputs:
#         llm_model:      An accelerator-wrapped AutoModelForCausalLM already on the correct device.
#         tokenizer:      The corresponding AutoTokenizer (not wrapped by accelerator).
#         queries:        List[str] of length B, each a user query.
#         query_facts:    List[List[str]] of length B, where query_facts[i] is a list of fact-strings for queries[i].
#         max_length:     Maximum total generation length (including prompt).
#         batch_size:     Optional int. If provided, processes inputs in chunks of this size; 
#                         otherwise, all B inputs are processed in one batch.
#         **generate_kwargs: Additional keyword arguments to pass to llm_model.generate().

#     Returns:
#         actions: List[str] of length B, where each element is the decoded output for the corresponding query+facts.
#     """
#     device = 'cuda:0'
#     tokenizer.pad_token = tokenizer.eos_token

#     # 1. Build full prompts by concatenating each query with its facts
#     prompts = []
#     for q, facts in zip(queries, query_facts):
#         # Example prompt template: you can adjust this to your preferred formatting
#         # Here, we prefix with "Question:" and list facts under "Facts:"
#         fact_section = ""
#         if facts:
#             fact_section = "Facts:\n" + "\n".join(f"- {f}" for f in facts) + "\n"
#         prompt = f"Question: {q}\n{fact_section}Answer: (Please choose from 'A, B, C, D' and answer with the letter at the very beginning or very end of you response.)"
#         prompts.append(prompt)

#     # 2. Decide batch size
#     B = len(prompts)
#     if batch_size is None or batch_size >= B:
#         batch_size = B

#     all_outputs = []
#     llm_model.eval()
#     with torch.no_grad():
#         # 3. Process in chunks of size batch_size
#         for i in range(0, B, batch_size):
#             chunk_prompts = prompts[i : i + batch_size]

#             # 3a. Tokenize the chunk of prompts
#             encoding = tokenizer(
#                 chunk_prompts,
#                 return_tensors="pt",
#                 padding=True,
#                 truncation=True,
#             )
#             input_ids = encoding["input_ids"].to(device)
#             attention_mask = encoding["attention_mask"].to(device)

#             # 3b. Generate output IDs
#             generated_ids = llm_model.generate(
#                 input_ids=input_ids,
#                 attention_mask=attention_mask,
#                 max_length=max_length,
#                 pad_token_id=tokenizer.eos_token_id,
#                 **generate_kwargs,
#             )

#             # 3c. Decode each generated sequence (skip the prompt tokens if desired)
#             #    Here, we decode the entire generated_ids and return full text.
#             decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
#             all_outputs.extend(decoded)

#     return all_outputs



In [31]:
import torch

def llm_output(llm_model, tokenizer, queries, query_facts, 
               max_length=256,
               batch_size=None, **generate_kwargs):
    """
    Perform batch inference with a causal LLM, combining each query with its associated facts.
    Returns only the model’s generated responses (i.e., strips off the prompt).

    Inputs:
        llm_model:      An accelerator-wrapped AutoModelForCausalLM already on the correct device.
        tokenizer:      The corresponding AutoTokenizer (not wrapped by accelerator).
        queries:        List[str] of length B, each a user query.
        query_facts:    List[List[str]] of length B, where query_facts[i] is a list of fact-strings for queries[i].
        max_length:     Maximum total generation length (including prompt).
        batch_size:     Optional int. If provided, processes inputs in chunks of this size; 
                        otherwise, all B inputs are processed in one batch.
        **generate_kwargs: Additional keyword arguments to pass to llm_model.generate().

    Returns:
        responses: List[str] of length B, where each element is the model’s decoded output (response-only).
    """
    device = llm_model.device
    # Ensure we have a pad token to avoid generate() errors
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # 1. Build full prompts by concatenating each query with its facts
    prompts = []
    for q, facts in zip(queries, query_facts):
        fact_section = ""
        if facts:
            fact_section = "Facts:\n" + "\n".join(f"- {f}" for f in facts) + "\n"
        prompt = (
            f"Question: {q}\n"
            f"{fact_section}"
            "Answer: (Please choose from 'A, B, C, D' and answer with the letter at the very beginning or very end of your response.)"
        )
        prompts.append(prompt)

    # 2. Decide batch size
    B = len(prompts)
    if batch_size is None or batch_size >= B:
        batch_size = B

    all_responses = []
    llm_model.eval()
    with torch.no_grad():
        # 3. Process in chunks of size batch_size
        for i in range(0, B, batch_size):
            chunk_prompts = prompts[i : i + batch_size]

            # 3a. Tokenize the chunk of prompts
            encoding = tokenizer(
                chunk_prompts,
                return_tensors="pt",
                padding=True,
                truncation=True,
            )
            # Compute prompt token lengths from the attention mask (before moving to device)
            prompt_lengths = (encoding["attention_mask"].sum(dim=1)).tolist()

            input_ids = encoding["input_ids"].to(device)
            attention_mask = encoding["attention_mask"].to(device)

            # 3b. Generate output IDs
            generated_ids = llm_model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=max_length,
                pad_token_id=tokenizer.eos_token_id,
                **generate_kwargs,
            )
            # generated_ids shape: (chunk_size, seq_len_total)

            # 3c. For each generated sequence, strip off prompt tokens
            for gen_ids, prompt_len in zip(generated_ids, prompt_lengths):
                # gen_ids: full sequence of length seq_len_total
                continuation_ids = gen_ids[prompt_len:]
                # Decode only the newly generated tokens
                clean_text = tokenizer.decode(continuation_ids, skip_special_tokens=True).strip()
                all_responses.append(clean_text)

    return all_responses


In [34]:
# Assume you already have:
#   tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name)
#   llm_model = AutoModelForCausalLM.from_pretrained(args.llm_model_name)
#   llm_model, optimizer, scheduler = accelerator.prepare(llm_model, optimizer, scheduler)

# Suppose you have a batch of 4 queries and associated facts:
queries = [
    "What causes rainbows? A. White sunlight. B. Air. C. Blue sunlight",
    # "How do mitochondria generate energy?",
    # "Explain the significance of the Battle of Hastings.",
    # "Describe the process of polymerase chain reaction."
]
query_facts = [
    ["Sunlight is refracted in raindrops", "White light splits into colors"],
    # ["Mitochondria have inner membranes", "Oxidative phosphorylation occurs"],
    # ["Battle occurred in 1066", "Involved William the Conqueror and Harold Godwinson"],
    # ["PCR requires cycles of heating and cooling", "Uses DNA polymerase enzymes"]
]

outputs = llm_output(
    llm_model=llm_model,
    tokenizer=tokenizer,
    queries=queries,
    query_facts=query_facts,
    max_length=1024,
    temperature=0.7,
    top_p=0.95,
    batch_size=2  # process two prompts at a time
)

for out in outputs:
    print(out)


A. White sunlight

Rainbows are formed when sunlight passes through water droplets in the air. The light is refracted, or bent, as it passes through the droplets, and is split into its individual colors, which are then dispersed by the shape of the droplets. This is why the colors of the rainbow always appear in the same order: red, orange, yellow, green, blue, indigo, and violet. The color of the rainbow is determined by the wavelength of the light, with red light having the longest wavelength and violet light having the shortest.


In [35]:
print(outputs[0])

A. White sunlight

Rainbows are formed when sunlight passes through water droplets in the air. The light is refracted, or bent, as it passes through the droplets, and is split into its individual colors, which are then dispersed by the shape of the droplets. This is why the colors of the rainbow always appear in the same order: red, orange, yellow, green, blue, indigo, and violet. The color of the rainbow is determined by the wavelength of the light, with red light having the longest wavelength and violet light having the shortest.


In [None]:

# 2) Pull out a single example by index
query, paragraphs, scores = ds[5]

# 3) Print or inspect
print("Query:", query)
print("Number of paragraphs:", len(paragraphs))
print("Paragraph text:", paragraphs)
print("Corresponding score for first paragraph:", scores)

In [None]:
from sentence_transformers import SentenceTransformer, util
import torch
# query = "How many people live in London?"
# docs = ["Around 9 Million people live in London", "London is known for its financial district"]

#Load the model
# model = SentenceTransformer('sentence-transformers/msmarco-distilbert-base-tas-b')    # TAS-B
# model = SentenceTransformer('sentence-transformers/msmarco-distilbert-base-v2')     # SBERT
model = SentenceTransformer('sentence-transformers/gtr-t5-xl').to('cuda:0')      # gtr-t5-xl
# model = SentenceTransformer('BAAI/bge-large-en-v1.5')       # BGE
# model = SentenceTransformer('BAAI/llm-embedder')
# model.load_state_dict(torch.load('/data/richard/taggerv2/test/test6/beir/outputs/ckpts/2025_05_27_17h55m37s/model_step_440075.pth'))
model.load_state_dict(torch.load('/data/richard/taggerv2/test/test6/beir/outputs/ckpts/2025_05_30_21h52m36s/model_step_251471.pth'))


In [None]:
queries

In [None]:
q_id = '182539'
query = queries[q_id]
query

In [None]:
docs = []
for key, val in qrels[q_id].items():
    if val == 1:
        docs.append(corpus[key]['text'])

In [None]:
#Encode query and documents
query_emb = model.encode(query)
doc_emb = model.encode(docs)

#Compute dot score between query and all document embeddings
scores = util.dot_score(query_emb, doc_emb)[0].cpu().tolist()

#Combine docs & scores
doc_score_pairs = list(zip(docs, scores))

#Sort by decreasing score
doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)

#Output passages & scores
for doc, score in doc_score_pairs:
    print(score, doc)




In [None]:
import torch

# assume scores is your [200×8e6] tensor
scores = torch.randn(200, 8_000_000, device='cuda')

# topk returns (values, indices); we only need indices here
_, top10_idx = scores.topk(k=10, dim=1, largest=True, sorted=True)

print(top10_idx.shape)   # torch.Size([200, 10])
# top10_idx[i] is a length-10 LongTensor of column-IDs for the i-th row’s top scores


In [None]:
top10_idx

In [None]:
from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from sentence_transformers import SentenceTransformer

#### Download scifact.zip dataset and unzip the dataset
dataset = "MSMARCO"
data_path = '/data/richard/taggerv2/test/test6/beir/outputs/datasets/msmarco'

#### Provide the data_path where scifact has been downloaded and unzipped
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

In [None]:
queries.keys()

In [None]:
list(corpus.keys())[:5]

In [None]:
list(qrels.values())[0]

In [None]:
len(corpus.values())

In [None]:
from sentence_transformers import SentenceTransformer, util

# query = "How many people live in London?"
# docs = ["Around 9 Million people live in London", "London is known for its financial district"]

input_queries = list(queries.values())
input_docs = list(corpus.values())

#Load the model
# model = SentenceTransformer('sentence-transformers/msmarco-distilbert-base-tas-b')    # TAS-B
# model = SentenceTransformer('sentence-transformers/msmarco-distilbert-base-v2')     # SBERT
model = SentenceTransformer('sentence-transformers/gtr-t5-xl').to('cuda:0')      # gtr-t5-xl
# model = SentenceTransformer('BAAI/bge-large-en-v1.5')       # BGE

#Encode query and documents
query_emb = model.encode(input_queries)
doc_emb = model.encode(input_docs)
print(doc_emb.shape)
raise ValueError


#Compute dot score between query and all document embeddings
scores = util.dot_score(query_emb, doc_emb)[0].cpu().tolist()

#Combine docs & scores
doc_score_pairs = list(zip(docs, scores))

#Sort by decreasing score
doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)

#Output passages & scores
for doc, score in doc_score_pairs:
    print(score, doc)


In [None]:
from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from sentence_transformers import SentenceTransformer

import logging
import pathlib, os

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

#### Download scifact.zip dataset and unzip the dataset
dataset = "MSMARCO"
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
# out_dir = os.path.join('/data/richard/taggerv2/test/test6/beir/outputs', "datasets")
data_path = '/data/richard/taggerv2/test/test6/beir/outputs/datasets/msmarco'

#### Provide the data_path where scifact has been downloaded and unzipped
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

#### Load the SBERT model and retrieve using cosine-similarity
# model = DRES(models.SentenceBERT("Alibaba-NLP/gte-modernbert-base"), batch_size=16)
# model = DRES(models.SentenceBERT("msmarco-roberta-base-ance-firstp"))
model = DRES(models.SentenceBERT('sentence-transformers/gtr-t5-xl'))

#### Or load models directly from HuggingFace
# model = DRES(models.HuggingFace(
#     "intfloat/e5-large-unsupervised",
#     max_length=512,
#     pooling="mean",
#     normalize=True,
#     prompts={"query": "query: ", "passage": "passage: "}), batch_size=16)

# model = SentenceTransformer('sentence-transformers/gtr-t5-xl')      # gtr-t5-xl

retriever = EvaluateRetrieval(model, score_function="cos_sim") # or "dot" for dot product
results = retriever.retrieve(corpus, queries)

#### Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K  where k = [1,3,5,10,100,1000]
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
print(f'ndcg, _map, recall, precision: {ndcg, _map, recall, precision}')
mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")

### If you want to save your results and runfile (useful for reranking)
results_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "results")
os.makedirs(results_dir, exist_ok=True)

#### Save the evaluation runfile & results
util.save_runfile(os.path.join(results_dir, f"{dataset}.run.trec"), results)
util.save_results(os.path.join(results_dir, f"{dataset}.json"), ndcg, _map, recall, precision, mrr)

In [None]:
len(corpus.keys())

In [None]:
models.SentenceBERT("Alibaba-NLP/gte-modernbert-base")

In [None]:
from sentence_transformers import SentenceTransformer, export_optimized_onnx_model

model = SentenceTransformer("sentence-transformers/gtr-t5-xl", device='cuda:0', backend="onnx")


In [None]:
from sentence_transformers import SentenceTransformer

# # load (and auto-export) your ONNX model
# model = SentenceTransformer("sentence-transformers/gtr-t5-xl",
#                             device="cuda:0",
#                             backend="onnx")

# save the ONNX graph + config/modules to disk
model.save_pretrained("./local-gtr-t5-xl-onnx")

model = SentenceTransformer("./local-gtr-t5-xl-onnx", backend="onnx")



In [None]:
export_optimized_onnx_model(model, "O4", "/data/richard/taggerv2/test/test6/onnx")

In [None]:
model = SentenceTransformer(
    "path/to/my/mpnet-legal-finetuned",
    backend="onnx",
    model_kwargs={"file_name": "onnx/model_O3.onnx"},
    device='cuda:0'
)