# Тестирование возможности последовательной загрузки датасета!


In [2]:
from datasets import load_from_disk






  from .autonotebook import tqdm as notebook_tqdm


In [7]:
triplet_dataset = load_from_disk("msmarco-ru/triplets")


In [21]:
type(triplet_dataset)

datasets.dataset_dict.DatasetDict

In [8]:
len(triplet_dataset["train"])

39780811

In [10]:
triplet_dataset_iterable = iter(triplet_dataset["train"])



In [14]:
type(triplet_dataset["train"])

datasets.arrow_dataset.Dataset

In [18]:
triplet_dataset["train"].__next__()

AttributeError: 'Dataset' object has no attribute '__next__'

In [45]:

class DictDatasetWrapper:
    def __init__(self,data_dir:str):
        self.dataset = load_from_disk(data_dir)
        self.dataset_iterable = None 

    def __iter__(self):
        self.dataset_iterable = iter(self.dataset["train"])
        return self
    
    
    def __next__(self):
        new_line = next(self.dataset_iterable)
        return (new_line["query"].strip(),new_line["positive"].strip(),new_line["negative"].strip())
    
    def __len__(self):
        return len(self.dataset["train"])

In [46]:
from torch.utils.data import IterableDataset
import os
from tqdm import tqdm
from datasets import load_from_disk

class PairsDatasetPreLoad(IterableDataset):
    """
    dataset to iterate over a collection of pairs, format per line: q \t d_pos \t d_neg
    we preload everything in memory at init
    """

    def __init__(self, data_dir):
        super(PairsDatasetPreLoad, self).__init__()

        self.data_dir = data_dir
        self.id_style = "row_id"
        
        self.dataset = DictDatasetWrapper(data_dir)

        # количество экземпляров    
        self.nb_ex = len(self.dataset)

    def __iter__(self):
        return iter(self.dataset)

    # возврат количества экземплятров в файле
    def __len__(self):
        return self.nb_ex

    # нет такого в iterableDataset
    # def __getitem__(self, idx):
    #     return self.data_dict[idx]

In [47]:
dataset = PairsDatasetPreLoad("msmarco-ru/triplets")

In [53]:
import torch
from torch.utils.data.dataloader import DataLoader
from transformers import AutoTokenizer

def rename_keys(d, prefix):
    return {prefix + "_" + k: v for k, v in d.items()}

class DataLoaderWrapper(DataLoader):
    def __init__(self, tokenizer_type, max_length, **kwargs):
        self.max_length = max_length
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_type)
        super().__init__(collate_fn=self.collate_fn, **kwargs, pin_memory=True)

    def collate_fn(self, batch):
        raise NotImplementedError("must implement this method")


class SiamesePairsDataLoader(DataLoaderWrapper):
    """Siamese encoding (query and document independent)
    train mode (triplets)
    """

    def collate_fn(self, batch):
        """
        batch is a list of tuples, each tuple has 3 (text) items (q, d_pos, d_neg)
        """
        #q - кортеж запросов 
        #d_pos - кортеж d_pos
        #d_neg - кортеж d_neg
        q, d_pos, d_neg = zip(*batch)
        # обработка + обрезка + padding (нужен для берта!)
        q = self.tokenizer(list(q),
                           add_special_tokens=True,
                           padding="longest",  # pad to max sequence length in batch
                           truncation="longest_first",  # truncates to self.max_length
                           max_length=self.max_length,
                           return_attention_mask=True)
        
        d_pos = self.tokenizer(list(d_pos),
                               add_special_tokens=True,
                               padding="longest",  # pad to max sequence length in batch
                               truncation="longest_first",  # truncates to self.max_length
                               max_length=self.max_length,
                               return_attention_mask=True)
        
        d_neg = self.tokenizer(list(d_neg),
                               add_special_tokens=True,
                               padding="longest",  # pad to max sequence length in batch
                               truncation="longest_first",  # truncates to self.max_length
                               max_length=self.max_length,
                               return_attention_mask=True)
        
        # переименование ключей (добавление префикса ко всему, что возвращает токенизатор)
        sample = {**rename_keys(q, "q"), **rename_keys(d_pos, "pos"), **rename_keys(d_neg, "neg")} # множество словарей
        return {k: torch.tensor(v) for k, v in sample.items()}

In [67]:
dataloader = SiamesePairsDataLoader(dataset=dataset, batch_size=16,
                                                     shuffle=False,
                                                     #! num_workers must be 1!
                                                     num_workers=1,
                                                     tokenizer_type="ai-forever/ruBert-base",
                                                     max_length=256, drop_last=True)

In [75]:
dataloader_iter = iter(dataloader)

In [78]:
next(dataloader_iter)

{'q_input_ids': tensor([[   101,    785,   2602,  26993,  11908,  12866,    378,  40052,    110,
           55980,   1754,    676,    102,      0],
         [   101,   4277,  10640,  78837,    102,      0,      0,      0,      0,
               0,      0,      0,      0,      0],
         [   101,    693,    113,    776,    377,  13951,    376,    102,      0,
               0,      0,      0,      0,      0],
         [   101,    693,   2432,  36298,  32207,   1914,  27401,    102,      0,
               0,      0,      0,      0,      0],
         [   101,   1079,    797, 118264,  42699,   5917,  12495,  68542,    102,
               0,      0,      0,      0,      0],
         [   101,    693,   9119,    118,   1024,  41801,   1415,  73735,  13648,
             107,  12560,    113,  27615,    102],
         [   101,   1079,    797,    118,   1721,  24242,   9647,  96685,    110,
           14238,  23085,  59145,    102,      0],
         [   101,  50690,   7580,    133,  31742,    8