In [None]:
from fastai.text.all import *
from fastai.vision.widgets import *

In [None]:
import urllib.request
import gc

MODEL_URL = "https://github.com/mikegarts/remarqueapp/raw/master/remarque.pkl"

# Remarquefy - What would Remarque say
## A neural network trained to complete a few words of an input as a paragraph from a book by Erich Maria Remarque

In [None]:
MIN_WORDS = 60
path = Path()
btn_generate = widgets.Button(description='Generate paragraph')
inp_text = widgets.Text(placeholder='I was drinking because')
progress = widgets.IntProgress(
    value=0,
    min=0,
    max=MIN_WORDS * 2,
    description='Generating: ',
    bar_style='info', # 'success', 'info', 'warning', 'danger' or ''
    style={'bar_color': 'maroon'},
    orientation='horizontal'
)

result = widgets.Textarea(rows=16, disabled=True)

def get_learn():
    #     return load_learner('remarque.pkl', cpu=True)
    model_filename = 'dw_' + Path(MODEL_URL).name
    if not Path(model_filename).is_file():
        result.value = 'Downloading model...'
        urllib.request.urlretrieve(MODEL_URL, model_filename)
        result.value = 'Loading model...'
    
    result.value = ''
    return load_learner(model_filename)

def generate_until(text, min_words=MIN_WORDS, stop='.', temperature=0.75):
    learn = get_learn()
    
    added_words = 0
    first_batch = int(min_words / 3)
    text = learn.predict(text, first_batch, temperature=temperature)
    added_words = first_batch
    micro_batch = 3
    while True:
        text = learn.predict(text, micro_batch, temperature=temperature)    
        if (text[-1] != stop or added_words <= min_words) and (added_words < min_words * 2):
            added_words += micro_batch
            result.value = text
            progress.value += micro_batch
            continue
        else:
            break
    
    if text[-1] != stop:
        text += stop

    text = text.replace('xxunk','')
    text = text.replace('Xxunk','')
    text = text.replace(' i ',' I ')
    text = text.replace(' .','.')
    text = text.replace(' ,',',')
    learn = None
    gc.collect()
    return text

def on_click_generate(change):
    progress.value = 1
    btn_generate.disabled = True
    inp_text.disabled = True
    result.value = 'Generating ... please wait'
    preds = generate_until(inp_text.value)
    result.value = preds
    btn_generate.disabled = False
    inp_text.disabled = False
    progress.value = 0
    inp_text.value = ''

btn_generate.on_click(on_click_generate)
display(VBox([widgets.Label('Enter a few words to start the paragraph'),inp_text, btn_generate, progress, result]))
