In [None]:
import json
import os

import requests
import ipywidgets as widgets
from ipywidgets import GridspecLayout

import lightonmuse

In [None]:
# object responsible to perform the Create API calls
creator = lightonmuse.Create("orion-fr")

In [None]:
### interface - feel free to ignore this
prompt = widgets.Textarea(
        value='La capitale de France s\'appelle',
        placeholder='Type something', description='Prompt:',
        disabled=False, layout=widgets.Layout(width='100%', height='120px')
    )
# generation params
mode = widgets.Dropdown(options=['greedy', 'topk', 'nucleus'], value='nucleus',
                            description='Mode:')
n_tokens = widgets.IntSlider(description='N tokens:', min=1, max=2048, value=16)
best_of = widgets.IntSlider(description='Best of:', min=1, max=16, value=1)
generation_grid = GridspecLayout(1, 3, align_items="center")
generation_grid[0, 0], generation_grid[0, 1], generation_grid[0, 2] = mode, n_tokens, best_of
# sampling params
temperature = widgets.FloatLogSlider(description="Temperature: ", base=10, min=-3, max=3,
                                         value=1., step=0.1)
p = widgets.FloatSlider(description="p: ", min=0., max=1., value=0.9)
k = widgets.IntSlider(description='k:', min=1, max=5, value=3)
sampling_grid = GridspecLayout(1, 3, align_items="center")
sampling_grid[0, 0], sampling_grid[0, 1], sampling_grid[0, 2] = temperature, p, k
# penalties params
presence_penalty = widgets.FloatSlider(description="Presence penalty: ",
                                           min=0., max=1., step=0.05,
                                           value=0., style= {'description_width': 'initial'},
                                           layout=widgets.Layout(width='100%'))
frequency_penalty = widgets.FloatSlider(description="Frequency penalty: ",
                                        min=0., max=1., step=0.05,
                                        value=0.8, style= {'description_width': 'initial'},
                                        layout=widgets.Layout(width='100%'))
penalties_grid = GridspecLayout(1, 2, align_items="center")
penalties_grid[0, 0], penalties_grid[0, 1] = presence_penalty, frequency_penalty
# biases
forbidden = widgets.Text(
        value='Paris;Lyon',
        placeholder='Type forbidden words, separated by ;',
        description='Forbidden words:', disabled=False,
        style= {'description_width': 'initial'}, layout=widgets.Layout(width='90%')
    )
encouraged = widgets.Text(
    value='Bordeaux;Marseille',
    placeholder='Type encouraged words, separated by ;',
    description='Encouraged words:', disabled=False,
    style= {'description_width': 'initial'}, layout=widgets.Layout(width='90%')
)
biases_grid = GridspecLayout(1, 2, align_items="center")
biases_grid[0, 0], biases_grid[0, 1] = encouraged, forbidden
# stopwords and generation button
stop_words = widgets.Text(
        value='', placeholder='Type stopwords, separated by ;',
        description='Stopwords:', disabled=False, layout=widgets.Layout(width='90%')
    )
button = widgets.Button(
    description='Generate', disabled=False,
    button_style='',
    icon='marker',  # (FontAwesome names without the `fa-` prefix)
    layout=widgets.Layout(width='90%')
)
bottom_grid = GridspecLayout(1, 2)
bottom_grid[0, 0], bottom_grid[0, 1] = stop_words, button
# output field
out = widgets.Output(layout=widgets.Layout(width='100%', border='1px solid black'))

In [None]:
### logic building the Create API call
def create(button):
    # clear the output box at every generation
    out.clear_output()
    
    # build the word biases dictionary
    biases = None
    if forbidden.value:  # if empty string this is False :)
        biases = dict()
        for bias in forbidden.value.split(";"):
            # effectivaly forbid word
            biases[bias] = -100
    if encouraged.value:  # if empty string this is False :)
        if biases is None:
            # if the dict wasn't created before because
            # there were no forbidden words
            # create it now
            biases = dict()
        for bias in encouraged.value.split(";"):
            # fine-tuned value that works nicely in 
            # combination with penalties
            biases[bias] = 4.5
    
    # build the stop_words list
    stop_words_list = None
    if stop_words.value:
        stop_words_list = list()
        for word in stop_words.value.split(";"):
            stop_words_list.append(word)
    
    # handle best_of and n_completions
    n_completions = 1
    if best_of.value > 1:
        n_completions = best_of.value
    
    # build the dictionary of parameters
    params = {"n_tokens": n_tokens.value, 
              "mode": mode.value, "temperature": temperature.value, "p": p.value, "k": k.value,
              "presence_penalty": presence_penalty.value, "frequency_penalty": frequency_penalty.value, 
              "best_of": 1, "n_completions": n_completions, 
              "seed": 42  # clearly the best seed
             }
    
    # add the biases and stop words if they have been provided
    if biases is not None:
        params["word_biases"] = biases
    if stop_words_list is not None:
        params["stop_words"] = stop_words_list
    
    # call Create
    outputs, cost, rid = creator(text=prompt.value, **params)
    # parse the response
    completions = outputs[0]["completions"]
    # update the output
    with out:
        print("**"+prompt.value+"**" + " " + completions[0]["output_text"])
    # and the prompt
    prompt.value += " " + completions[0]["output_text"]
    return outputs, cost, rid

In [None]:
# register create call on button click
button.on_click(create) 

In [None]:
display(prompt, generation_grid, sampling_grid, penalties_grid, biases_grid, bottom_grid, out)