<a href="https://colab.research.google.com/github/datjandra/Clairvaux/blob/master/PhireBlast.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install 'transformers[torch]'
!pip uninstall -y transformers
!pip install git+https://github.com/huggingface/transformers
!pip install gradio

In [None]:
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr

torch.set_default_device("cuda")

def predict(name, gender, age, conditions):
  PERSIST_DIR = "./storage"
  try:
    if not os.path.exists(PERSIST_DIR):
      model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", low_cpu_mem_usage=True, trust_remote_code=True)
      model.save_pretrained(PERSIST_DIR, from_pt=True)
    else:
      model = AutoModelForCausalLM.from_pretrained(PERSIST_DIR, torch_dtype="auto")

    tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
    prompt = "Instruct: Sample data in FHIR JSON format of {age} year old {gender} patient named {name} with {conditions}.\nOutput:\n"
    prompt = prompt.format(age=age, gender=gender, name=name, conditions=conditions)
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    model.to("cuda")
    outputs = model.generate(**inputs, max_length=256)
    text = tokenizer.batch_decode(outputs)[0]
    return text
  finally:
    del model
    del tokenizer

demo = gr.Blocks()
with demo:
  gr.Markdown("<div class='pull-left'><img width='100' src='https://raw.githubusercontent.com/datjandra/PhireBlast/main/phireblast.png'></div><h3>PhireBlast</h3>")
  with gr.Row():
    name = gr.Textbox(label="Name")
    gender = gr.Dropdown(["male", "female"], label="Gender", value="female")
  with gr.Row():
    age = gr.Textbox(label="Age")
    conditions = gr.Textbox(label="Conditions")

  output = gr.Textbox(label="Data", lines=10)
  submit_button = gr.Button("Submit")
  submit_button.click(predict, inputs=[name, gender, age, conditions], outputs=output)

demo.launch(debug=True)

In [None]:
# run if needed to clean up memory
%reset -f