## Weighted Sampling with `torchnlp`'s BalancedSampler

In [1]:
import torch
import pandas as pd

In [2]:
train = pd.read_csv('./training.csv').to_dict(orient='record')
valid = pd.read_csv('./validation.csv').to_dict(orient='record')

In [3]:
train[:3]

[{'label': 'neg', 'text': '  dame ! ! ! 1'},
 {'label': 'neg', 'text': '  thesaurus world sale ended '},
 {'label': 'pos', 'text': '  ight i let they lil white boy know. hahaha. '}]

In [4]:
from torchnlp.encoders.text import WhitespaceEncoder
from torchnlp.encoders import LabelEncoder

text_encoder = WhitespaceEncoder(map(lambda x: x['text'], train))
label_encoder = LabelEncoder(map(lambda x: x['label'], train))
train_encoded = [{'text': text_encoder.encode(ex['text']), 'label': label_encoder.encode(ex['label'])} for ex in train]
valid_encoded = [{'text': text_encoder.encode(ex['text']), 'label': label_encoder.encode(ex['label'])} for ex in valid]

In [5]:
train_encoded[:3]

[{'text': tensor([5, 5, 6, 7, 7, 7, 8]), 'label': tensor(1)},
 {'text': tensor([ 5,  5,  9, 10, 11, 12,  5]), 'label': tensor(1)},
 {'text': tensor([ 5,  5, 13, 14, 15, 16, 17, 18, 19, 20, 21,  5]),
  'label': tensor(2)}]

In [6]:
from torchnlp.samplers import BucketBatchSampler, BalancedSampler
from torchnlp.encoders.text import stack_and_pad_tensors

train_sampler = BalancedSampler(train_encoded, get_class=lambda x: x['label'])
train_batch_sampler = BucketBatchSampler(
    train_sampler, batch_size=10, drop_last=False, sort_key=lambda i: train_encoded[i]['text'].shape[0])

In [7]:
from torch.utils.data import DataLoader

def collate_fn(batch, train=True):
    """ list of tensors to a batch tensors """
    text_batch, _ = stack_and_pad_tensors([row['text'] for row in batch])
    label_batch = torch.stack([row['label'] for row in batch])

    # PyTorch RNN requires batches to be transposed for speed and integration with CUDA
    transpose = (lambda b: b.t_().squeeze(0).contiguous())

    return (transpose(text_batch), transpose(label_batch))

train_iterator = DataLoader(
    train_encoded,
    batch_sampler=train_batch_sampler,
    collate_fn=collate_fn,
    pin_memory=torch.cuda.is_available(),
    num_workers=0)

In [8]:
for i, (text, label) in enumerate(train_iterator):
    print(f'Batch {i}')
    print('text:')
    print(text)
    
    print('label:')
    print(label)
    print()

Batch 0
text:
tensor([[  5,   5,   5,   5,   5,   5,   5,   5,   5,   5],
        [304,   5,  60, 369,   5,   5,   5, 357, 184,  72],
        [196,  40,  61,  24,  40,  77,   6,  48, 612,  73],
        [305,  41,  62,  49,  41,  85,   7, 358, 100,  24],
        [  5,  42,  56, 370,  42, 157,   7,  56, 203,  74],
        [  0,   5,  63,   5,   5, 196,   7, 359, 613,  75],
        [  0,   0,   0,   0,   0,   5,   8,   5,   5,  76],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   5]])
label:
tensor([2, 2, 2, 2, 2, 2, 1, 1, 1, 2])

Batch 1
text:
tensor([[  5,   5,   5,   5,   5,   5,   5,   5,   5,   5],
        [  5,  14,   5, 467, 467,   5, 327, 327,   5,   5],
        [ 14, 458,  14, 468, 468, 153, 197, 197, 237, 440],
        [614,  27, 614, 469, 469, 154, 107, 107, 207,  40],
        [615,  28, 615, 470, 470,  14,  46,  46, 350, 507],
        [231, 459, 231,  83,  83, 155,  25,  25, 210, 178],
        [616, 122, 616, 471, 471,  25, 328, 328,  85, 508],
        [213,  40, 213, 