In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
import random
import json
import os
from typing import List, Tuple
import datasets

class data_set_retrieval(Dataset):
    def __init__(self, args):
        if os.path.isdir(args.train_data):
            train_datasets = []
            for file in os.listdir(args.train_data):
                temp_dataset = datasets.load_dataset('json', data_files=os.path.join(args.train_data, file), split='train')
                train_datasets.append(temp_dataset)    
            self.dataset = datasets.concatenate_datasets(train_datasets)
            
        else:
            self.dataset = datasets.load_dataset('json', data_files=args.train_data, split='train')

        self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
        self.args = args
        self.total_len = len(self.dataset)
        self.passage_max_len = args.passage_max_len
        self.query_max_len = args.query_max_len

    def __len__(self):
        return self.total_len
    
    
    def __getitem__(self, item) -> Tuple[str, List[str]]:
        query = self.dataset[item]['query']
        
        passages_negative, passages_positive = [], []

        assert isinstance(self.dataset[item]['pos'], list)
        passages_positive.extend(self.dataset[item]['pos'])

        passages_negative.extend(self.dataset[item]['neg'])
        
        return query, passages_positive, passages_negative


    def collate_fn(self, batch):
        query, positive, negative = zip(*batch)

        if isinstance(query[0], list):
            query = sum(query, [])
        if isinstance(positive[0], list):
            positive = sum(positive, [])
        if isinstance(negative[0], list):
            negative = sum(negative, [])

        query_token = self.tokenizer(
            query,
            padding=True,
            truncation=True,
            max_length=self.query_max_len,
            return_tensors="pt",
        )
        positive_token = self.tokenizer(
            positive,
            padding=True,
            truncation=True,
            max_length=self.passage_max_len,
            return_tensors="pt",
        )
        
        negative_token = self.tokenizer(
            negative,
            padding=True,
            truncation=True,
            max_length=self.passage_max_len,
            return_tensors="pt",
        )
        return query_token, positive_token, negative_token
    

In [4]:
class Args:
    train_data = '/home/thhiep/dta/Thesis-2023.2/datasets/train_data/train_step_0/train_step_0.jsonl'
    passage_max_len = 512
    query_max_len = 64
    
args = Args

In [5]:
x = data_set_retrieval(args=args)

In [6]:
dataloader = DataLoader(x, batch_size=32, collate_fn=x.collate_fn)

In [13]:
from tqdm import tqdm

for a, b in enumerate(tqdm(dataloader)):
    print(b)

  7%|▋         | 2/30 [00:00<00:02, 13.19it/s]

