In [1]:
import torch
import torch.nn as nn
from torch.functional import F
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from pathlib import Path
import re
from collections import Counter
import random
from IPython.display import clear_output
import math

This function will test if two tensors are equal:

In [2]:
def equal_tensors(a,b):
    return torch.all(torch.eq(a,b))

# BaseTransform

In [3]:
from helpers import BaseTransform

In [4]:
bt = BaseTransform(['a', 'b', 'c'])

In [5]:
assert(bt.vocab == ['a', 'b', 'c'])
assert(bt.count == 3)

In [6]:
assert(bt.item2num['a'] == 0)
assert(bt.item2num['b'] == 1)
assert(bt.item2num['c'] == 2)
assert(bt.num2item[0] == 'a')
assert(bt.num2item[1] == 'b')
assert(bt.num2item[2] == 'c')

In [7]:
encoded = bt.encode(['a', 'b', 'c'])
assert(equal_tensors(encoded, torch.tensor([0,1,2])))

In [8]:
decoded = bt.decode(encoded)
assert(decoded == ['a', 'b', 'c'])

# TokTransform

In [9]:
from helpers import TokTransform

In [10]:
tt = TokTransform(['a', 'b'])

In [11]:
assert(tt.item2num['xxunk'] == 0)
assert(tt.item2num['a'] == 1)
assert(tt.item2num['b'] == 2)

In [12]:
assert(tt.num2item[0] == 'xxunk')
assert(tt.num2item[1] == 'a')
assert(tt.num2item[2] == 'b')

In [13]:
encoded = tt.encode(['a', 'b', 'c'])
assert(equal_tensors(encoded, torch.tensor([1,2,0])))

In [14]:
decoded = tt.decode(encoded)
assert(decoded == ['a', 'b', 'xxunk'])

# DataLoader

In [15]:
from helpers import DataLoader

In [16]:
x_set = [
    torch.tensor([10,11,12]),
    torch.tensor([20,21,22]),
    torch.tensor([30,31,32]),
]
y_set = [
    torch.tensor([11,12,13]),
    torch.tensor([21,22,23]),
    torch.tensor([31,32,33]),
]

In [17]:
dl = DataLoader(x_set, y_set, 2)

In [18]:
assert(dl.n_items == 3)

In [19]:
assert(dl.n_batches == 2)

In [20]:
for xb,yb in dl.get_batches():
    # all items in yb should be 1 more than their xb counterparts
    assert(torch.all(yb-xb==1))
    # this means xb and yb have remained aligned

# DataLoaders

In [21]:
from helpers import DataLoaders

In [22]:
dl_train = DataLoader([torch.tensor([0])], [torch.tensor([0])], 1)
dl_val = DataLoader([torch.tensor([1])], [torch.tensor([1])], 1)

In [23]:
dls = DataLoaders(dl_train, dl_val)

In [24]:
assert(dls.train == dl_train)
assert(dls.val == dl_val)