In [1]:
from random import choices
from string import ascii_lowercase

import torch
import torch.nn as nn
import torch.optim as optim
from dataloader import get_text_encoder_decoder, customize_dataloader_func
from pytorch_utils import get_model, train_and_evaluate, customize_predictor

In [2]:
in_vocab = ascii_lowercase
out_vocab = list(range(10, 36))
len(in_vocab), len(out_vocab)

(26, 26)

In [3]:
trans_map = dict(zip(in_vocab, out_vocab))
trans_map

{'a': 10,
 'b': 11,
 'c': 12,
 'd': 13,
 'e': 14,
 'f': 15,
 'g': 16,
 'h': 17,
 'i': 18,
 'j': 19,
 'k': 20,
 'l': 21,
 'm': 22,
 'n': 23,
 'o': 24,
 'p': 25,
 'q': 26,
 'r': 27,
 's': 28,
 't': 29,
 'u': 30,
 'v': 31,
 'w': 32,
 'x': 33,
 'y': 34,
 'z': 35}

In [4]:
in_seq_encoder, _ = get_text_encoder_decoder(in_vocab)
out_seq_encoder, out_seq_decoder = get_text_encoder_decoder(out_vocab)

In [5]:
def transduce(text):
    return [trans_map[t] for t in text]

In [6]:
transduce("adsdsd")

[10, 13, 28, 13, 28, 13]

In [7]:
train = []
dev = []

for _ in range(5):
    text = "".join(choices(ascii_lowercase, k=26))
    train.append([text, transduce(text)])
    dev.append([text[::-1], transduce(text[::-1])])

In [8]:
dataloader_func = customize_dataloader_func(in_seq_encoder, out_seq_encoder, 27, batch_size=1000)

In [9]:
train_dl = dataloader_func(train)
dev_dl = dataloader_func(dev)

In [10]:
ModelConfig = {"bias": True, 
               "rnn_type": "SRNN",
               "embd_dim": 2, 
               "num_layers": 1,
               "hidden_size": 2, 
               "dropout_rate": 0.0,
               "bidirectional": False,
               "in_vocab_size": len(in_vocab), 
               "out_vocab_size": len(out_vocab),  
               "reduction_method": "sum"}

model = get_model(ModelConfig)

The model has 142 trainable parameters


In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-2, weight_decay=1e-5)

log = train_and_evaluate(model, train_dl, dev_dl, criterion, optimizer, 
                         saved_model_fp="model.pt", acc_threshold=1.1, 
                         print_eval_freq=200, max_epoch_num=10000, 
                         train_exit_acc=1.0, eval_exit_acc=1.0)

Current epoch: 200, 
training performance: {'loss': 1.8932832479476929, 'full sequence accuracy': 0.0, 'first n-symbol accuracy': 0.046153846153846156, 'overlap rate': 0.3384615384615385}
evaluation performance: {'loss': 1.8951029777526855, 'full sequence accuracy': 0.0, 'first n-symbol accuracy': 0.015384615384615385, 'overlap rate': 0.3384615384615385}

Current epoch: 400, 
training performance: {'loss': 1.5397619009017944, 'full sequence accuracy': 0.0, 'first n-symbol accuracy': 0.046153846153846156, 'overlap rate': 0.4384615384615384}
evaluation performance: {'loss': 1.539762020111084, 'full sequence accuracy': 0.0, 'first n-symbol accuracy': 0.015384615384615385, 'overlap rate': 0.4384615384615384}

Current epoch: 600, 
training performance: {'loss': 1.3509770631790161, 'full sequence accuracy': 0.0, 'first n-symbol accuracy': 0.046153846153846156, 'overlap rate': 0.5230769230769231}
evaluation performance: {'loss': 1.3509780168533325, 'full sequence accuracy': 0.0, 'first n-symb

Current epoch: 5000, 
training performance: {'loss': 0.26772400736808777, 'full sequence accuracy': 0.0, 'first n-symbol accuracy': 0.3076923076923077, 'overlap rate': 0.9153846153846154}
evaluation performance: {'loss': 0.26772400736808777, 'full sequence accuracy': 0.0, 'first n-symbol accuracy': 0.2846153846153846, 'overlap rate': 0.9153846153846154}

Current epoch: 5200, 
training performance: {'loss': 0.25871559977531433, 'full sequence accuracy': 0.0, 'first n-symbol accuracy': 0.3076923076923077, 'overlap rate': 0.9153846153846154}
evaluation performance: {'loss': 0.25871556997299194, 'full sequence accuracy': 0.0, 'first n-symbol accuracy': 0.2846153846153846, 'overlap rate': 0.9153846153846154}

Current epoch: 5400, 
training performance: {'loss': 0.20316016674041748, 'full sequence accuracy': 1.0, 'first n-symbol accuracy': 1.0, 'overlap rate': 1.0}
evaluation performance: {'loss': 0.20316922664642334, 'full sequence accuracy': 1.0, 'first n-symbol accuracy': 1.0, 'overlap ra

In [12]:
predictor = customize_predictor(model, dataloader_func, out_seq_decoder)

In [13]:
predictor(ascii_lowercase)

[[10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35]]