In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataclasses import dataclass
from functools import partial
from itertools import cycle
import os
from pathlib import Path
import re
import sys
from typing import Dict, List, Optional, Tuple, Any

if '..' not in sys.path: sys.path.append('..')

import numpy as np

from datasets import load_dataset
from transformers import AutoTokenizer, PreTrainedTokenizer

import torch
import torch.distributed as dist
from torch import nn
from torch.utils.data import Dataset, DataLoader, DistributedSampler


# from mllm.train.utils import WordToks

## Dataset loading

In [3]:
# DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
DATA_PATH = Path('Q:/data')
print(f"DATA_PATH: {DATA_PATH}")

DATA_PATH: Q:\data


In [4]:
# ds = load_dataset('wikipedia', '20200501.en',  beam_runner='DirectRunner', cache_dir=str(DATA_PATH), trust_remote_code=True)
ds = load_dataset('wikipedia', '20220301.en', cache_dir=str(DATA_PATH), trust_remote_code=True)
# ds = load_dataset('wikipedia', '20220301.en', cache_dir=str(DATA_PATH))

Loading dataset shards:   0%|          | 0/41 [00:00<?, ?it/s]

In [5]:
list(ds.keys())

['train']

In [6]:
dst = ds['train']

In [7]:
len(dst)

6458670

In [9]:
dst[6078421]

