In [51]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from datasets import load_dataset
from transformers import AutoTokenizer

In [52]:
BATCH_SIZE=2

In [76]:
class AllNLI(Dataset):
    def __init__(self, split):
        # super(AllNLI, Dataset).__init__()
        self.ds = load_dataset("sentence-transformers/all-nli", "pair-score")
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
        self.ds = self.ds[split]

    def __getitem__(self, idx):
        data = self.ds[idx]
        return data['sentence1'], data['sentence2'], data['score']

    def __len__(self):
        return len(self.ds)

    def collate_fn(self, batch):
        sen1 = [b[0] for b in batch]
        sen2 = [b[1] for b in batch]
        score = [b[2] for b in batch]

        tokens1 = self.tokenizer(sen1, return_tensors='pt', truncation=True, padding=True)
        tokens2 = self.tokenizer(sen2, return_tensors='pt', truncation=True, padding=True)
        score = torch.tensor(score).float()
        return tokens1, tokens2, score

In [77]:
train_ds = AllNLI('train')
dev_ds = AllNLI('dev')
test_ds = AllNLI('test')

In [78]:
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, num_workers=0, shuffle=True, collate_fn=train_ds.collate_fn)
dev_dl = DataLoader(dev_ds, batch_size=BATCH_SIZE, num_workers=2, shuffle=False, collate_fn=dev_ds.collate_fn)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, num_workers=2, shuffle=False, collate_fn=test_ds.collate_fn)

In [80]:
for batch in train_dl:
    sen1, sen2, score = batch
    print(sen1['input_ids'].shape)
    print(sen2['input_ids'].shape)
    print(score.shape)
    break

torch.Size([2, 20])
torch.Size([2, 16])
torch.Size([2])
