## Weighted Sampling with `torchnlp`'s BalancedSampler

In [None]:
import torch
import pandas as pd

In [181]:
train = pd.read_csv('./training.csv').to_dict(orient='record')
valid = pd.read_csv('./validation.csv').to_dict(orient='record')

In [182]:
train[:3]

[{'label': 'neg', 'text': '  dame ! ! ! 1'},
 {'label': 'neg', 'text': '  thesaurus world sale ended '},
 {'label': 'pos', 'text': '  ight i let they lil white boy know. hahaha. '}]

In [188]:
from torchnlp.encoders.text import WhitespaceEncoder
from torchnlp.encoders import LabelEncoder

text_encoder = WhitespaceEncoder(map(lambda x: x['text'], train))
label_encoder = LabelEncoder(map(lambda x: x['label'], train))
train_encoded = [{'text': text_encoder.encode(ex['text']), 'label': label_encoder.encode(ex['label'])} for ex in train]
valid_encoded = [{'text': text_encoder.encode(ex['text']), 'label': label_encoder.encode(ex['label'])} for ex in valid]

In [187]:
train_encoded[:3]

[{'text': tensor([5, 5, 6, 7, 7, 7, 8]), 'label': tensor(1)},
 {'text': tensor([ 5,  5,  9, 10, 11, 12,  5]), 'label': tensor(1)},
 {'text': tensor([ 5,  5, 13, 14, 15, 16, 17, 18, 19, 20, 21,  5]),
  'label': tensor(2)}]

In [122]:
from torch._six import container_abcs, string_classes, int_classes

default_collate_err_msg_format = (
    "torchtext_collate: batch must contain tensors, numpy arrays, numbers, "
    "dicts or lists; found {}")


def torchtext_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    #print('ii', isinstance(batch[0], torchtext.data.Batch))
    
    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(batch[0], torchtext.data.Batch):   # difference here
        relevant_keys = filter(lambda x: x[1] is not None, batch[0].dataset.fields.items())
        relevant_keys = list(map(lambda x: x[0], relevant_keys))
        for key in relevant_keys:
            print([getattr(d, key) for d in batch])
        
        output = {key: torchtext_collate([getattr(d, key) for d in batch]) for key in relevant_keys}
        return output
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        transposed = zip(*batch)
        return [torchtext_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

In [207]:
from torchnlp.samplers import BucketBatchSampler, BalancedSampler
from torchnlp.encoders.text import stack_and_pad_tensors

train_sampler = BalancedSampler(train_encoded, get_class=lambda x: x['label'])
train_batch_sampler = BucketBatchSampler(
    train_sampler, batch_size=10, drop_last=False, sort_key=lambda i: train_encoded[i]['text'].shape[0])

In [214]:
from torch.utils.data import DataLoader

def collate_fn(batch, train=True):
    """ list of tensors to a batch tensors """
    text_batch, _ = stack_and_pad_tensors([row['text'] for row in batch])
    label_batch = torch.stack([row['label'] for row in batch])

    # PyTorch RNN requires batches to be transposed for speed and integration with CUDA
    transpose = (lambda b: b.t_().squeeze(0).contiguous())

    return (transpose(text_batch), transpose(label_batch))

train_iterator = DataLoader(
    train_encoded,
    batch_sampler=train_batch_sampler,
    collate_fn=collate_fn,
    pin_memory=torch.cuda.is_available(),
    num_workers=0)

In [218]:
for i, (text, label) in enumerate(train_iterator):
    print(f'Batch {i}')
    print('text:')
    print(text)
    
    print('label:')
    print(label)
    print()

Batch 0
text:
tensor([[  5,   5,   5,   5,   5,   5,   5,   5,   5,   5],
        [  5, 500,   5, 122,   5, 517,   5, 214, 214, 254],
        [445,  27, 539,  49, 445, 295, 445,  49,  49, 271],
        [270,  70,  40, 377, 270, 518, 270, 371, 371, 267],
        [ 40,  29,  41,  69,  40, 295,  40,  41,  41, 238],
        [231,  25, 540,  25, 231,  14, 231, 372, 372,  46],
        [128, 115,   7, 378, 128,  33, 128, 373, 373,  25],
        [ 25, 327, 184, 379,  25, 164,  25,   7,   7, 272],
        [446,  56, 361, 380, 446, 519, 446,   7,   7, 273],
        [447, 528, 134, 381, 447,  25, 447,   7,   7, 173],
        [448,  73,  52,  69, 448, 520, 448, 374, 374, 133],
        [449, 529,  40, 382, 449, 521, 449,  35,  35, 274],
        [125,   7, 541, 383, 125, 522, 125, 375, 375, 214],
        [  5,   5,   5, 384,   5,   5,   5, 376, 376, 190],
        [  0,   0,   0,   0,   0,   0,   0,   5,   5, 275],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   5]])
label:
tensor([1, 2, 2, 1