In [9]:
from torch.utils.data import Dataset
import torch
import numpy as np


class KorSTSDatasets(Dataset):
    def __init__(self, dir_x, dir_y):
        self.x = np.load(dir_x, allow_pickle=True)
        self.y = np.load(dir_y, allow_pickle=True)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        sentence1, sentence2 = self.x[idx]
        data = torch.IntTensor(sentence1), torch.IntTensor(sentence2)
        label = int(float(self.y[idx]))
        return data, label

dataset = KorSTSDatasets("../KorSTS/train_x.npy", "../KorSTS/train_y.npy")

In [10]:
for data in dataset:
    print(data)
    break

((tensor([    2,  7046,  2116, 31389, 19521,  1513,  2062,    18,     3],
       dtype=torch.int32), tensor([    2,  7046,  2116, 31389, 19521,  1513,  2062,    18,     3],
       dtype=torch.int32)), 5)


In [10]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import random
from collections import defaultdict
from typing import List, Tuple


def collate_fn_(batch):
    # batch = list([((s1, s2), label), ((s1, s2), label), ...])
    s1_batches = []
    s2_batches = []
    labels = []
    for b in batch:
        data, label = b
        s1, s2 = data
        s1_batches.append(s1)
        s2_batches.append(s2)
        labels.append(label)
        
    s1_batch = pad_sequence(s1_batches, batch_first=True, padding_value=0)
    s2_batch = pad_sequence(s2_batches, batch_first=True, padding_value=0)
    return s1_batch, s2_batch, torch.Tensor(labels)

def bucketed_batch_indices(
    sentence_length: List[Tuple[int, int]],
    batch_size: int,
    max_pad_len: int
):
    batch_indices_list = []
    bucket = defaultdict(list)
    for idx, length in enumerate(sentence_length):
        s1_len, s2_len = length
        x = s1_len//max_pad_len
        y = s2_len//max_pad_len
        bucket[(x, y)].append(idx)
        if len(bucket[(x, y)]) == 64:
            batch_indices_list.append(bucket[(x, y)])
            bucket[(x, y)] = []
    for key in bucket.keys():
        batch_indices_list.append(bucket[key])

    random.shuffle(batch_indices_list)

    return batch_indices_list

sentence_length = []
for s1, s2 in dataset.x: # [(s1, s2), (s1, s2), ...]
    sentence_length.append((len(s1), len(s2)))

sampler = bucketed_batch_indices(sentence_length, batch_size=64, max_pad_len=10)
train_dataloader = DataLoader(dataset, collate_fn=collate_fn_, batch_sampler=sampler)

for data in train_dataloader:
    s1, s2, label = data
    print(s1.shape)
    print(s2.shape)
    print(label)
    break

torch.Size([64, 19])
torch.Size([64, 19])
tensor([1.2000, 4.6000, 4.6000, 0.0000, 2.6000, 3.8000, 3.6000, 3.8000, 3.6000,
        3.6000, 3.6000, 1.6000, 4.0000, 0.0000, 4.2000, 4.6000, 2.2500, 0.4000,
        0.4000, 4.4000, 3.8000, 2.0000, 3.6000, 2.2000, 5.0000, 0.2000, 3.6000,
        1.8000, 0.8000, 4.0000, 0.0000, 2.6000, 1.6000, 5.0000, 4.0000, 4.4000,
        3.6000, 4.0000, 2.4000, 2.6000, 4.0000, 0.2000, 3.2000, 3.0000, 4.2000,
        4.2500, 4.2000, 3.8000, 0.8000, 3.0000, 3.8000, 0.0000, 0.2000, 4.6000,
        3.0000, 4.4000, 5.0000, 3.6000, 0.0000, 4.0000, 1.4000, 3.2000, 4.2000,
        3.4000])
