In [177]:
import json
import torch
from torch.utils.data import Dataset
from enum import Enum

class DatasetType(Enum):
    QUERY = 0,
    DOC = 1

class EncoderDataset(Dataset):
    def __init__(self, dataset_type: DatasetType, input_path, tokenizer, max_seq_len=None, max_lines=None, prefix_examples=None, qrels_filter_path=None):
        self.dataset_type = dataset_type
        self.input_path = input_path
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.max_lines = max_lines
        self.task = 'Given a query, retrieve relevant passages that answer the query.'
        self.data = []
        self.qrels_filter_path = qrels_filter_path

        self._load_examples_prefix(prefix_examples)
        self._load_data(qrels_filter_path)

    def _load_qrels(self, qrels_filter_path):
        qids = set()
        with open(qrels_filter_path, 'r') as file:
            for line in file:
                qid = line.strip().split()[0]  # Assuming QREL format where QID is the first column
                qids.add(qid)
        return qids

    def _load_data(self, qrels_filter_path=None):
        # Load QIDs filter from qrels if provided
        qids_filter = set()
        if self.dataset_type == DatasetType.QUERY and self.qrels_filter_path:
            qids_filter = self._load_qrels(qrels_filter_path)
            print(f"Loaded {len(qids_filter)} qids from qrels filter.")

        # Load the data
        with open(self.input_path, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if self.max_lines and i >= self.max_lines:
                    break
                data = json.loads(line)
                id = data["_id"].replace("doc", "").replace("test", "")  # Remove prefixes
                
                # Filter queries if QIDs filter is applied
                if self.dataset_type == DatasetType.QUERY and qids_filter and id not in qids_filter:
                    continue

                title = data.get("title", "")
                text = data["text"]
                passage = title + " " + text if title else text
                self.data.append({"id": int(id), "text": passage})

    def _load_examples_prefix(self, examples):
        if examples is None:
            examples = [
                {'instruct': self.task,
                'query': 'what is a virtual interface',
                'response': "A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes."},
                {'instruct': self.task,
                'query': 'causes of back pain in female for a week',
                'response': "Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or improper posture, spinal issues like herniated discs or osteoporosis, menstrual cramps causing referred pain, urinary tract infections, or pelvic inflammatory disease. Pregnancy-related changes can also contribute. Stress and lack of physical activity may exacerbate symptoms. Proper diagnosis by a healthcare professional is crucial for effective treatment and management."}
                ]            
        examples = [self.get_detailed_example(e['instruct'], e['query'], e['response']) for e in examples]
        self.examples_prefix = '\n\n'.join(examples) + '\n\n' 

    def get_detailed_example(self, task_description: str, query: str, response: str) -> str:
        return f'<instruct>{task_description}\n<query>{query}\n<response>{response}'

    def __len__(self):
        """Returns the size of the dataset."""
        return len(self.data)

    def __getitem__(self, idx):
        """Returns a tokenized sample."""
        sample = self.data[idx]
        text = sample["text"]
        id = sample["id"]

        return {
            "id": id,
            "text": text
        }
    
    def collate_fn(self, batch):
        # Extract the elements in the batch
        ids = [sample['id'] for sample in batch]
        texts = [sample['text'] for sample in batch]

        max_len = self.max_seq_len
        if self.dataset_type == DatasetType.QUERY:
            max_len, texts = self.get_new_queries(texts)
            
        tokenized = self.tokenizer(
            texts,
            max_length=max_len,
            padding=True,
            truncation=True,
            return_tensors="pt"
        )
        
        return {
            "ids": torch.tensor(ids, dtype=torch.long),  # Convert ids to tensor
            "input_ids": tokenized["input_ids"],
            "attention_mask": tokenized["attention_mask"],
        }

    def get_new_queries(self, queries):
        inputs = self.tokenizer(
            queries,
            max_length=self.max_seq_len - len(self.tokenizer('<s>', add_special_tokens=False)['input_ids']) - len(
                self.tokenizer('\n<response></s>', add_special_tokens=False)['input_ids']),
            return_token_type_ids=False,
            truncation=True,
            return_tensors=None,
            add_special_tokens=False
        )
        prefix_ids = self.tokenizer(self.examples_prefix, add_special_tokens=False)['input_ids']
        suffix_ids = self.tokenizer('\n<response>', add_special_tokens=False)['input_ids']
        new_max_length = (len(prefix_ids) + len(suffix_ids) + self.max_seq_len + 8) // 8 * 8 + 8
        new_queries = self.tokenizer.batch_decode(inputs['input_ids'])
        for i in range(len(new_queries)):
            new_queries[i] = self.examples_prefix + new_queries[i] + '\n<response>'
        return new_max_length, new_queries


In [178]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-en-icl')
dataset = EncoderDataset(DatasetType.QUERY, '../data/nq/queries.jsonl', tokenizer, max_seq_len=512, max_lines=1000)

In [179]:
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=dataset.collate_fn)

In [180]:
for batch in dataloader:
    print(batch)
    break

