In [None]:
%load_ext autoreload
%autoreload 2

## Data

In [None]:
# load WebBeteg db
import json

with open('./data/WebBeteg', 'rb') as f:
    database = json.load(f)['_default']

keys = list(database.keys())

webbeteg = ' '.join([database[k]['text'] for k in keys[:2000]])

In [None]:
# load shakespeare db
import tensorflow as tf
from src.data import Dataset

path_to_file = tf.keras.utils.get_file('shakespeare.txt', 
                                       'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
shakespeare = open(path_to_file, 'rb').read().decode(encoding='utf-8')

In [None]:
#load moliere db
with open('./data/moliere_complete.txt', 'r') as f:
    moliere = ''.join(list(f))

In [None]:
#create a dataset from one of the db

dataset = Dataset(webbeteg)

## Model

In [None]:
from src.transformer_model import TransformerDecoder
from src.train import train

num_layers = 6 #num of decoder layers
d_model = 128 #dimension of the self-attention's feature space
dff = 512 #dimension of the fully-connected layer's feature space
num_heads = 8 #num of heads in each self-attention layer
dropout_rate = 0.1 #dropout rate

# instatiate and train the transformer decoder language model
transformer = TransformerDecoder(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    dff=dff,
    target_vocab_size=len(dataset.vocab),
    maximum_position_encoding=100,
    rate=dropout_rate)

train(transformer,  dataset.dataset, 20)
transformer.summary()

In [None]:
from src.rnn_model import LSTMModel
from src.train import train

# The embedding dimension
embedding_dim = 128

# Number of RNN units
rnn_units = 512

# instatiate and train the lstm language model
rnn_model = LSTMModel(
    vocab_size=len(dataset.vocab),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units)

train(rnn_model, dataset.dataset, 20)
rnn_model.summary()

## Text generation

In [None]:
from src.rnn_model import RNNGenerator
import pickle as pkl

# create an lstm based text generator and print the result
# the text seed is 'Doctor', the length is 200 character
rnn_generator = RNNGenerator(rnn_model, dataset.chars_from_ids, dataset.ids_from_chars)
print(rnn_generator.generate_text('Doktor', 200))

In [None]:
from src.transformer_model import TransformerGenerator

# create a transformer based text generator and print the result
# the text seed is 'Doctor', the length is 200 character
transformer_generator = TransformerGenerator(transformer, dataset.chars_from_ids, dataset.ids_from_chars)
print(transformer_generator.generate_text('Doktor', 200))