<a href="https://colab.research.google.com/github/natviv/med-gpt3/blob/main/Medical_GPT_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Medical GPT-3. 

This colab features an exploration of potential medical applications with GPT-3.

The following applications are prototyped and explored.

*   Convert clinical notes with abbreviations to patient readable format
*   A primary care conversational agent that triages with a patient

This is built using Open AI API's integration with Weights and Biases (W&B).

## API key setup

In [None]:
# API key credentials
%env OPENAI_API_KEY=

## Install dependencies

In [2]:
!pip install --upgrade openai wandb

Collecting openai
  Downloading openai-0.18.1.tar.gz (42 kB)
[?25l[K     |███████▊                        | 10 kB 19.0 MB/s eta 0:00:01[K     |███████████████▍                | 20 kB 10.0 MB/s eta 0:00:01[K     |███████████████████████▏        | 30 kB 8.2 MB/s eta 0:00:01[K     |██████████████████████████████▉ | 40 kB 3.6 MB/s eta 0:00:01[K     |████████████████████████████████| 42 kB 727 kB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting wandb
  Downloading wandb-0.12.16-py2.py3-none-any.whl (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 7.3 MB/s 
Collecting pandas-stubs>=1.1.0.11
  Downloading pandas_stubs-1.2.0.58-py3-none-any.whl (162 kB)
[K     |████████████████████████████████| 162 kB 63.7 MB/s 
Collecting shortuuid>=0.5.0
  Downloading shortuuid-1.0.9-py3-none-any.whl (9.4 kB)
Collecting pathtools
  Downloading pathtools-0.

In [4]:
# Setup imports

import json
import numpy as np
import openai
import os
import pandas as pd
import wandb

from pathlib import Path
from tqdm import tqdm

In [5]:
run = wandb.init(project='Medical GPT-3', job_type="dataset_preparation")

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


### Connect to dataset in Google Drive

In [6]:
from google.colab import drive 
drive.mount('/content/gdrive/')

Mounted at /content/gdrive/


### Read data into pandas csv

In [18]:
df=pd.read_csv('gdrive/MyDrive/med-gpt3-data/med_abbreviations.csv')

### Analyze data

In [19]:
print(f"Number of rows is {len(df)} and number of columns is {len(df.columns)}")
for index, row in df.iterrows():
    print(f"ID: {row['ID']}, Text: {row['Text']}, Completion: {row['Completion']}")

df_test = df.iloc[:5,:]
df_train = df.iloc[5:,:]

Number of rows is 30 and number of columns is 3
ID: 1, Text: The pt has lbp, Completion: The patient has lower back pain.
ID: 2, Text: The pt is a 30 y/o m, Completion: The patient is a 30 year old male.
ID: 3, Text: VSS after Tx, Completion: Vital signs stable after treatment.
ID: 4, Text: c/o gi pain, Completion: complains of gastrointestinal pain.
ID: 5, Text: tx d/c due to c/o h/a, Completion: Treatment discontinued due to compaints of headache.
ID: 6, Text: abx dosage recommended, Completion: Antibiotics dosage recommended.
ID: 7, Text: thx required, Completion: therapy required.
ID: 8, Text: The pt requires lbp pt, Completion: The patient requires lower back pain physical therapy.
ID: 9, Text: MBC likely impaired due to covid, Completion: Maximum breathing capacity likely impaired due to covid.
ID: 10, Text: The pt has recurring history of jt pain, Completion: The patient has recurring history of joint pain.
ID: 11, Text: Nacl low in pt, Completion: Sodium chloride low in patient

### Use completion API to see how the 'text-davinci-002' model performs. Try with the following approaches:

* Zero shot with only instruction
* Few shot in-context learning 
* Few shot with chain of thought prompting -> https://arxiv.org/abs/2201.11903. Also see https://twitter.com/npew/status/1525900849888866307

In [20]:
openai.api_key = os.getenv("OPENAI_API_KEY")

# This can be made more efficient and converted into a few batch calls
def get_predictions(df, model, context, temperature=0.1, max_tokens=20, use_qa_prefix=False):
  results = []
  for _, row in tqdm(df.iterrows()):
      input = row['Text'] if not use_qa_prefix else row['Text'] + 'Q:'
      prompt = context + input + ' ->'
      res = openai.Completion.create(model=model, 
                                     prompt=prompt, 
                                     max_tokens=max_tokens, 
                                     temperature = temperature, 
                                     stop=[" END"])
      completion = res['choices'][0]['text']
      completion = completion[1:] # remove initial space
      results.append(f"Text: {row['Text']}, Target: {row['Completion']}, Prediction: {completion}")
  return results

def print_results(results): 
  for row in results:
    print(f"\n{row}")

model = 'text-davinci-002'
# Zero-shot with only instruction prompt
instruction = 'Convert the doctor note into patient readable format without abbreviations.\n'
results = get_predictions(df_test, model, instruction)
print_results(results)


5it [00:07,  1.58s/it]


Text: The pt has lbp, Target: The patient has lower back pain., Prediction: The patient has low back pain.

Text: The pt is a 30 y/o m, Target: The patient is a 30 year old male., Prediction: The patient is a 30 year old male.

The patient is a 30 year old male.

Text: VSS after Tx, Target: Vital signs stable after treatment., Prediction: VSS (visual acuity) after treatment

The patient's visual acuity was 20/

Text: c/o gi pain, Target: complains of gastrointestinal pain., Prediction: complaining of gastrointestinal pain

Text: tx d/c due to c/o h/a, Target: Treatment discontinued due to compaints of headache., Prediction: 
The doctor has discharged me from the hospital due to my complaints of headaches.





### Zero shot with instruction only prompting seems to generate spurious results particularly beyond the length necessary. Lets see if this can be fixed with some in-context learning examples in addition to the instruction.

In the above scenario, only 2/5 are correct and model can be seen rambling along in a couple.

In [21]:
instruction = 'Convert the doctor note into patient readable format without abbreviations.\n'
example1 = 'The pt exhibits symptoms of CAD -> The patient exhibits symptoms of Coronary Artery Disease\n'
example2 = 'Pt has prior history of hbp -> Patient has prior history of high blood pressure\n'

# Few-shot in-context learning with instruction prompt and additional few shot examples
context = instruction + example1 + example2
results = get_predictions(df_test, model, context)
print_results(results)

5it [00:07,  1.58s/it]


Text: The pt has lbp, Target: The patient has lower back pain., Prediction: The patient has low back pain

Text: The pt is a 30 y/o m, Target: The patient is a 30 year old male., Prediction: The patient is a 30 year old male

Text: VSS after Tx, Target: Vital signs stable after treatment., Prediction: Vital Signs Stable after treatment

Text: c/o gi pain, Target: complains of gastrointestinal pain., Prediction: complains of gastrointestinal pain

The patient exhibits symptoms of Coronary Artery Disease. The patient

Text: tx d/c due to c/o h/a, Target: Treatment discontinued due to compaints of headache., Prediction: Treatment was discontinued due to complaint of headache





### After in-context learning with only a couple of examples, the model seems to have improved quite significantly. 

Now the model is able to get 4/5 out of 5 examples correct and the rambling / tendency to generate long sentences seems to have reduced significantly.

In [22]:
instruction = 'Convert the doctor note into patient readable format without abbreviations.\n'
example1 = 'Q: The pt exhibits symptoms of CAD -> A: The patient exhibits symptoms of Coronary Artery Disease. Explanation: Here pt stands for patient and CAD stands for Coronary Artery Disease.\n'
example2 = 'Q: Pt has prior history of hbp -> A: Patient has prior history of high blood pressure. Explanation: Here pt stands for patient and hbp stands for high blood pressure.\n'

# Few-shot in-context learning with instruction prompt and additional few shot examples with chain of thought prompting
context = instruction + example1 + example2
results = get_predictions(df_test, model, context, temperature=0.4, max_tokens=30)
print_results(results)

5it [00:06,  1.37s/it]


Text: The pt has lbp, Target: The patient has lower back pain., Prediction: A: The patient has low back pain. Explanation: Here pt stands for patient and lbp stands for low back pain.

Text: The pt is a 30 y/o m, Target: The patient is a 30 year old male., Prediction: The patient is a 30 year old male. Explanation: Here pt stands for patient, y/o stands for years old, and m stands for

Text: VSS after Tx, Target: Vital signs stable after treatment., Prediction: A: VSS after treatment. Explanation: Here Tx stands for treatment.

Text: c/o gi pain, Target: complains of gastrointestinal pain., Prediction: Complains of gastrointestinal pain

Text: tx d/c due to c/o h/a, Target: Treatment discontinued due to compaints of headache., Prediction: Treatment was discontinued due to complaint of headache.





### Chain of thought prompting doesn't seem to be super useful here. Perhaps this is due to the simple nature of the task.

However, it is interesting to see incoherent explanations co-relate with in-correct model outputs. This suggests one mechanism to consider using models in medical applications might be to check whether the explanations / proof of work are coherent or not.

### Now let's see if finetuning a smaller model can help with improving the performance. 

Trying finetuning with a modest number of examples (20) as this is primarily for demonstration purposes.

Using Weights & Biases integration for model training.

In [23]:
run = wandb.init(project='Medical GPT-3')

# artifact = run.use_artifact('/content/gdrive/MyDrive/med-gpt3-data/med_abbreviations.csv', type='raw_dataset')
# artifact_dir = artifact.download()+"/data.csv"

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

In [24]:
#Shuffling the dataset with fixed seed

df = pd.read_csv('gdrive/MyDrive/med-gpt3-data/med_abbreviations.csv')
ds = df.sample(frac=1.0, random_state=0)
ds.rename(columns={'Text': 'Prompt'}, inplace=True)
ds.to_csv("data.csv")
ds.head()

Unnamed: 0,ID,Prompt,Completion
2,3,VSS after Tx,Vital signs stable after treatment.
28,29,oe pt recommended to be placed under obs for 2...,On examination patient recommended to placed u...
13,14,PA will recommend next steps,Physician's Assistant will recommend next steps.
10,11,Nacl low in pt,Sodium chloride low in patient.
26,27,CXR required to assess tx,Chest x-ray required to assess treatment.


### Using OpenAI tools to preprocess the data

In [25]:
!openai tools fine_tunes.prepare_data -f data.csv

Analyzing...

- Based on your file extension, your file is formatted as a CSV file
- Your file contains 30 prompt-completion pairs. In general, we recommend having at least a few hundred examples. We've found that performance tends to linearly increase for every doubling of the number of examples
- The `prompt` column/key should be lowercase
- The `completion` column/key should be lowercase
- The input file should contain exactly two columns/keys per row. Additional columns/keys present are: ['Unnamed: 0', 'ID']
- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See https://beta.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty
- Your data does not contain a common ending at the end of your completions. Havin

### Splitting the data into train and val sets

In [26]:
# The dataset has 30 examples. We will use 20 for training and 10 for testing

!head -n 20 data_prepared.jsonl > train.jsonl
!tail -n 10  data_prepared.jsonl > valid.jsonl


In [27]:
wandb.finish()

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

### GPT-3 fine-tuning 'ada' hyper-parameters definition.

In [28]:
model = 'ada'  # can be ada, babbage or curie
n_epochs = 10
batch_size = 4
learning_rate_multiplier = 0.05
prompt_loss_weight = 0.1

### Model training

In [29]:
!openai api fine_tunes.create \
    -t train.jsonl \
    -v valid.jsonl \
    -m $model \
    --n_epochs $n_epochs \
    --batch_size $batch_size \
    --learning_rate_multiplier $learning_rate_multiplier \
    --prompt_loss_weight $prompt_loss_weight

Found potentially duplicated files with name 'train.jsonl', purpose 'fine-tune' and size 1910 bytes
file-v1pWf4Rzd6epNVT5eSwehFXZ
file-yJOD9irAl3PoA6Ip4bdF6xnd
Enter file ID to reuse an already uploaded file, or an empty string to upload this file anyway: 
Upload progress: 100% 1.91k/1.91k [00:00<00:00, 818kit/s]
Uploaded file from train.jsonl: file-bqnPlyL9koVLWf0sCOUsxXFC
Found potentially duplicated files with name 'valid.jsonl', purpose 'fine-tune' and size 995 bytes
file-LPVmd5FAgyeEiuEV42Pd5Enb
file-Ex2NITSRAYxTYkkqotDvNmHT
Enter file ID to reuse an already uploaded file, or an empty string to upload this file anyway: 
Upload progress: 100% 995/995 [00:00<00:00, 1.03Mit/s]
Uploaded file from valid.jsonl: file-YPL77TwQe8YwFRxleUfLGJNm
Created fine-tune: ft-eP8h1UV1obmolYmIPQFniTXA
Streaming events until fine-tuning is complete...

(Ctrl-C will interrupt the stream, but not cancel the fine-tune)
[2022-05-18 04:12:45] Created fine-tune: ft-eP8h1UV1obmolYmIPQFniTXA
[2022-05-18 04:15:

In [30]:
!openai api fine_tunes.follow -i ft-eP8h1UV1obmolYmIPQFniTXA

[2022-05-18 04:12:45] Created fine-tune: ft-eP8h1UV1obmolYmIPQFniTXA
[2022-05-18 04:15:00] Fine-tune costs $0.00
[2022-05-18 04:15:00] Fine-tune enqueued. Queue number: 0
[2022-05-18 04:15:03] Fine-tune started
[2022-05-18 04:15:21] Completed epoch 1/10
[2022-05-18 04:15:24] Completed epoch 2/10
[2022-05-18 04:15:26] Completed epoch 3/10
[2022-05-18 04:15:29] Completed epoch 4/10
[2022-05-18 04:15:31] Completed epoch 5/10
[2022-05-18 04:15:34] Completed epoch 6/10
[2022-05-18 04:15:37] Completed epoch 7/10
[2022-05-18 04:15:39] Completed epoch 8/10
[2022-05-18 04:15:42] Completed epoch 9/10
[2022-05-18 04:15:45] Completed epoch 10/10
[2022-05-18 04:16:09] Uploaded model: ada:ft-personal-2022-05-18-04-16-07
[2022-05-18 04:16:12] Uploaded result file: file-pNSHjUMi6XVoJPuaHVbo5IAy
[2022-05-18 04:16:12] Fine-tune succeeded

Job complete! Status: succeeded 🎉
Try out your fine-tuned model:

openai api completions.create -m ada:ft-personal-2022-05-18-04-16-07 -p <YOUR_PROMPT>


## Sync fine-tune jobs to Weights & Biases

Log fine-tuning runs.

In [31]:
!openai wandb sync --project "Medical GPT-3" 

[34m[1mwandb[0m: Currently logged in as: [33mnatviv[0m ([33mnatviv-gpt[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.12.16
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/content/wandb/run-20220518_041708-ft-eP8h1UV1obmolYmIPQFniTXA[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mft-eP8h1UV1obmolYmIPQFniTXA[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/natviv-gpt/Medical%20GPT-3[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/natviv-gpt/Medical%20GPT-3/runs/ft-eP8h1UV1obmolYmIPQFniTXA[0m
[34m[1mwandb[0m: Waiting for W&B process to finish... [32m(success).[0m
[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:             elapsed_examples ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


## Run inference on validation examples

Run some predictions on a few validation samples.

In [32]:
# create eval job
run = wandb.init(project='Medical GPT-3', job_type='eval')
entity = wandb.run.entity

[34m[1mwandb[0m: Currently logged in as: [33mnatviv[0m ([33mnatviv-gpt[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [33]:
# choose a fine-tuned model
artifact_job = run.use_artifact(f'{entity}/Medical GPT-3/fine_tune_details:latest', type='fine_tune_details')
artifact_job.metadata

wandb.config.update({k:artifact_job.metadata[k] for k in ['fine_tuned_model', 'model', 'hyperparams']})
fine_tuned_model = artifact_job.metadata['fine_tuned_model']
fine_tuned_model

'ada:ft-personal-2022-05-18-04-16-07'

Loading validation data as dataframe

In [34]:
df = pd.read_json("valid.jsonl", orient='records', lines=True)
df.head()

Unnamed: 0,prompt,completion
0,"PTA, vss were not normal ->","Prior to admission, vital signs were not normal."
1,UTI as primary cause of fever ->,Urinary tract infection as primary cause of f...
2,The pt has recurring history of jt pain ->,The patient has recurring history of joint pain.
3,The pt requires lbp pt ->,The patient requires lower back pain physical...
4,Pt dc ama ->,Patient discharged against medical advice.


Perform inference on 10 validation examples. 

In [35]:
n_samples = 10
df = df.iloc[:n_samples]

In [37]:
results = []
openai.api_key = os.getenv("OPENAI_API_KEY")

for _, row in tqdm(df.iterrows()):
    prompt = row['prompt']
    res = openai.Completion.create(model=fine_tuned_model, prompt=prompt, max_tokens=30, stop=[" END"])
    completion = res['choices'][0]['text']
    completion = completion[1:]       # remove initial space
    prompt = prompt[:-3]              # remove " ->"
    target = row['completion'][1:]  # remove initial space
    results.append(f"Prompt: {prompt}, Target: {target}, Prediction: {completion}")

print_results(results)

10it [00:06,  1.45it/s]


Prompt: PTA, vss were not normal, Target: Prior to admission, vital signs were not normal., Prediction: Train skills abilities test were high normal. Physically active was normal. Was fit was normal. Social skills skills normal. Knowledge of language was normal.

Prompt: UTI as primary cause of fever, Target: Urinary tract infection as primary cause of fever., Prediction: infection as primary cause of fever.

Moderate injury as primary cause of fever.

Rehabilitation of fever.

Hyperther

Prompt: The pt has recurring history of jt pain, Target: The patient has recurring history of joint pain., Prediction: The physician has repeated history of subjective pain and it flashes intermittently intermittently (2). Can the patient be treated with non-steroidal anti-

Prompt: The pt requires lbp pt, Target: The patient requires lower back pain physical therapy., Prediction: arge patient required. Please tell nurse. (no staff present) Please direct exhausted patient to facility. (staff dont exi




### Conclusion - smaller Ada model seems not so useful in the generation task and the results are a bit all over the place.

A more expansive hyper parameter sweep might help fix this. Another option is to try fine-tuning with the larger 'curie' or 'da-vinci' models for better results 

Create and log a W&B Table to explore, query & compare model predictions if helpful.

In [38]:
# prediction_table = wandb.Table(columns=['prompt', 'target', 'completion'], data=data)
# wandb.log({'predictions': prediction_table})
wandb.finish()


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

###Now let's try finetuning with the curie model instead to see if it might help###

In [39]:
model = 'curie'  # can be ada, babbage or curie
n_epochs = 10
batch_size = 4
learning_rate_multiplier = 0.05
prompt_loss_weight = 0.1

In [41]:
!openai api fine_tunes.create \
    -t train.jsonl \
    -v valid.jsonl \
    -m $model \
    --n_epochs $n_epochs \
    --batch_size $batch_size \
    --learning_rate_multiplier $learning_rate_multiplier \
    --prompt_loss_weight $prompt_loss_weight


Found potentially duplicated files with name 'train.jsonl', purpose 'fine-tune' and size 1910 bytes
file-yJOD9irAl3PoA6Ip4bdF6xnd
file-v1pWf4Rzd6epNVT5eSwehFXZ
file-bqnPlyL9koVLWf0sCOUsxXFC
Enter file ID to reuse an already uploaded file, or an empty string to upload this file anyway: 
Upload progress: 100% 1.91k/1.91k [00:00<00:00, 752kit/s]
Uploaded file from train.jsonl: file-hLDjkwyB4ou3Amdq2GgEzs5x
Found potentially duplicated files with name 'valid.jsonl', purpose 'fine-tune' and size 995 bytes
file-LPVmd5FAgyeEiuEV42Pd5Enb
file-YPL77TwQe8YwFRxleUfLGJNm
file-Ex2NITSRAYxTYkkqotDvNmHT
Enter file ID to reuse an already uploaded file, or an empty string to upload this file anyway: 
Upload progress: 100% 995/995 [00:00<00:00, 1.03Mit/s]
Uploaded file from valid.jsonl: file-zyAlFBzfJz6xu3eYpoFIu8UW
Created fine-tune: ft-O2QU3ssy2dxUFjhst2VRqrWE
Streaming events until fine-tuning is complete...

(Ctrl-C will interrupt the stream, but not cancel the fine-tune)
[2022-05-18 04:24:23] Creat

In [42]:
!openai api fine_tunes.follow -i ft-O2QU3ssy2dxUFjhst2VRqrWE

!openai wandb sync --project "Medical GPT-3" 

# create eval job
run = wandb.init(project='Medical GPT-3', job_type='eval')
entity = wandb.run.entity

# choose a fine-tuned model
artifact_job = run.use_artifact(f'{entity}/Medical GPT-3/fine_tune_details:latest', type='fine_tune_details')
artifact_job.metadata

wandb.config.update({k:artifact_job.metadata[k] for k in ['fine_tuned_model', 'model', 'hyperparams']})
fine_tuned_model = artifact_job.metadata['fine_tuned_model']
fine_tuned_model

[2022-05-18 04:24:23] Created fine-tune: ft-O2QU3ssy2dxUFjhst2VRqrWE
[2022-05-18 04:26:20] Fine-tune costs $0.01
[2022-05-18 04:26:21] Fine-tune enqueued. Queue number: 0
[2022-05-18 04:26:23] Fine-tune started
[2022-05-18 04:27:16] Completed epoch 1/10
[2022-05-18 04:27:20] Completed epoch 2/10
[2022-05-18 04:27:23] Completed epoch 3/10
[2022-05-18 04:27:27] Completed epoch 4/10
[2022-05-18 04:27:30] Completed epoch 5/10
[2022-05-18 04:27:34] Completed epoch 6/10
[2022-05-18 04:27:38] Completed epoch 7/10
[2022-05-18 04:27:41] Completed epoch 8/10
[2022-05-18 04:27:45] Completed epoch 9/10
[2022-05-18 04:27:48] Completed epoch 10/10
[2022-05-18 04:28:11] Uploaded model: curie:ft-personal-2022-05-18-04-28-09
[2022-05-18 04:28:14] Uploaded result file: file-OLYpH04cO0AARMQpsabxY3Iu
[2022-05-18 04:28:14] Fine-tune succeeded

Job complete! Status: succeeded 🎉
Try out your fine-tuned model:

openai api completions.create -m curie:ft-personal-2022-05-18-04-28-09 -p <YOUR_PROMPT>
[34m[1mwa

'curie:ft-personal-2022-05-18-04-28-09'

In [43]:
# Perform inference
results = []
openai.api_key = os.getenv("OPENAI_API_KEY")

for _, row in tqdm(df.iterrows()):
    prompt = row['prompt']
    res = openai.Completion.create(model=fine_tuned_model, prompt=prompt, max_tokens=30, stop=[" END"])
    completion = res['choices'][0]['text']
    completion = completion[1:]       # remove initial space
    prompt = prompt[:-3]              # remove " ->"
    target = row['completion'][1:]  # remove initial space
    results.append(f"Prompt: {prompt}, Target: {target}, Prediction: {completion}")

print_results(results)

10it [00:22,  2.25s/it]


Prompt: PTA, vss were not normal, Target: Prior to admission, vital signs were not normal., Prediction: PTA, vs were not normal. Patient was diagnosed with possible ventricular tachycardia. Vital signs were not abnormal. The patient was diagnosed

Prompt: UTI as primary cause of fever, Target: Urinary tract infection as primary cause of fever., Prediction: CHRONIC CONDITION CAUSED HYPER-CATE-GASTRIC ISCHEMIA as the primary causes of fever.

Prompt: The pt has recurring history of jt pain, Target: The patient has recurring history of joint pain., Prediction: Patient has recurrent history of jt pain. 0.4. Decreased Range of Motion. Patient has decreased range of motion. 9. Treatment Plan

Prompt: The pt requires lbp pt, Target: The patient requires lower back pain physical therapy., Prediction: Physical Therapy required. Patient requires physical therapy. ( Codeset : G0702 Amelioration of Error ) Medication. Patient required Physio

Prompt: Pt dc ama, Target: Patient discharged against 




### Again the results after finetuning even with the curie model are not very good. This requires further probing and analysis. Perhaps the dataset size is not enough or more longer training / more hyper parameter search is needed.

### Next, we will prototype a simple conversational primary care agent with GPT-3

<figure>
<center>
<img src='https://singularityhub.com/wp-content/uploads/2019/02/doctor-robot-modern-future-health-artificial-intelligence-shutterstock-1072509989-1068x601.jpg' style="width:10px;height:20px" />
<figcaption></figcaption></center>
</figure>


### Prompt for the conversational AI model
This can do with more prompt engineering but seems to work decently for starters.

In [45]:

instruction = 'In the following interactions, you are supposed to interact with a user who may have a primary care concern and provide them with support. This support can include triaging, providing a differential diagnosis if certain or help with scheduling a doctor appointment. Please find some example conversations below. When prompted, you are to only provide the doctor\'s response.'

# Provide a simple example of doctor-patient conversation. Not sure how long of a context GPT-3 can handle effectively at the moment
example1 = "Doctor: How can I help? Patient: I have a rash on my skin. Doctor: Anything else? Patient: No Doctor:Is it hurting Patient: Yes, It is swollen and itchy Doctor: Ok, I will refer you to a specialist dermatologist."
example2 = "Doctor: How can I help? Patient: I have fever and headache. Doctor: For how long, have you had these symptoms? Patient: 1 day. Doctor: Please come over to the clinic if possible so we can take a closer look."

prompt = '\n'.join([instruction, example1, example2])


### Using the above prompt actually causes GPT-3 to generate the full dialogue instead of the next turn only

Reducing max_tokens to prevent the model from rambling.

In [None]:
openai.api_key = os.getenv("OPENAI_API_KEY")

class DoctorAgent(object):
  ''' Simple Doctor Agent class '''
  def __init__(self):
    self._model = 'text-davinci-002'
    self._doctor_prefix = 'Doctor: '
    self._patient_prefix = 'Patient: '
    self._prompt = ''
    self._last_doctor_response = 'Hi, I am Dr. AI. How can I help you today? Please say Done to end conversation'
    self._last_patient_response = ''
    self._max_tokens=100
    self._temperature=0.2
    self._stop=[" END"]
    self._stop_agent=False
    self._num_turns=0

  def send_doctor_response(self):
    ''' Print last doctor response '''
    print(self._doctor_prefix + '\n' + self._last_doctor_response)
    return
  
  def get_patient_input(self):
    ''' Get next patient input '''
    self._last_patient_response = input(f'{self._patient_prefix}\n')
    if self._last_patient_response in ['Done', 'done']:
      self._stop_agent=True
    self._num_turns += 1
    return

  def stop_conversation(self):
    ''' Check if we should proceed to next turn or not. Limit to 10 turns '''
    return self._stop_agent or self._num_turns > 10 

  def update_prompt(self, prompt=None):
    ''' Update prompt with last conversation turn '''
    if prompt is not None:
      self._prompt = prompt
    else:
      self._prompt = ' '.join([self._prompt, self._last_doctor_response, self._patient_prefix, self._last_patient_response]) 

  def run_model(self):
    ''' Run completion model and update last doctor response '''
    self._prompt = ' '.join([self._prompt, self._doctor_prefix]) 
    model_response = openai.Completion.create(model=self._model, 
                                              prompt=self._prompt, 
                                              max_tokens=self._max_tokens, 
                                              temperature=self._temperature, 
                                              stop=self._stop)
    self._last_doctor_response = model_response['choices'][0]['text']
    self._last_doctor_response = self._last_doctor_response[1:] # remove initial space

In [53]:
# Create a doctor agent and run it
doctor_agent = DoctorAgent()
doctor_agent.update_prompt(prompt)

while(True):
  doctor_agent.send_doctor_response()
  doctor_agent.get_patient_input()
  if (doctor_agent.stop_conversation()):
    break
  doctor_agent.update_prompt()
  doctor_agent.run_model()

Doctor: 
Hi, I am Dr. AI. How can I help you today? Please say Done to end conversation
Patient: 
I am feeling very tired
Doctor: 
Anything else?
Patient: 
I am not feeling hungry
Doctor: 
Ok, I will refer you to a specialist.
Patient: 
Can you help with an appoitment?
Doctor: 
Yes, I can help you schedule an appointment with your primary care physician.
Patient: 
At what time
Doctor: 
The earliest available appointment is at 2:00 PM. Would that work for you?
Patient: 
yes, where are they located?
Doctor: 
They are located at 123 Main Street.
Patient: 
can you also provide a referral letter?
Doctor: 
Yes, I can provide you with a referral letter.
Patient: 
thanks
Doctor: 
You're welcome.
Patient: 
done


As you can see, the above conversation is pretty realistic and a good starting point. The model has a tendency to ramble but this can be fixed hopefully with finetuning on some realistic primary care conversations.

Other things to prototype with the API are similar notes / patient records retrieval (this has use important applications in clinican settings) using embeddings and search. 