{'ids': tensor([500, 720]), 'input_ids': tensor([[    0,     1, 32000, 12628,   264,  5709, 28725, 20132,  8598,  1455,
          1291,   369,  4372,   272,  5709, 28723,    13, 32001,   767,   349,
           264,  8252,  4971,    13, 32002,   330,  8252,  4971,   349,   264,
          3930, 28733, 11498,   534,  6781,   445,   369, 26302,  1063,   272,
          6174,   304, 15559,   302,   264,  5277,  3681,  4971, 28723,   661,
          5976,  5166, 16441,  3681, 12284,   298,  4098,   272,  1348,  5277,
          3681,  4971, 28725, 25748,  9096,  4479,  1837,   302,  3681,  5823,
         28723, 19032,   791,  9288,   460, 14473,  1307,   297,  8252,  1837,
         14880,  1259,   390,  8252, 12155,   304, 25399,   298,  3084,  3681,
          5789,  2574,  1671, 22579, 10383, 13218, 28723,  1306, 25729, 17574,
          3681, 24991,   304,  1316,   297,  9777,  1077,  3681,  8475,   354,
          4908,   304,  5411, 10700, 28723,    13,    13, 32000, 12628,   264,
          5

In [181]:
tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=False)

['<unk><s><instruct> Given a query, retrieve relevant passages that answer the query.\n<query> what is a virtual interface\n<response> A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes.\n\n<instruct> Given a query, retrieve relevant passages that answer the query.\n<query> causes of back pain in female for a week\n<response> Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or improper 

In [5]:
from enum import Enum
import json

class DatasetType(Enum):
    QUERY = 0,
    DOC = 1

def load_qrels(qrels_path):
        qids = set()
        with open(qrels_path, 'r') as file:
            for line in file:
                qid = line.strip().split()[0]
                qid = qid.replace("query", "").replace("test", "").replace("train", "").replace("dev", "")
                qids.add(qid)
        return qids

def load_data_from_jsonl(dataset_type, input_path, qrels_filter_path=None, start_line=0, max_lines=None):
        data_arr = []
        qids_filter = set()
        if dataset_type == DatasetType.QUERY and qrels_filter_path:
            qids_filter = load_qrels(qrels_filter_path)
            print(f"Loaded {len(qids_filter)} qids from qrels filter.")

        # Load the data
        with open(input_path, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                # continue until start line reached
                if i < start_line:
                    continue

                # break if max lines reached
                if max_lines and i - start_line >= max_lines:
                    break

                data = json.loads(line)
                id = data["_id"].replace("doc", "").replace("test", "").replace("train", "").replace("dev", "")
                
                # Filter queries if QIDs filter is applied
                if dataset_type == DatasetType.QUERY and qids_filter and id not in qids_filter:
                    print("Skipping query", id)
                    continue

                title = data.get("title", "")
                text = data["text"]
                passage = title + "\n" + text if title and title != "" else text
                data_arr.append({"id": int(id), "text": passage})

        return data_arr

In [6]:
from torch.utils.data import Dataset

class RawTextDataset(Dataset):
    def __init__(self, dataset_type: DatasetType, input_path, start_line=0, max_lines=None, qrels_filter_path=None):
        self.dataset_type = dataset_type
        self.input_path = input_path
        self.max_lines = max_lines
        self.qrels_filter_path = qrels_filter_path

        # Load data from the input JSONL file
        self.data = load_data_from_jsonl(dataset_type, input_path, qrels_filter_path, start_line, max_lines)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return {
            "id": sample["id"],
            "text": sample["text"]
        }

    def collate_fn(self, batch):
        ids = [sample['id'] for sample in batch]
        texts = [sample['text'] for sample in batch]

        if self.dataset_type == DatasetType.QUERY:
            texts = [f'Instruct: Retrieve relevant passages.\nQuery: {text}' for text in texts]

        return {
            "ids": ids,
            "text": texts  # Just return raw text strings
        }

In [7]:
dataset = RawTextDataset(DatasetType.DOC, '../data/nq/corpus.jsonl', max_lines=1000)

In [10]:
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=dataset.collate_fn)

In [11]:
for batch in dataloader:
    print(batch)
    break

{'ids': [924, 879], 'text': ['Geography of Spain\nSpain also has a small exclave inside France called Llívia.', 'Comanche\nThe Comanche sheathed their tipis with a covering made of buffalo hides sewn together. To prepare the buffalo hides, women first spread them on the ground, then scraped away the fat and flesh with blades made from bones or antlers, and left them in the sun. When the hides were dry, they scraped off the thick hair, and then soaked them in water. After several days, they vigorously rubbed the hides in a mixture of animal fat, brains, and liver to soften the hides. The hides were made even more supple by further rinsing and working back and forth over a rawhide thong. Finally, they were smoked over a fire, which gave the hides a light tan color. To finish the tipi covering, women laid the tanned hides side by side and stitched them together. As many as 22 hides could be used, but 14 was the average. When finished, the hide covering was tied to a pole and raised, wrapp

In [12]:
batch['text'][1]

'Comanche\nThe Comanche sheathed their tipis with a covering made of buffalo hides sewn together. To prepare the buffalo hides, women first spread them on the ground, then scraped away the fat and flesh with blades made from bones or antlers, and left them in the sun. When the hides were dry, they scraped off the thick hair, and then soaked them in water. After several days, they vigorously rubbed the hides in a mixture of animal fat, brains, and liver to soften the hides. The hides were made even more supple by further rinsing and working back and forth over a rawhide thong. Finally, they were smoked over a fire, which gave the hides a light tan color. To finish the tipi covering, women laid the tanned hides side by side and stitched them together. As many as 22 hides could be used, but 14 was the average. When finished, the hide covering was tied to a pole and raised, wrapped around the cone-shaped frame, and pinned together with pencil-sized wooden skewers. Two wing-shaped flaps at 