{'title': 'Overload (novel)',
 'text': 'Overload (1979) is a novel by Arthur Hailey, concerning the electricity production industry in California and the activities of the employees and others involved with Golden State Power and Light, a fictional California public service company. The plot follows many of the issues of the day, including race relations, corporate politics, business ethics, terrorism and journalism. (Hailey would later explore (television) journalism in another novel, The Evening News.)\n\nPlot Synopsis\nThe novel is described from the point of view of vice-president of Golden State Power and Light, Nimrod "Nim" Goldman, who, despite being married, tends to be somewhat of a Lothario and has many extramarital affairs. The geographic area of service of the fictional electric utility, Golden State Power and Light, matches the actual Northern California footprint of the real-life Pacific Gas and Electric Company.\n\nGolden State Power and Light is a public utility, supply

In [10]:
dst.column_names

['title', 'text']

## Tokenization

In [8]:
pretrained_model_name = 'bert-base-uncased'
tkz = AutoTokenizer.from_pretrained(pretrained_model_name)
tkz

BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [30]:
item = dst[20]
item

{'id': '569',
 'url': 'https://en.wikipedia.org/wiki/Anthropology',
 'title': 'Anthropology',
 'text': 'Anthropology is the scientific study of humanity, concerned with human behavior, human biology, cultures, societies, and linguistics, in both the present and past, including past human species. Social anthropology studies patterns of behaviour, while cultural anthropology studies cultural meaning, including norms and values. A portmanteau sociocultural anthropology is commonly used today. Linguistic anthropology studies how language influences social life. Biological or physical anthropology studies the biological development of humans.\n\nArchaeological anthropology, often termed as \'anthropology of the past\', studies human activity through investigation of physical evidence. It is considered a branch of anthropology in North America and Asia, while in Europe archaeology is viewed as a discipline in its own right or grouped under other related disciplines, such as history.\n\nEtym

In [32]:
title, text = item['title'], item['text'][:150]
print(f'{title}. Text len: {len(text)}')
inp_toks = tkz.tokenize(text)
print(inp_toks)

Anthropology. Text len: 150
['anthropology', 'is', 'the', 'scientific', 'study', 'of', 'humanity', ',', 'concerned', 'with', 'human', 'behavior', ',', 'human', 'biology', ',', 'cultures', ',', 'societies', ',', 'and', 'linguistics', ',', 'in', 'both', 'the', 'pre', '##s']


In [33]:
res1 = tkz(text, is_split_into_words=False)
print(res1)

{'input_ids': [101, 12795, 2003, 1996, 4045, 2817, 1997, 8438, 1010, 4986, 2007, 2529, 5248, 1010, 2529, 7366, 1010, 8578, 1010, 8384, 1010, 1998, 15397, 1010, 1999, 2119, 1996, 3653, 2015, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


In [34]:
res2 = tkz(text.split(), is_split_into_words=True)
print(res2)

{'input_ids': [101, 12795, 2003, 1996, 4045, 2817, 1997, 8438, 1010, 4986, 2007, 2529, 5248, 1010, 2529, 7366, 1010, 8578, 1010, 8384, 1010, 1998, 15397, 1010, 1999, 2119, 1996, 3653, 2015, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


In [35]:
res1 == res2

True

In [36]:
res3 = tkz(text.split(), add_special_tokens=False)
print(text.split())
res3


['Anthropology', 'is', 'the', 'scientific', 'study', 'of', 'humanity,', 'concerned', 'with', 'human', 'behavior,', 'human', 'biology,', 'cultures,', 'societies,', 'and', 'linguistics,', 'in', 'both', 'the', 'pres']


{'input_ids': [[12795], [2003], [1996], [4045], [2817], [1997], [8438, 1010], [4986], [2007], [2529], [5248, 1010], [2529], [7366, 1010], [8578, 1010], [8384, 1010], [1998], [15397, 1010], [1999], [2119], [1996], [3653, 2015]], 'token_type_ids': [[0], [0], [0], [0], [0], [0], [0, 0], [0], [0], [0], [0, 0], [0], [0, 0], [0, 0], [0, 0], [0], [0, 0], [0], [0], [0], [0, 0]], 'attention_mask': [[1], [1], [1], [1], [1], [1], [1, 1], [1], [1], [1], [1, 1], [1], [1, 1], [1, 1], [1, 1], [1], [1, 1], [1], [1], [1], [1, 1]]}

In [37]:
words3 = tkz.batch_decode(res3.input_ids)
words3

['anthropology',
 'is',
 'the',
 'scientific',
 'study',
 'of',
 'humanity,',
 'concerned',
 'with',
 'human',
 'behavior,',
 'human',
 'biology,',
 'cultures,',
 'societies,',
 'and',
 'linguistics,',
 'in',
 'both',
 'the',
 'pres']

In [38]:
s1 = tkz.decode(res1.input_ids, skip_special_tokens=True)
s2 = ' '.join(words3)
print(s2)
s1 == s2

anthropology is the scientific study of humanity, concerned with human behavior, human biology, cultures, societies, and linguistics, in both the pres


True

In [40]:
pat = re.compile(r'\b')
parts = pat.split(text)
''.join(parts) == text

True

In [41]:
parts_nospace = [p.lower().strip() for p in parts]
parts_nospace = filter(None, parts_nospace)
parts_nospace = list(parts_nospace)
parts_nospace

['anthropology',
 'is',
 'the',
 'scientific',
 'study',
 'of',
 'humanity',
 ',',
 'concerned',
 'with',
 'human',
 'behavior',
 ',',
 'human',
 'biology',
 ',',
 'cultures',
 ',',
 'societies',
 ',',
 'and',
 'linguistics',
 ',',
 'in',
 'both',
 'the',
 'pres']

In [42]:
s2 = ' '.join([p.lower() for p in parts_nospace])
s1 == s2

False

In [43]:
s1

'anthropology is the scientific study of humanity, concerned with human behavior, human biology, cultures, societies, and linguistics, in both the pres'

In [44]:
s2

'anthropology is the scientific study of humanity , concerned with human behavior , human biology , cultures , societies , and linguistics , in both the pres'

In [45]:
s = s1 + ' abc## ## ####xxx'
toks_ids = tkz(s, add_special_tokens=False).input_ids
toks_str = tkz.convert_ids_to_tokens(toks_ids)
assert len(toks_ids) == len(toks_str)
ids_to_str = [(tid, tis) for tid, tis in zip(toks_ids, toks_str)]
ids_to_str

[(12795, 'anthropology'),
 (2003, 'is'),
 (1996, 'the'),
 (4045, 'scientific'),
 (2817, 'study'),
 (1997, 'of'),
 (8438, 'humanity'),
 (1010, ','),
 (4986, 'concerned'),
 (2007, 'with'),
 (2529, 'human'),
 (5248, 'behavior'),
 (1010, ','),
 (2529, 'human'),
 (7366, 'biology'),
 (1010, ','),
 (8578, 'cultures'),
 (1010, ','),
 (8384, 'societies'),
 (1010, ','),
 (1998, 'and'),
 (15397, 'linguistics'),
 (1010, ','),
 (1999, 'in'),
 (2119, 'both'),
 (1996, 'the'),
 (3653, 'pre'),
 (2015, '##s'),
 (5925, 'abc'),
 (1001, '#'),
 (1001, '#'),
 (1001, '#'),
 (1001, '#'),
 (1001, '#'),
 (1001, '#'),
 (1001, '#'),
 (1001, '#'),
 (22038, 'xx'),
 (2595, '##x')]

In [46]:
tkz.decode(toks_ids)

'anthropology is the scientific study of humanity, concerned with human behavior, human biology, cultures, societies, and linguistics, in both the pres abc # # # # # # # # xxx'

In [None]:
wt = WordToks(tkz, s, max_tgt_len_freq=0.2, max_tgt_len=10)
inp_str, inp_masked_str, tgt_str = tkz.decode(wt.inp_toks), tkz.decode(wt.inp_masked_toks), tkz.decode(wt.tgt_toks)
print(inp_str)
print(inp_masked_str)
print(tgt_str)

anthropology is the scientific study of humanity, concerned with human behavior, < | cite _ begin | > human biology, cultures < | cite _ end | >, societies, and linguistics, in both the pres abc # # # # # # # # xxx
anthropology is the scientific study of humanity, concerned with human behavior, < | cite _ begin | > [MASK] [MASK] [MASK] [MASK] < | cite _ end | >, societies, and linguistics, in both the pres abc # # # # # # # # xxx
human biology, cultures


## Mapping and processing the dataset

In [5]:
tkz = AutoTokenizer.from_pretrained('bert-base-uncased', use_fast=True)
tkz

BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [45]:
n = 5
ds_sub = ds['train'].select(range(n))
len(ds_sub), ds_sub.column_names

(5, ['id', 'url', 'title', 'text'])

In [69]:
def tokenize_item(item, tokenizer: PreTrainedTokenizer):
    text = item['text']
    toks = tokenizer(text, add_special_tokens=False).input_ids
    return {
        **item,
        'toks': toks,
        'toks_len': len(toks),
    }


class MaskDataset:
    def __init__(self, pad_token_id: int = 0, max_seq_len: int = 512, mask_token_id: int = 103, min_mask_toks: int = 0, max_mask_toks: int = 10):
        self.pad_token_id = pad_token_id
        self.max_seq_len = max_seq_len
        self.mask_token_id = mask_token_id
        self.min_mask_toks = min_mask_toks
        self.max_mask_toks = max_mask_toks
        self.inds = np.arange(self.max_seq_len)

    def extract_masked_input(self, item: Dict) -> Dict:
        cur_len = item['toks_len']
        cur_len = min(cur_len, 9)
        max_seq_len = min(cur_len, self.max_seq_len)
        input_ids = item['toks']
        if max_seq_len < cur_len:
            ind_off_max = cur_len - max_seq_len + 1
            ind_off_max = min(ind_off_max, 3)
            ind_off = np.random.randint(0, ind_off_max)
            input_ids = input_ids[ind_off:ind_off + max_seq_len]
        input_ids_masked = input_ids
        mask_toks_num = np.random.randint(self.min_mask_toks, self.max_mask_toks + 1)
        mask_toks_num = min(mask_toks_num, max_seq_len // 2)
        if mask_toks_num > 0:
            mask_inds = np.random.choice(self.inds[:max_seq_len], size=mask_toks_num, replace=False)
            input_ids_masked = np.array(input_ids)
            input_ids_masked[mask_inds] = self.mask_token_id
    
        return {
            **item,
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'input_ids_masked': torch.tensor(input_ids_masked, dtype=torch.long),
        }


class DatasetMaskCollator:
    def __init__(self, pad_token_id: int = 0, max_seq_len: int = 512, mask_token_id: int = 103, min_mask_toks: int = 0, max_mask_toks: int = 10):
        self.pad_token_id = pad_token_id
        self.max_seq_len = max_seq_len
        self.mask_token_id = mask_token_id
        self.min_mask_toks = min_mask_toks
        self.max_mask_toks = max_mask_toks
        self.inds = np.arange(self.max_seq_len)

    def extract_masked_input(self, item: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
        print('!!!')
        cur_len = item['toks_len']
        max_seq_len = min(cur_len, self.max_seq_len)
        input_ids = item['toks']
        if max_seq_len < cur_len:
            ind_off_max = cur_len - max_seq_len + 1
            ind_off_max = min(ind_off_max, 3)
            ind_off = np.random.randint(0, ind_off_max)
            input_ids = input_ids[ind_off:ind_off + max_seq_len]
        input_ids_masked = input_ids
        mask_toks_num = np.random.randint(self.min_mask_toks, self.max_mask_toks + 1)
        mask_toks_num = min(mask_toks_num, max_seq_len // 2)
        if mask_toks_num > 0:
            mask_inds = np.random.choice(self.inds[:max_seq_len], size=mask_toks_num, replace=False)
            input_ids_masked = np.array(input_ids)
            input_ids_masked[mask_inds] = self.mask_token_id
        
        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(input_ids_masked, dtype=torch.long)
    
    def collate(self, batch) -> Tuple[torch.Tensor, torch.Tensor]:
        input_ids_batch = []
        input_ids_masked_batch = []
        for item in batch:
            input_ids, input_ids_masked = self.extract_masked_input(item)
            input_ids_batch.append(input_ids)
            input_ids_masked_batch.append(input_ids_masked)
        input_ids = nn.utils.rnn.pad_sequence(input_ids_batch, batch_first=True, padding_value=self.pad_token_id)
        input_ids_masked = nn.utils.rnn.pad_sequence(input_ids_masked_batch, batch_first=True, padding_value=self.pad_token_id)
        return input_ids, input_ids_masked


In [70]:
dataset = ds_sub.map(tokenize_item, fn_kwargs={'tokenizer': tkz})
# dataset = dataset.map(MaskDataset(pad_token_id=tkz.pad_token_id, max_seq_len=5, mask_token_id=tkz.mask_token_id, min_mask_toks=0, max_mask_toks=10).extract_masked_input)

In [71]:
dataset.column_names

['id', 'url', 'title', 'text', 'toks', 'toks_len']

In [None]:
batch_size = 3
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group(backend='gloo', rank=0, world_size=1)

In [117]:
class MaskedDataset(Dataset):
    def __init__(self, dataset: Dataset, pad_token_id: int = 0, max_seq_len: int = 512, mask_token_id: int = 103, min_mask_toks: int = 0, max_mask_toks: int = 10):
        self.dataset = dataset.map(tokenize_item, fn_kwargs={'tokenizer': tkz})
        self.len = len(dataset)
        self.pad_token_id = pad_token_id
        self.max_seq_len = max_seq_len
        self.mask_token_id = mask_token_id
        self.min_mask_toks = min_mask_toks
        self.max_mask_toks = max_mask_toks
        self.inds = np.arange(self.max_seq_len)

    def __len__(self):
        return self.len

    def extract_masked_input(self, item: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
        print('!!!')
        cur_len = item['toks_len']
        max_seq_len = min(cur_len, self.max_seq_len)
        input_ids = item['toks']
        if max_seq_len < cur_len:
            ind_off_max = cur_len - max_seq_len + 1
            ind_off_max = min(ind_off_max, 3)
            ind_off = np.random.randint(0, ind_off_max)
            input_ids = input_ids[ind_off:ind_off + max_seq_len]
        input_ids_masked = input_ids
        mask_toks_num = np.random.randint(self.min_mask_toks, self.max_mask_toks + 1)
        mask_toks_num = min(mask_toks_num, max_seq_len // 2)
        if mask_toks_num > 0:
            mask_inds = np.random.choice(self.inds[:max_seq_len], size=mask_toks_num, replace=False)
            input_ids_masked = np.array(input_ids)
            input_ids_masked[mask_inds] = self.mask_token_id
        
        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(input_ids_masked, dtype=torch.long)

    def __getitem__(self, idx: int) -> dict[str, Any]:
        idx = idx % len(self.dataset)
        item = self.dataset[idx]
        input_ids, input_ids_masked = self.extract_masked_input(item)
        return {
            **item,
            'input_ids': input_ids,
            'input_ids_masked': input_ids_masked,
        }
    

def collate_masked_batch(batch: List[Dict[str, Any]]) -> Tuple[torch.Tensor, torch.Tensor]:
    input_ids_batch = []
    input_ids_masked_batch = []
    for item in batch:
        # print(item)
        input_ids = item['input_ids']
        input_ids_masked = item['input_ids_masked']
        input_ids_batch.append(input_ids)
        input_ids_masked_batch.append(input_ids_masked)
    input_ids = nn.utils.rnn.pad_sequence(input_ids_batch, batch_first=True, padding_value=tkz.pad_token_id)
    input_ids_masked = nn.utils.rnn.pad_sequence(input_ids_masked_batch, batch_first=True, padding_value=tkz.pad_token_id)
    return input_ids, input_ids_masked


In [118]:
dataset = MaskedDataset(ds_sub,
    pad_token_id=tkz.pad_token_id,
    max_seq_len=5,
    mask_token_id=tkz.mask_token_id,
    min_mask_toks=0,
    max_mask_toks=10
)

In [90]:


train_sampler = DistributedSampler(dataset)

# Create a DataLoader with the DistributedSampler
train_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,  # Don't shuffle for distributed training (handled by the sampler)
    sampler=train_sampler,
    num_workers=0, # Sets the number of subprocesses to load data in parallel, avoiding I/O bottlenecks
    pin_memory=True, # Copies data to pinned memory, which is faster to transfer to the GPU.
    # prefetch_factor=2, # Sets the number of batches that each worker will prepare in advance.
    collate_fn=collate_masked_batch,
)


In [91]:
for i, (inp, tgt) in enumerate(cycle(train_loader)):
    print(i, inp.shape, tgt.shape)
    print(inp)
    print(tgt)
    if i == 6:
        break


!!!
!!!
!!!
0 torch.Size([3, 5]) torch.Size([3, 5])
tensor([[ 1006,  1007,  2003,  1037,  2110],
        [ 1037, 11265, 10976, 24844, 18349],
        [ 2003,  1037, 11265, 10976, 24844]])
tensor([[ 1006,   103,  2003,  1037,   103],
        [ 1037,   103, 10976, 24844,   103],
        [ 2003,  1037, 11265, 10976, 24844]])
!!!
!!!
!!!
1 torch.Size([3, 5]) torch.Size([3, 5])
tensor([[2964, 2003, 1037, 2576, 4695],
        [1010, 2030, 1037, 1010, 2003],
        [1007, 2003, 1037, 2110, 1999]])
tensor([[2964,  103, 1037,  103, 4695],
        [1010,  103, 1037,  103, 2003],
        [ 103,  103, 1037, 2110, 1999]])
!!!
!!!
!!!
2 torch.Size([3, 5]) torch.Size([3, 5])
tensor([[ 2003,  1037, 11265, 10976, 24844],
        [ 1037,  1010,  2030,  1037,  1010],
        [28759,  1006,  1025,  1007,  2003]])
tensor([[  103,  1037,   103, 10976, 24844],
        [  103,  1010,  2030,   103,  1010],
        [28759,   103,   103,  1007,  2003]])
!!!
!!!
!!!
3 torch.Size([3, 5]) torch.Size([3, 5])
tensor

In [119]:
def create_dataloader_iter(dataset: Dataset, batch_size: int, num_workers: int, collate_fn: Any):
    while True:
        print(f'Generate Dataloader')
        train_sampler = DistributedSampler(dataset)

        # Create a DataLoader with the DistributedSampler
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,  # Don't shuffle for distributed training (handled by the sampler)
            sampler=train_sampler,
            num_workers=num_workers, # Sets the number of subprocesses to load data in parallel, avoiding I/O bottlenecks
            pin_memory=True, # Copies data to pinned memory, which is faster to transfer to the GPU.
            # prefetch_factor=2, # Sets the number of batches that each worker will prepare in advance.
            collate_fn=collate_fn,
        )
        
        for item in dataloader:
            yield item


In [120]:
dataloader = create_dataloader_iter(
    dataset,
    batch_size=batch_size,
    num_workers=0,
    collate_fn=collate_masked_batch,
)

In [121]:
for i, (inp, tgt) in enumerate(dataloader):
    print(i, inp.shape, tgt.shape)
    print(inp)
    print(tgt)
    if i == 6:
        break

Generate Dataloader
!!!
!!!
!!!
0 torch.Size([3, 5]) torch.Size([3, 5])
tensor([[ 6041,  1006,  1007,  2003,  1037],
        [ 9617, 11140,  2964,  2003,  1037],
        [ 2003,  1037, 11265, 10976, 24844]])
tensor([[ 6041,   103,  1007,  2003,   103],
        [ 9617,   103,  2964,   103,  1037],
        [ 2003,   103,   103, 10976, 24844]])
!!!
!!!
1 torch.Size([2, 5]) torch.Size([2, 5])
tensor([[ 1037,  1010,  2030,  1037,  1010],
        [ 2632, 28759,  1006,  1025,  1007]])
tensor([[  103,  1010,  2030,  1037,  1010],
        [  103, 28759,  1006,  1025,  1007]])
Generate Dataloader
!!!
!!!
!!!
2 torch.Size([3, 5]) torch.Size([3, 5])
tensor([[ 1007,  2003,  1037,  2110,  1999],
        [ 2964,  2003,  1037,  2576,  4695],
        [19465,  2003,  1037, 11265, 10976]])
tensor([[ 1007,   103,  1037,  2110,   103],
        [ 2964,  2003,  1037,   103,  4695],
        [19465,  2003,   103,   103, 10976]])
!!!
!!!
3 torch.Size([2, 5]) torch.Size([2, 5])
tensor([[ 2030,  1037,  1010,  200

In [16]:
n = 10
# ds_sub = ds['train'].shuffle(seed=42)[:n]
ds_sub = ds['train'].shuffle(seed=42).select(range(n))

In [17]:
for i in range(n):
    print(ds_sub[i]['title'])

William Whitehouse
Cheryl S. McWatters
Lithuanian Lands Militia
Mizoram–Manipur–Kachin rain forests
Salesbury
Maurice Eustace (Lord Chancellor)
Heleen Mees
Diogo Dalot
CBCM
Okka Rau


In [19]:
ds1 = ds_sub.select(range(6))
ds2 = ds_sub.select(range(6, n, 1))

In [20]:
for i in range(len(ds1)):
    print(ds1[i]['title'])
print('---')
for i in range(len(ds2)):
    print(ds2[i]['title'])

William Whitehouse
Cheryl S. McWatters
Lithuanian Lands Militia
Mizoram–Manipur–Kachin rain forests
Salesbury
Maurice Eustace (Lord Chancellor)
---
Heleen Mees
Diogo Dalot
CBCM
Okka Rau


In [8]:
from torch.utils.data import TensorDataset, DataLoader

class SimpleCustomBatch:
    def __init__(self, data):
        print(len(data), data)
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = -torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(type(sample), sample.inp.shape, sample.tgt.shape)
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())

2 [(tensor([0., 1., 2., 3., 4.]), tensor([-0., -1., -2., -3., -4.])), (tensor([5., 6., 7., 8., 9.]), tensor([-5., -6., -7., -8., -9.]))]
<class '__main__.SimpleCustomBatch'> torch.Size([2, 5]) torch.Size([2, 5])
False
False
2 [(tensor([10., 11., 12., 13., 14.]), tensor([-10., -11., -12., -13., -14.])), (tensor([15., 16., 17., 18., 19.]), tensor([-15., -16., -17., -18., -19.]))]
<class '__main__.SimpleCustomBatch'> torch.Size([2, 5]) torch.Size([2, 5])
False
False
2 [(tensor([20., 21., 22., 23., 24.]), tensor([-20., -21., -22., -23., -24.])), (tensor([25., 26., 27., 28., 29.]), tensor([-25., -26., -27., -28., -29.]))]
<class '__main__.SimpleCustomBatch'> torch.Size([2, 5]) torch.Size([2, 5])
False
False
2 [(tensor([30., 31., 32., 33., 34.]), tensor([-30., -31., -32., -33., -34.])), (tensor([35., 36., 37., 38., 39.]), tensor([-35., -36., -37., -38., -39.]))]
<class '__main__.SimpleCustomBatch'> torch.Size([2, 5]) torch.Size([2, 5])
False
False
2 [(tensor([40., 41., 42., 43., 44.]), tenso

