In [3]:
import torch
import torch.nn as nn
import numpy as np
from torchtext.data import Dataset, Example, Field
from torchtext.data import Iterator, BucketIterator

In [10]:
def generate_art_data(nr_classes=3):
    """ Generate random sentences from random words along with random target.
    """
    for i in range(N):
        rand_sent = np.random.choice(words, random_sent_lengths[i])
        rand_y = np.random.randint(0,2,nr_classes)
        yield (" ".join(rand_sent), rand_y)

In [11]:
def create_datasets(nr_classes):
    """ Create torchtext.data.Dataset from generated random data
    """
    data = generate_art_data(nr_classes)
    
    TEXT  = Field(sequential=True, tokenize=lambda x: x.split(), use_vocab=True, lower=True)
    
    LABEL = Field(sequential=False, use_vocab=False)

    trn_fields = [('text', TEXT), ('category', LABEL)]
    examples = list(
        map(lambda x: Example.fromlist(list(x), fields=trn_fields), 
        data))
    TEXT.build_vocab(data)
    dt_train = Dataset(examples, fields=trn_fields)
    trn, vld = dt_train.split(split_ratio=0.7)
    return (trn, vld, TEXT)

In [12]:
def create_iterators(num_of_batches=4, nr_classes=3):
    """ Create BucketIterator iterators from generated torchtext.data.Dataset s
    """
    trn, vld, T = create_datasets(nr_classes)
    print()
    print("Generated string: ", trn[0].text)
    print("Length = ", len(trn[0].text))
    
    #train_iter, val_iter = BucketIterator.splits((trn, vld), batch_sizes=batch_sizes, sort_key=lambda x: len(x.text), sort_within_batch=False, repeat=False)

    train_iter = Iterator(trn, batch_size=num_of_batches, sort_key=lambda x: len(x.text))
    val_iter = Iterator(vld, batch_size=num_of_batches, sort_key=lambda x: len(x.text))

    return train_iter, val_iter, T

In [13]:
N = 100 # Number of random sentences to generate
words = ['world', 'hello', 'country', 'moon', 'planet', 'earth']
random_sent_lengths=np.random.randint(1, 10, N)

In [18]:
random_sent_lengths

array([5, 4, 5, 9, 8, 8, 3, 8, 7, 9, 6, 2, 1, 1, 7, 9, 9, 9, 1, 3, 6, 4,
       9, 6, 9, 1, 5, 2, 8, 9, 4, 7, 7, 1, 2, 6, 6, 3, 9, 9, 7, 7, 6, 4,
       4, 9, 6, 5, 1, 1, 8, 4, 9, 4, 9, 8, 5, 4, 5, 9, 5, 3, 3, 6, 9, 2,
       3, 2, 4, 7, 2, 9, 9, 3, 3, 9, 1, 9, 3, 8, 5, 1, 2, 9, 3, 1, 3, 4,
       1, 7, 3, 5, 7, 9, 9, 7, 8, 9, 7, 7])

In [20]:
train_iter, val_iter, T = create_iterators(num_of_batches=10, nr_classes=5)


Generated string:  ['country']
Length =  1


In [17]:
for i, batch in enumerate(train_iter):
    if i == 0:
        print()
        print(batch.text.size(), batch.category.size())
        #print(batch.text)
        #print(batch.category)


torch.Size([9, 10]) torch.Size([10, 5])