({'input_ids': tensor([[     0,    581,  41911,  ...,     83,     12,      2],
        [     0,    581,  41911,  ...,    268, 119893,      2],
        [     0,    581,  41911,  ...,    647,   1530,      2],
        ...,
        [     0,    581,  41911,  ...,    268,   5367,      2],
        [     0,    581,  41911,  ...,     47,    387,      2],
        [     0,    581,  41911,  ...,     12,      6,      2]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])}, {'input_ids': tensor([[     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        ...,
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      5,      6,      2],
 

 20%|██        | 6/30 [00:00<00:01, 13.73it/s]

({'input_ids': tensor([[    0,   581, 41911,  ...,   242,     7,     2],
        [    0,   581, 41911,  ...,    83,    12,     2],
        [    0,   581, 41911,  ..., 26349,   242,     2],
        ...,
        [    0,   581, 41911,  ...,     1,     1,     1],
        [    0,   581, 41911,  ...,     9, 19729,     2],
        [    0,   581, 41911,  ...,     4, 43799,     2]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])}, {'input_ids': tensor([[     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        ...,
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...

 27%|██▋       | 8/30 [00:00<00:01, 13.95it/s]

({'input_ids': tensor([[    0,   581, 41911,  ...,    83,    12,     2],
        [    0,   581, 41911,  ...,  1919,   378,     2],
        [    0,   581, 41911,  ...,     1,     1,     1],
        ...,
        [    0,   581, 41911,  ...,  3249,  9077,     2],
        [    0,   581, 41911,  ...,     1,     1,     1],
        [    0,   581, 41911,  ...,     6,     2,     1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 0]])}, {'input_ids': tensor([[     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        ...,
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...

 40%|████      | 12/30 [00:00<00:01, 14.26it/s]

({'input_ids': tensor([[     0,    581,  41911,  ...,  25188,     56,      2],
        [     0,    581,  41911,  ...,     42,  13471,      2],
        [     0,    581,  41911,  ...,   6338,   2450,      2],
        ...,
        [     0,    581,  41911,  ...,    647,   1530,      2],
        [     0,    581,  41911,  ..., 109513,      6,      2],
        [     0,    581,  41911,  ...,  49602, 180732,      2]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])}, {'input_ids': tensor([[     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        ...,
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
 

 47%|████▋     | 14/30 [00:00<00:01, 14.27it/s]

({'input_ids': tensor([[    0,   581, 41911,  ...,   647,  1530,     2],
        [    0,   581, 41911,  ..., 19725,     6,     2],
        [    0,   581, 41911,  ..., 16177,  1363,     2],
        ...,
        [    0,   581, 41911,  ..., 54100,   977,     2],
        [    0,   581, 41911,  ...,    10,    67,     2],
        [    0,   581, 41911,  ...,     1,     1,     1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0]])}, {'input_ids': tensor([[     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        ...,
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...

 60%|██████    | 18/30 [00:01<00:00, 14.34it/s]

({'input_ids': tensor([[    0,   581, 41911,  ...,  1919, 67373,     2],
        [    0,   581, 41911,  ...,     1,     1,     1],
        [    0,   581, 41911,  ...,    71,   390,     2],
        ...,
        [    0,   581, 41911,  ...,   378,   647,     2],
        [    0,   581, 41911,  ...,   268,     6,     2],
        [    0,   581, 41911,  ..., 27759,     6,     2]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])}, {'input_ids': tensor([[     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      5,      6,      2],
        ...,
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...

 67%|██████▋   | 20/30 [00:01<00:00, 14.24it/s]

({'input_ids': tensor([[     0,    581,  41911,  ...,    647,   1530,      2],
        [     0,    581,  41911,  ...,      1,      1,      1],
        [     0,    581,  41911,  ...,      5,     20,      2],
        ...,
        [     0,    581,  41911,  ...,   6827, 119805,      2],
        [     0,    581,  41911,  ...,    647,   3117,      2],
        [     0,    581,  41911,  ...,    647,   1530,      2]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])}, {'input_ids': tensor([[     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        ...,
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
 

 80%|████████  | 24/30 [00:01<00:00, 14.27it/s]

({'input_ids': tensor([[    0,   581, 41911,  ...,   927, 10740,     2],
        [    0,   581, 41911,  ...,  4015,   268,     2],
        [    0,   581, 41911,  ...,  4015,   268,     2],
        ...,
        [    0,   581, 41911,  ...,   268,   136,     2],
        [    0,   581, 41911,  ..., 13924,  3584,     2],
        [    0,   581, 41911,  ...,  1507,   378,     2]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])}, {'input_ids': tensor([[     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        ...,
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...

 87%|████████▋ | 26/30 [00:01<00:00, 14.33it/s]

({'input_ids': tensor([[     0,    581,  41911,  ...,      6,      4,      2],
        [     0,    581,  41911,  ...,      1,      1,      1],
        [     0,    581,  41911,  ...,     23, 111880,      2],
        ...,
        [     0,    581,  41911,  ...,  21652,   1363,      2],
        [     0,    581,  41911,  ...,      6,      2,      1],
        [     0,    581,  41911,  ...,    191, 109412,      2]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 0],
        [1, 1, 1,  ..., 1, 1, 1]])}, {'input_ids': tensor([[     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        ...,
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
 

100%|██████████| 30/30 [00:02<00:00, 14.38it/s]

({'input_ids': tensor([[     0,    581,  41911,  ...,    452,  10033,      2],
        [     0,    581,  41911,  ...,     23, 101085,      2],
        [     0,    581,  41911,  ...,      1,      1,      1],
        ...,
        [     0,    581,  41911,  ...,      1,      1,      1],
        [     0,    581,  41911,  ...,     71,  17265,      2],
        [     0,    581,  41911,  ...,     70,     47,      2]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])}, {'input_ids': tensor([[     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
        ...,
        [     0, 137399,     70,  ...,      1,      1,      1],
        [     0, 137399,     70,  ...,      1,      1,      1],
 


