In [None]:
!pip install datasets ipywidgets

In [None]:
import urllib.request
import gc
from pathlib import Path
import ipywidgets as widgets

import transformers
from transformers import AutoTokenizer, TFAutoModelForCausalLM
from datasets import load_dataset
import tensorflow as tf

checkpoint = 'distilgpt2'
saved_checkpoint = 'mikegarts/mishka-distgpt2'

# An improved version of remarqify, this time based on the self attention mechanism (distilgpt2)
## A neural network trained to complete a few words of an input as a paragraph from a book by Erich Maria Remarque

In [None]:
MIN_WORDS = 60
path = Path()
btn_generate = widgets.Button(description='Generate paragraph')
inp_text = widgets.Text(placeholder='I was drinking because')
progress = widgets.IntProgress(
    value=0,
    min=0,
    max=MIN_WORDS * 2,
    description='Generating: ',
    bar_style='info', # 'success', 'info', 'warning', 'danger' or ''
    style={'bar_color': 'maroon'},
    orientation='horizontal'
)

result = widgets.Textarea(rows=13, disabled=True)

def get_model():
    result.value = 'Downloading model...'
    model = TFAutoModelForCausalLM.from_pretrained(saved_checkpoint)
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    result.value = ''
    return model, tokenizer

def generate(prompt):
    model, tokenizer = get_model()
    
    input_context = prompt
    input_ids = tokenizer.encode(input_context, return_tensors="tf")

    outputs = model.generate(
        input_ids=input_ids, 
        max_length=MIN_WORDS, 
        temperature=0.7, 
        num_return_sequences=1, 
        do_sample=True
    )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True).rsplit('.', 1)[0] + '.'

def on_click_generate(change):
    progress.value = 1
    btn_generate.disabled = True
    inp_text.disabled = True
    result.value = 'Generating ... please wait'
    preds = generate(inp_text.value)
    result.value = preds
    btn_generate.disabled = False
    inp_text.disabled = False
    progress.value = 0
    inp_text.value = ''

btn_generate.on_click(on_click_generate)
display(widgets.VBox([widgets.Label('Enter a few words to start the paragraph'),inp_text, btn_generate, progress, result]))
