In [1]:
from data.data import generate_sample
from data.generator import get_token2pos
from utils.rnn_utils import make_1hot
import torch

In [2]:
def get_target_vocab_chars():
    return set('0123456789-')

In [3]:
def get_target_vocab_size():
    return len(get_target_vocab_chars()) + 3  # sos eos pad

In [4]:
def get_target_token2pos():
    token2pos = {t: p for p, t in enumerate(get_target_vocab_chars())}
    token2pos['sos'] = len(token2pos)  # start of string
    token2pos['eos'] = len(token2pos)  # end of string
    token2pos['#'] = len(token2pos)    # padding
    return token2pos

In [5]:
def make_target_tensor(target, token2pos, vocab_size):
    """Given a target sequence, vocabulary size and a dictionary associating each
    char to an index, builds a tensor representation of the target sequence."""
    target_tensor = []
    target_tensor.append(make_1hot('sos', token2pos, vocab_size))
    for c in target:
        target_tensor.append(make_1hot(c, token2pos, vocab_size))
    target_tensor.append(make_1hot('eos', token2pos, vocab_size))
    return torch.concat(target_tensor).unsqueeze(dim=0)

In [6]:
def target_tensors_to_str(y_t):
    pos2token = get_target_pos2token()
    idx_outputs = [torch.argmax(o).item() for o in y_t]
    return ''.join([pos2token[idx] for idx in idx_outputs])

In [7]:
def get_target_pos2token():
    token2pos = get_target_token2pos()
    return {p: t for t, p in token2pos.items()}

In [8]:
x, y = generate_sample(3, 4)

In [9]:
y_t = make_target_tensor(y, get_target_token2pos(), get_target_vocab_size())

In [10]:
y

'670'

In [11]:
y_t.size()

torch.Size([1, 5, 14])

In [12]:
target_tensors_to_str(y_t.squeeze())

'sos670eos'