# Batching Tests

In [1]:
import sys
sys.path.append('../..')

from src.data_loader import DataLoader
from src.pickle_loader import PickleLoader

from src.preprocessing.dictionary import Dictionary
from src.preprocessing.batching import Batcher

BATCH_SIZE = 200
WINDOW_SIZE = 4
BPE_OPERATIONS = 5000


### Creating exercise solutions

In [3]:
# load indexed data and dictionary
source_indexed_data = PickleLoader.load('../../data/data_v2/multi30k.de.5000_BPE.indexed.pickle')
target_indexed_data = PickleLoader.load('../../data/data_v2/multi30k.en.5000_BPE.indexed.pickle')

source_dictionary = PickleLoader.load('../../data/dictionaries/dict_DE_5000.pkl')
target_dictionary = PickleLoader.load('../../data/dictionaries/dict_EN_5000.pkl')

dict_keys(['<UNK>', '<s>', '</s>', 'algen', 'turners', 'parkartigen', 'hill', 'house', 'schreibmaschine', 'eiswaffel', 'brünettem', 'broad', 'beschwerden', 'neongrün', 'pflanzung', 'hochstuhl', 'kite@@', 'ferien@@', 'spatel', 'gläsernen', 'kapuzenteil', 'finden', 'golf', 'ie-@@', 'me@@', 'vielerlei', '-bh', 'valen', 'schülers', 'mikrofon', 'telefoniert', 'gel@@', 'rustikal', 'schwester', 'pie', 'konzert@@', 'uren@@', 'hochhält', 'ur', 'stein@@', 'plastiktisch', 'keramikteile', 'richtige', 'feuerrotem', 'auschecken', 'neben', 'steh@@', 'schnarrtrommel', 'messern', 'einkäuferin', 'aufregendes', 'eisschnell@@', 'nachspringt', 'land@@', 'sattel@@', 'bodies', 'kopfbedeckung', 'gelb-@@', 'abendessen', 'mendes', 'vatikan', 'kokosnuss', 'gepunktet', ')', 'wandschmuck', 'nächstes', 'postsendungen', 'doppelüber@@', 'abzuwerfen', 'weisen', 'flachrelie@@', 'ganzen', 'farbigen', 'tattoo-@@', 'elektrogitarre', 'zerbrochener', 'odschaukel', 'cracker', 'metallbauwerk', 'verschwitzter', 'mitten', 'mend

In [4]:
# batch lines 1100 to 1200
limited_batcher = Batcher(source_indexed_data[1100:1200], target_indexed_data[1100:1200], batch_size=BATCH_SIZE, window=WINDOW_SIZE, torch_device="cpu")
limited_batcher.batch()
limited_batches = limited_batcher.getBatches()

Creating batches: 100%|██████████| 100/100 [00:00<00:00, 1904.11it/s, Batch=8/8, Row=16/200]


In [5]:
# print indexed batches
for S, T, L in limited_batches:
    for s, t, l in zip(S, T, L):
        print("[", end=" ")
        for i in s.tolist():
            print(f"{i}\t", end=" ")
        print("]", end=" ")

        print("[", end=" ")
        for i in t.tolist():
            print(f"{i}\t", end=" ")
        print("]", end=" ")

        print("[", end=" ")
        for i in l.tolist():
            print(f"{i}", end=" ")
        print("]")
    print("\n\n\n")

[ 0.0	 0.0	 0.0	 0.0	 0.0	 3369.0	 14563.0	 1855.0	 13221.0	 ] [ 0.0	 0.0	 0.0	 0.0	 ] [ 1391.0 ]
[ 0.0	 0.0	 0.0	 0.0	 3369.0	 14563.0	 1855.0	 13221.0	 16024.0	 ] [ 0.0	 0.0	 0.0	 1391.0	 ] [ 426.0 ]
[ 0.0	 0.0	 0.0	 3369.0	 14563.0	 1855.0	 13221.0	 16024.0	 13075.0	 ] [ 0.0	 0.0	 1391.0	 426.0	 ] [ 9316.0 ]
[ 0.0	 0.0	 3369.0	 14563.0	 1855.0	 13221.0	 16024.0	 13075.0	 14334.0	 ] [ 0.0	 1391.0	 426.0	 9316.0	 ] [ 513.0 ]
[ 0.0	 3369.0	 14563.0	 1855.0	 13221.0	 16024.0	 13075.0	 14334.0	 1914.0	 ] [ 1391.0	 426.0	 9316.0	 513.0	 ] [ 4114.0 ]
[ 3369.0	 14563.0	 1855.0	 13221.0	 16024.0	 13075.0	 14334.0	 1914.0	 7966.0	 ] [ 426.0	 9316.0	 513.0	 4114.0	 ] [ 1391.0 ]
[ 14563.0	 1855.0	 13221.0	 16024.0	 13075.0	 14334.0	 1914.0	 7966.0	 1.0	 ] [ 9316.0	 513.0	 4114.0	 1391.0	 ] [ 513.0 ]
[ 1855.0	 13221.0	 16024.0	 13075.0	 14334.0	 1914.0	 7966.0	 1.0	 1.0	 ] [ 513.0	 4114.0	 1391.0	 513.0	 ] [ 10202.0 ]
[ 13221.0	 16024.0	 13075.0	 14334.0	 1914.0	 7966.0	 1.0	 1.0	 1.0	 ] [ 4114.

In [13]:
# print string batches
for S, T, L in limited_batches:
    for s, t, l in zip(S, T, L):
        print("[", end=" ")
        for i in s.tolist():
            print(source_dictionary.getToken(i), end=" ")
        print("]", end=" ")

        print("[", end=" ")
        for i in t.tolist():
            print(target_dictionary.getToken(i), end=" ")
        print("]", end=" ")

        print("[", end=" ")
        for i in l.tolist():
            print(target_dictionary.getToken(i), end=" ")
        print("]")
    print("\n\n\n")

[ <s> <s> <s> <s> <s> ein mädchen spielt in ] [ <s> <s> <s> <s> ] [ a ]
[ <s> <s> <s> <s> ein mädchen spielt in einer ] [ <s> <s> <s> a ] [ girl ]
[ <s> <s> <s> ein mädchen spielt in einer billard@@ ] [ <s> <s> a girl ] [ playing ]
[ <s> <s> ein mädchen spielt in einer billard@@ halle ] [ <s> a girl playing ] [ pool ]
[ <s> ein mädchen spielt in einer billard@@ halle billard ] [ a girl playing pool ] [ at ]
[ ein mädchen spielt in einer billard@@ halle billard . ] [ girl playing pool at ] [ a ]
[ mädchen spielt in einer billard@@ halle billard . </s> ] [ playing pool at a ] [ pool ]
[ spielt in einer billard@@ halle billard . </s> </s> ] [ pool at a pool ] [ hall ]
[ in einer billard@@ halle billard . </s> </s> </s> ] [ at a pool hall ] [ . ]
[ einer billard@@ halle billard . </s> </s> </s> </s> ] [ a pool hall . ] [ </s> ]
[ <s> <s> <s> <s> <s> ein mann klettert eine ] [ <s> <s> <s> <s> ] [ a ]
[ <s> <s> <s> <s> ein mann klettert eine felswand ] [ <s> <s> <s> a ] [ man ]
[ <s> <s> <s>

### Real test datasss

In [9]:
# load indexed data
source_indexed_data = PickleLoader.load('../../data/data_v2/multi30k.de.5000_BPE.indexed.pickle')
target_indexed_data = PickleLoader.load('../../data/data_v2/multi30k.en.5000_BPE.indexed.pickle')

In [10]:
# batch all data
batcher = Batcher(source_indexed_data, target_indexed_data, batch_size=BATCH_SIZE, window=WINDOW_SIZE, torch_device="cpu")

In [11]:
batcher.batch()

Creating batches: 100%|██████████| 29000/29000 [00:12<00:00, 2254.34it/s, Batch=2038/2038, Row=24/200] 


In [12]:
batches = batcher.getBatches()
# PickleLoader.save('data/data_v2/multi30k.batches.pickle', batches)

### Toy data

In [None]:
source = [
    [0, 2, 95, 59, 63, 52, 48, 4, 1],
    [0, 42, 63, 4, 82, 48, 1],
    [0, 17, 70, 21, 89, 1],
    [0, 49, 92, 48, 39, 79, 99, 55, 1],
    [0, 5, 93, 78, 53, 1]
]

target = [
    [100, 159, 187, 197, 147, 169, 113, 191, 197, 101],
    [100, 148, 152, 154, 139, 101],
    [100, 164, 151, 133, 101],
    [100, 151, 131, 190, 101],
    [100, 187, 154, 191, 163, 101]
]

In [None]:
batcher = Batcher(source, target, batch_size=10, window=3)
batcher.batch()
for (S, T, L) in batcher.getBatches():
    print("Source windows:")
    print(S)
    print("\n--------\n")
    print("Target windows:")
    print(T)
    print("\n--------\n")
    print("Labels:")
    print(L)
    print("\n--------\n")