In [3]:
import importlib

import sys
sys.path.append("../../")

# importlib.reload(index)
from pytorch_beam_search import autoregressive

# Create vocabulary and examples
corpus = list("abcdefghijklmnopqrstwxyz ")    # tokenize the way you need
# len(corpus) == 25
index = autoregressive.Index(corpus)
n_gram_size = 17    # 16 with an offset of 1 
n_grams = [corpus[i:n_gram_size + i] for i in range(len(corpus))[:-n_gram_size + 1]]

# Create tensor
X = index.text2tensor(n_grams)
# X.shape == (n_examples, len_examples) == (25 - 17 + 1 = 9, 17)

# Create and train the model
model = autoregressive.TransformerEncoder(index)    # just a PyTorch model
model.fit(X)    # basic method included

# Generate new predictions
new_examples = ["new first", "new second"]
X_new = index.text2tensor(new_examples)
loss, error_rate = model.evaluate(X_new)    # basic method included
predictions, log_probabilities = autoregressive.beam_search(model, X_new)
# every element in predictions is the list of candidates for each example
output = [index.tensor2text(p) for p in predictions]
output

  0%|          | 0/9 [00:00<?, ?it/s]

Model: Autoregressive Transformer Encoder
Index: <Autoregressive Index with 27 items>
Max sequence length: 16
Embedding dimension: 32
Feedforward dimension: 128
Layers: 2
Attention heads: 2
Activation: relu
Dropout: 0.0
Trainable parameters: 40,379



  0%|          | 0/5 [00:00<?, ?it/s]

Training started
X_train.shape: torch.Size([9, 17])
Epochs: 5
Learning rate: 0.0001
Weight decay: 0
Epoch | Train                 | Minutes
      | Loss     | Error Rate |
---------------------------------------


  0%|          | 0/1 [00:00<?, ?it/s]

    1 |   3.5582 |     98.611 |     0.0


  0%|          | 0/1 [00:00<?, ?it/s]

    2 |   3.5481 |     98.611 |     0.0


  0%|          | 0/1 [00:00<?, ?it/s]

    3 |   3.5380 |     98.611 |     0.0


  0%|          | 0/1 [00:00<?, ?it/s]

    4 |   3.5279 |     98.611 |     0.0


  0%|          | 0/1 [00:00<?, ?it/s]

    5 |   3.5178 |     98.611 |     0.0


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

[['<PAD>new firstgjnshshhhhhhhhhhhhhh',
  '<PAD>new firstgjnshhhhhhhhhhhhhhhh',
  '<PAD>new firstgjnshshqhhhhhhhhhhhh',
  '<PAD>new firstgjnshshhqhhhhhhhhhhh',
  '<PAD>new firstgjnshshhhqhhhhhhhhhh'],
 ['new second<PAD>tgnjhhhhhhhhhhhhhhh',
  'new secondltgnjshhhhhhhhhhhhhh',
  'new secondltgnjhhhhhhhhhhhhhhh',
  'new secondltgnjshqhhhhhhhhhhhh',
  'new secondltgnjshhhqhhhhhhhhhh']]