# BioGPT
## Generative pre-trained transformer for biomedical text generation and mining

In [None]:
import os
import torch
from fairseq.models.transformer_lm import TransformerLanguageModel
import ipywidgets as widgets
from IPython.display import display, clear_output

In [None]:
checkpoints_path = "/opt/shared/data/biogpt/checkpoints"
data_path = "/opt/shared/data/biogpt/data"

In [None]:
def start(model):
    m = TransformerLanguageModel.from_pretrained(
        os.path.join(checkpoints_path, model),
        "checkpoint.pt",
        os.path.join(data_path, model),
        tokenizer='moses',
        bpe='fastbpe',
        bpe_codes=os.path.join(data_path, model, "bpecodes"),
        min_len=100,
        max_len_b=1024)
    m.cuda()
    return m

In [None]:
m = {}

In [None]:
def generate(prompt, model):
    global m
    if m.get(model) == None:
        print('Cold start. Loading the weights...')
        m[model] = start(model)
        print('GPU is running')
    src_tokens = m[model].encode(prompt)
    generate = m[model].generate([src_tokens], beam=5)[0]
    output = m[model].decode(generate[0]["tokens"])
    return output

In [None]:
def button_factory(text):
    button = widgets.Button(
        description=text,
        disabled=False,
        display='flex',
        flex_flow='column',
        align_items='stretch',
        layout=layout
    )
    return button

In [None]:
layout = widgets.Layout(width='auto', height='40px') #set width and height

inp = widgets.Text(
    value='',
    placeholder='Type something',
    description='Prompt:',
    disabled=False,
    display='flex',
    flex_flow='column',
    align_items='stretch',
)
button = widgets.Button(
    description='Generate',
    disabled=False,
    tooltip='Click me',
    icon='cog'
)
model_dropdown = widgets.Dropdown(
    options=['BioGPT', 'BioGPT-Large'],
    value='BioGPT',
    description='Model:',
    disabled=False,
)
output = widgets.Output()
response_widget = widgets.HTML(
    value="",
    description='<b>></b>',
)

def btn_generate(btn):
    with output:
        clear_output()
        response_widget.value = '<img src="https://user-images.githubusercontent.com/3059371/49334754-3c9dfe00-f5ab-11e8-8885-0192552d12a1.gif" width="50" />'
        text = generate(inp.value, model_dropdown.value)
        clear_output()
        response_widget.value = text

button.on_click(btn_generate)

search_bar = widgets.VBox([widgets.HBox([inp,button]), model_dropdown])

suggestion_prompts = ['COVID-19 is', 'A 65-year-old female patient with a past medical history of']
suggestion_buttons = [button_factory(prompt) for prompt in suggestion_prompts]

def btn_populate_search_bar(btn):
    inp.value = btn.description

for btn in suggestion_buttons:
    btn.on_click(btn_populate_search_bar)

suggestions = widgets.HBox(suggestion_buttons)

display(search_bar, suggestions, response_widget)