In [1]:
#Import libraries
import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output, State
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from tqdm import tqdm, trange
import torch.nn.functional as F

The dash_core_components package is deprecated. Please replace
`import dash_core_components as dcc` with `from dash import dcc`
  import dash_core_components as dcc
The dash_html_components package is deprecated. Please replace
`import dash_html_components as html` with `from dash import html`
  import dash_html_components as html


In [2]:
#Import tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
device=torch.device('cpu')
model = GPT2LMHeadModel.from_pretrained('distilgpt2')
model_chat = GPT2LMHeadModel.from_pretrained('distilgpt2')
model_imdb = GPT2LMHeadModel.from_pretrained('distilgpt2')
model_chat.load_state_dict(torch.load("./reddit_chat_text_gen_epoch20.pt"))
model_chat.load_state_dict(torch.load("./IMDB_text_gen_epoch20.pt"))

<All keys matched successfully>

In [3]:
def generate(
    model,
    tokenizer,
    prompt,
    entry_count=10,
    entry_length=30, #maximum number of words
    top_p=0.8,
    temperature=1.,
):

    model.eval()

    generated_num = 0
    generated_list = []

    filter_value = -float("Inf")

    with torch.no_grad():

        for entry_idx in trange(entry_count):

            entry_finished = False

            generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)

            for i in range(entry_length):
                outputs = model(generated, labels=generated)
                loss, logits = outputs[:2]
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)

                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                    ..., :-1
                ].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = filter_value

                next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
                generated = torch.cat((generated, next_token), dim=1)

                if next_token in tokenizer.encode("<|endoftext|>"):
                    entry_finished = True

                if entry_finished:

                    generated_num = generated_num + 1

                    output_list = list(generated.squeeze().numpy())
                    output_text = tokenizer.decode(output_list)
                    generated_list.append(output_text)
                    break
            
            if not entry_finished:
              output_list = list(generated.squeeze().numpy())
              output_text = f"{tokenizer.decode(output_list)}<|endoftext|>" 
              generated_list.append(output_text)
                
    return generated_list

In [5]:
#START OF THE ACTUAL APP INTERFACE
app = dash.Dash()

#Create the layout of the app
app.layout = html.Div([
    
 html.H1(children='TEXT GENERATOR', style={'textAlign': 'center'
        }),
    
#Add dropdown menu for type of
dcc.Dropdown(id='models',
    options=[
        {'label': 'original', 'value': 'original'},
        {'label': 'chat', 'value': 'chat'},
        {'label': 'IMDB', 'value': 'IMDB'}
    ],
    value='original'),

#Add place where text would be inserted    
  dcc.Textarea(
        id='textarea-state-example',
        value='',
        style={'width': '99%', 'height': 200, 'font-size': 'large','margin-top': '10px','margin-bottom': '10px'},
    ),
    html.Button('Submit', id='textarea-state-example-button', n_clicks=0, 
               style={'font-size': '12px', 'width': '140px', 'display': 'inline-block','margin-top': '10px', 'margin-bottom': '10px', 'margin-right': '5px', 'height':'25px'}),
    
    html.Div(id='textarea-state-example-output', style={'whiteSpace': 'pre-line', 'position':'absolute', 
                        'margin-left': 'auto', 'margin-right': 'auto', 'top':'400px',
                        'font-size': 'large','height': '200px', 'width': 'auto'})
    
])  

@app.callback(
    Output('textarea-state-example-output', 'children'),
    [Input('textarea-state-example-button', 'n_clicks'), Input('models', 'value')],
    State('textarea-state-example', 'value')
)

def update_output(n_clicks, value, value2):
    
    if n_clicks > 0:
        if value == 'original':
            generated = generate(model.to('cpu'), tokenizer, value2, entry_count=1)
            
            #Clean the output
            generated2 = ' '.join(generated)
            to_remove = generated2.split('.')[-1]
            
            my_text = generated2.replace(to_remove,'')
        
        if value == 'chat':
            generated = generate(model_chat.to('cpu'), tokenizer, value2, entry_count=1)
            
            #Clean the output
            generated2 = ' '.join(generated)
            to_remove = generated2.split('.')[-1]
            
            my_text = generated2.replace(to_remove,'')
        if value == 'IMDB':
            generated = generate(model_imdb.to('cpu'), tokenizer, value2, entry_count=1)
            
            #Clean the output
            generated2 = ' '.join(generated)
            to_remove = generated2.split('.')[-1]
            
            my_text = generated2.replace(to_remove,'')              
        
        return my_text

if __name__ == "__main__":
    app.run_server(debug=True, host="0.0.0.0", use_reloader=False, port=8050)

Dash is running on http://0.0.0.0:8050/

Dash is running on http://0.0.0.0:8050/

 * Serving Flask app '__main__'
 * Debug mode: on


  0%|          | 0/1 [00:00<?, ?it/s]
  0%|          | 0/1 [00:00<?, ?it/s]
