# train models

## wikitext2

In [None]:
import math

from torchtext.datasets import WikiText2

from model import TransformerModel, train_model, evaluate_model, save_model
from data_provider import LMData

vocab_iter = WikiText2(split='train')
train_iter, eval_iter, test_iter = WikiText2()
train_data = LMData(vocab_iter, train_iter, 20, 35)
test_data = LMData(vocab_iter, test_iter, 20, 35)

n_token = train_data.n_token  # size of vocabulary
model = TransformerModel(n_token)

#train_model(model, train_data, 10)
#save_model(model)

test_loss = evaluate_model(model, test_data)
print('perplexity on test data: {:.2f}'.format(math.exp(test_loss)))


## wikitext103

In [1]:
import math

from torchtext.datasets import WikiText103

from model import TransformerModel, train_model, evaluate_model, save_model
from data_provider import LMData

vocab_iter = WikiText103(split='train')
train_iter, eval_iter, test_iter = WikiText103()
train_data = LMData(vocab_iter, train_iter, 20, 35)
test_data = LMData(vocab_iter, test_iter, 20, 35)

n_token = train_data.n_token  # size of vocabulary
model = TransformerModel(n_token)

#train_model(model, train_data, 10)
#save_model(model)

test_loss = evaluate_model(model, test_data)
print('perplexity on test data: {:.2f}'.format(math.exp(test_loss)))

KeyboardInterrupt: 

## librispeech

In [None]:
import math
import torch

from model import TransformerModel, save_model, train_model, evaluate_model
from data_provider import LibriSpeechRawIter, LMData

train_iter = LibriSpeechRawIter('train')
train_data = LMData(train_iter, 20, 35)
test_iter = LibriSpeechRawIter('test')
test_data = LMData(test_iter, 20, 35)
n_token = train_data.n_token

model = TransformerModel(n_token)
train_model(model, train_data, 10)
save_model(model)
test_loss = evaluate_model(model, test_data)

print('perplexity on test dataset: {:.2f}'.format(math.exp(test_loss)))

## airbus

# load pretrained model

In [None]:
import torch

from torchtext.datasets import WikiText2, WikiText103, PennTreebank


from model import TransformerModel
from data_provider import LMData, LibriSpeechRawIter

vocab_iter = WikiText103(split='train')
train_iter, eval_iter, test_iter = WikiText103()
test_data = LMData(vocab_iter, test_iter, 20, 35)
n_token = test_data.n_token

path = 'saved_models/20210917_102159_IXOYXWR' # wiki103
model_dict = torch.load(path)
model = TransformerModel(n_token=n_token)
model.load_state_dict(model_dict['model_state_dict'])
model.eval();

# inference

In [None]:
import torch

from data_provider import str2seq, seq2str
from model import inference

input_str = 'i want to get the song'
input = str2seq(input_str, test_data.vocab, test_data.tokenizer)
input_length = 6

output = inference(model, input[:input_length])
prediction = torch.argmax(output, dim=1)
prediction_str = seq2str(prediction.unsqueeze_(dim=1)[-1], test_data.vocab)

print(input_str)
print(seq2str(input[:input_length], test_data.vocab))
print(prediction_str)


In [68]:
import torch
import jiwer

from model import inference
from data_provider import seq2str

model.eval()
wer = 0
for n_batches, batch in enumerate(test_data):
    #print(seq2str(inputs[:, 0], test_data.vocab))
    #print(seq2str(targets[::test_data.batch_size], test_data.vocab))
    inputs, targets, index = batch
    output_prob = inference(model, inputs)
    output_max = torch.argmax(output_prob, dim=1)
    output_max = output_max.reshape(35, 20)
    targets = targets.reshape(35, 20)

    for n_samples, sample in enumerate(zip(output_max.t(), targets.t())):
        prediction, target = sample
        prediction_str = seq2str(prediction, test_data.vocab)
        target_str = seq2str(target, test_data.vocab)
        wer += jiwer.wer(target_str, prediction_str)

print(wer/(n_batches * n_samples))


0.8398100514444204
