In [52]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
import random
import json
import os
from typing import List, Tuple
import datasets

class data_set_retrieval(Dataset):
    def __init__(self, args):
        if os.path.isdir(args.train_data):
            train_datasets = []
            for file in os.listdir(args.train_data):
                temp_dataset = datasets.load_dataset('json', data_files=os.path.join(args.train_data, file), split='train')
                train_datasets.append(temp_dataset)    
            self.dataset = datasets.concatenate_datasets(train_datasets)
            
        else:
            self.dataset = datasets.load_dataset('json', data_files=args.train_data, split='train')

        self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
        self.args = args
        self.total_len = len(self.dataset)
        self.passage_max_len = args.passage_max_len
        self.query_max_len = args.query_max_len

    def __len__(self):
        return self.total_len
    
    
    def __getitem__(self, item) -> Tuple[str, List[str]]:
        query = self.dataset[item]['query']
        
        passages_negative, passages_positive = [], []

        assert isinstance(self.dataset[item]['pos'], list)
        passages_positive.extend(self.dataset[item]['pos'])

        passages_negative.extend(self.dataset[item]['neg'])
        
        return query, passages_positive, passages_negative


    def collate_fn(self, batch):
        query, positive, negative = zip(*batch)

        if isinstance(query[0], list):
            query = sum(query, [])
        if isinstance(positive[0], list):
            positive = sum(positive, [])
        if isinstance(negative[0], list):
            negative = sum(negative, [])

        query_token = self.tokenizer(
            query,
            padding=True,
            truncation=True,
            max_length=self.query_max_len,
            return_tensors="pt",
        )
        positive_token = self.tokenizer(
            positive,
            padding=True,
            truncation=True,
            max_length=self.passage_max_len,
            return_tensors="pt",
        )
        
        negative_token = self.tokenizer(
            negative,
            padding=True,
            truncation=True,
            max_length=self.passage_max_len,
            return_tensors="pt",
        )
        return query_token, positive_token, negative_token
    

In [53]:
class Args:
    train_data = '/home/ltngoc/ngoclt/Thesis-2023.2/datasets/train_data/train_step_0/train_step_0.jsonl'
    passage_max_len = 512
    query_max_len = 64
    
args = Args

In [63]:
x = data_set_retrieval(args=args)

In [66]:
dataloader = DataLoader(x, batch_size=32, collate_fn=x.collate_fn)