In [1]:
import sys
sys.path.append("../../")

In [2]:
from pytorch_beam_search import seq2seq

# Create vocabularies
source = [list("abcdefghijkl"), list("mnopqrstwxyz")]    # tokenize the way you need
target = [list("ABCDEFGHIJKL"), list("MNOPQRSTWXYZ")]    # tokenize the way you need
source_index = seq2seq.Index(source)
target_index = seq2seq.Index(target)

# Create tensors
X = source_index.text2tensor(source)
Y = target_index.text2tensor(target)
# X.shape == (n_source_examples, len_source_examples) == (2, 11)
# Y.shape == (n_target_examples, len_target_examples) == (2, 12)

# Create and train the model
model = seq2seq.Transformer(source_index, target_index)    # just a PyTorch model
model.fit(X, Y)    # basic method included

# Generate new predictions
new_source = [list("new first in"), list("new second in")]
new_target = [list("new first out"), list("new second out")]
X_new = source_index.text2tensor(new_source)
Y_new = target_index.text2tensor(new_target)
loss, error_rate = model.evaluate(X_new, Y_new)    # basic method included
predictions, log_probabilities = seq2seq.beam_search(model, X_new) 
output = [target_index.tensor2text(p) for p in predictions]
output

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

Model: Seq2Seq Transformer
Source index: <Seq2Seq Index with 28 items>
Target index: <Seq2Seq Index with 28 items>
Max sequence length: 32
Embedding dimension: 32
Feedforward dimension: 128
Encoder layers: 2
Decoder layers: 2
Attention heads: 2
Activation: relu
Dropout: 0.0
Trainable parameters: 63,260



  0%|          | 0/5 [00:00<?, ?it/s]

Training started
X_train.shape: torch.Size([2, 14])
Y_train.shape: torch.Size([2, 14])
Epochs: 5
Learning rate: 0.0001
Weight decay: 0
Epoch | Train                 | Minutes
      | Loss     | Error Rate |
---------------------------------------


  0%|          | 0/1 [00:00<?, ?it/s]

    1 |   3.6379 |    100.000 |     0.0


  0%|          | 0/1 [00:00<?, ?it/s]

    2 |   3.6148 |    100.000 |     0.0


  0%|          | 0/1 [00:00<?, ?it/s]

    3 |   3.5921 |    100.000 |     0.0


  0%|          | 0/1 [00:00<?, ?it/s]

    4 |   3.5693 |    100.000 |     0.0


  0%|          | 0/1 [00:00<?, ?it/s]

    5 |   3.5469 |    100.000 |     0.0


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

[['<START>JS<END>HJK<END>JHJX<END>JSXTJ<END>JX',
  '<START>JS<END>HJK<END>JHJX<END>JSXTJSXX',
  '<START>JS<END>HJK<END>JHJX<END>JCXTJSXX',
  '<START>JS<END>HJK<END>JHJX<END>JSXTJ<END>XX',
  '<START>JS<END>HJK<END>JHJX<END>JCXTJ<END>XX'],
 ['<START>JS<END>JKK<END>JTJMKTJXTJ<END>XX',
  '<START>JS<END>JKK<END>JTJMKTJXTJSXX',
  '<START>JS<END>JKK<END>JTJMKTJXTJ<END>JX',
  '<START>JS<END>JKK<END>JTJMKKJXTJ<END>XX',
  '<START>JS<END>JKK<END>JTJMKKJXTJSXX']]