<a href="https://colab.research.google.com/github/natviv/med-gpt3/blob/main/Medical_GPT_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Medical GPT-3. 

This colab features an exploration of potential medical applications with GPT-3.

This is built using Open AI API's integration with Weights and Biases (W&B).

## API key setup

In [None]:
# API key credentials
%env OPENAI_API_KEY=

## Install dependencies

In [4]:
!pip install --upgrade openai wandb



In [5]:
# Setup imports

import openai
import wandb
from pathlib import Path
import pandas as pd
import numpy as np
import json
from tqdm import tqdm

In [6]:
run = wandb.init(project='Medical GPT-3', job_type="dataset_preparation")

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


### Connect to dataset in Google Drive

In [12]:
from google.colab import drive 
drive.mount('/content/gdrive/')

Mounted at /content/gdrive/


### Read data into pandas csv

In [31]:
import pandas as pd 
df=pd.read_csv('gdrive/MyDrive/med-gpt3-data/data.csv')

### Analyze data

In [43]:
print(f"Number of rows is {len(df)} and number of columns is {len(df.columns)}")
for index, row in df.iterrows():
    print(f"ID: {row['ID']}, Text: {row['Text']}, Completion: {row['Completion']}")

df_test = df.iloc[:5,:]
df_train = df.iloc[5:,:]

Number of rows is 30 and number of columns is 3
ID: 1, Text: The pt has lbp, Completion: The patient has lower back pain
ID: 2, Text: The pt is a 30 y/o m, Completion: The patient is a 30 year old male
ID: 3, Text: VSS after Tx, Completion: Vital signs stable after treatment
ID: 4, Text: c/o gi pain, Completion: complains of gastrointestinal pain
ID: 5, Text: tx d/c due to c/o h/a, Completion: Treatment discontinued due to compaints of headache
ID: 6, Text: abx dosage recommended, Completion: Antibiotics dosage recommended
ID: 7, Text: thx required, Completion: therapy required
ID: 8, Text: The pt requires lbp pt, Completion: The patient requires lower back pain physical therapy
ID: 9, Text: MBC likely impaired due to covid, Completion: Maximum breathing capacity likely impaired due to covid
ID: 10, Text: The pt has recurring history of jt pain, Completion: The patient has recurring history of joint pain
ID: 11, Text: Nacl low in pt, Completion: Sodium chloride low in patient
ID: 12, T

### Use completion API to see how the 'text-davinci-002' model performs. Try with the following approaches:

* Zero shot with only instruction
* Few shot in-context learning 
* Few shot with chain of thought prompting -> https://arxiv.org/abs/2201.11903

In [44]:
import os
openai.api_key = os.getenv("OPENAI_API_KEY")

# This can be made more efficient and converted into a few batch calls
def get_predictions(df, model, context, temperature=0.1, max_tokens=20, use_qa_prefix=False):
  results = []
  for _, row in tqdm(df.iterrows()):
      input = row['Text'] if not use_qa_prefix else row['Text'] + 'Q:'
      prompt = context + input + ' ->'
      res = openai.Completion.create(model=model, 
                                     prompt=prompt, 
                                     max_tokens=max_tokens, 
                                     temperature = temperature, 
                                     stop=[" END"])
      completion = res['choices'][0]['text']
      completion = completion[1:] # remove initial space
      results.append(f"Text: {row['Text']}, Target: {row['Completion']}, Prediction: {completion}")
  return results

def print_results(results): 
  for row in results:
    print(f"\n{row}")

model = 'text-davinci-002'
# Zero-shot with only instruction prompt
instruction = 'Convert the doctor note into patient readable format without abbreviations.\n'
results = get_predictions(df_test, model, instruction)
print_results(results)


5it [00:07,  1.51s/it]


Text: The pt has lbp, Target: The patient has lower back pain, Prediction: The patient has low back pain.

Text: The pt is a 30 y/o m, Target: The patient is a 30 year old male, Prediction: The patient is a 30 year old male.

The patient is a 30 year old male.

Text: VSS after Tx, Target: Vital signs stable after treatment, Prediction: VSS (visual acuity) after treatment

The patient's visual acuity was 20/

Text: c/o gi pain, Target: complains of gastrointestinal pain, Prediction: complaining of gastrointestinal pain

Text: tx d/c due to c/o h/a, Target: Treatment discontinued due to compaints of headache, Prediction: 
The doctor has discharged me from the hospital due to my complaints of headaches.





### Zero shot with instruction only prompting seems to generate spurious results particularly beyond the length necessary. Lets see if this can be fixed with some in-context learning examples in addition to the instruction.

In the above scenario, only 2/5 are correct and model can be seen rambling along in a couple.

In [45]:
instruction = 'Convert the doctor note into patient readable format without abbreviations.\n'
example1 = 'The pt exhibits symptoms of CAD -> The patient exhibits symptoms of Coronary Artery Disease\n'
example2 = 'Pt has prior history of hbp -> Patient has prior history of high blood pressure\n'

# Few-shot in-context learning with instruction prompt and additional few shot examples
context = instruction + example1 + example2
results = get_predictions(df_test, model, context)
print_results(results)

5it [00:06,  1.34s/it]


Text: The pt has lbp, Target: The patient has lower back pain, Prediction: The patient has low back pain

Text: The pt is a 30 y/o m, Target: The patient is a 30 year old male, Prediction: The patient is a 30 year old male

Text: VSS after Tx, Target: Vital signs stable after treatment, Prediction: Vital Signs Stable after treatment

Text: c/o gi pain, Target: complains of gastrointestinal pain, Prediction: complains of gastrointestinal pain

The patient exhibits symptoms of Coronary Artery Disease. The patient

Text: tx d/c due to c/o h/a, Target: Treatment discontinued due to compaints of headache, Prediction: Treatment was discontinued due to complaint of headache





### After in-context learning with only a couple of examples, the model seems to have improved quite significantly. 

Now the model is able to get 4/5 out of 5 examples correct and the rambling / tendency to generate long sentences seems to have reduced significantly.

In [48]:
instruction = 'Convert the doctor note into patient readable format without abbreviations.\n'
example1 = 'Q: The pt exhibits symptoms of CAD -> A: The patient exhibits symptoms of Coronary Artery Disease. Explanation: Here pt stands for patient and CAD stands for Coronary Artery Disease.\n'
example2 = 'Q: Pt has prior history of hbp -> A: Patient has prior history of high blood pressure. Explanation: Here pt stands for patient and hbp stands for high blood pressure.\n'

# Few-shot in-context learning with instruction prompt and additional few shot examples with chain of thought prompting
context = instruction + example1 + example2
results = get_predictions(df_test, model, context, temperature=0.4, max_tokens=30)
print_results(results)

5it [00:07,  1.57s/it]


Text: The pt has lbp, Target: The patient has lower back pain, Prediction: The patient has low back pain. Explanation: Here pt stands for patient and lbp stands for low back pain.

Text: The pt is a 30 y/o m, Target: The patient is a 30 year old male, Prediction: The patient is a 30 year old male. Explanation: Here pt stands for patient, y/o stands for years old, and m stands for

Text: VSS after Tx, Target: Vital signs stable after treatment, Prediction: A: Patient had a heart attack after treatment. Explanation: Here Tx stands for treatment and VSS stands for heart attack.

Text: c/o gi pain, Target: complains of gastrointestinal pain, Prediction: complaining of gastrointestinal pain

Text: tx d/c due to c/o h/a, Target: Treatment discontinued due to compaints of headache, Prediction: Treatment was discontinued due to complaint of headache.





### Chain of thought prompting doesn't seem to be super useful here. Perhaps this is due to the simple nature of the task.

However, it is interesting to see inoherent explanations co-relate with in-correct model outputs. This suggests one mechanism to consider using models in medical applications might be to check whether the explanations / proof of work are coherent or not.

### Now let's see if finetuning a smaller model can help with improving the performance. 

I am going to try finetuning with a modest number of examples (20) as this is primarily for demonstration purposes.

Using Weights & Biases integration for model training.

In [51]:
run = wandb.init(project='Medical GPT-3')

# artifact = run.use_artifact('/content/gdrive/MyDrive/med-gpt3-data/data.csv', type='raw_dataset')
# artifact_dir = artifact.download()+"/data.csv"

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

In [59]:
#Shuffling the dataset with fixed seed

df = pd.read_csv('gdrive/MyDrive/med-gpt3-data/data.csv')
ds = df.sample(frac=1.0, random_state=0)
ds.rename(columns={'Text': 'Prompt'}, inplace=True)
ds.to_csv("data.csv")
ds.head()

Unnamed: 0,ID,Prompt,Completion
2,3,VSS after Tx,Vital signs stable after treatment
28,29,oe pt recommended to be placed under obs for 2...,"On examination, patient recommended to placed ..."
13,14,PA will recommend next steps,Physician's Assistant will recommend next steps
10,11,Nacl low in pt,Sodium chloride low in patient
26,27,CXR required to assess tx,Chest x-ray required to assess treatment


### Using OpenAI tools to preprocess the data

In [61]:
!openai tools fine_tunes.prepare_data -f data.csv

Analyzing...

- Based on your file extension, your file is formatted as a CSV file
- Your file contains 30 prompt-completion pairs. In general, we recommend having at least a few hundred examples. We've found that performance tends to linearly increase for every doubling of the number of examples
- The `prompt` column/key should be lowercase
- The `completion` column/key should be lowercase
- The input file should contain exactly two columns/keys per row. Additional columns/keys present are: ['Unnamed: 0', 'ID']
- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See https://beta.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty
- Your data does not contain a common ending at the end of your completions. Havin

### Splitting the data into train and val sets

In [62]:
# The dataset has 30 examples. We will use 20 for training and 10 for testing

!head -n 20 data_prepared.jsonl > train.jsonl
!tail -n 10  data_prepared.jsonl > valid.jsonl


In [63]:
wandb.finish()

### GPT-3 fine-tuning hyper-parameters definition.

In [58]:
model = 'ada'  # can be ada, babbage or curie
n_epochs = 4
batch_size = 4
learning_rate_multiplier = 0.1
prompt_loss_weight = 0.1

### Model training

In [64]:
!openai api fine_tunes.create \
    -t train.jsonl \
    -v valid.jsonl \
    -m $model \
    --n_epochs $n_epochs \
    --batch_size $batch_size \
    --learning_rate_multiplier $learning_rate_multiplier \
    --prompt_loss_weight $prompt_loss_weight

Upload progress:   0% 0.00/1.91k [00:00<?, ?it/s]Upload progress: 100% 1.91k/1.91k [00:00<00:00, 2.68Mit/s]
Uploaded file from train.jsonl: file-yJOD9irAl3PoA6Ip4bdF6xnd
Upload progress: 100% 995/995 [00:00<00:00, 1.76Mit/s]
Uploaded file from valid.jsonl: file-LPVmd5FAgyeEiuEV42Pd5Enb
Created fine-tune: ft-0aDcGNgNcG0428ffEg98wQLJ
Streaming events until fine-tuning is complete...

(Ctrl-C will interrupt the stream, but not cancel the fine-tune)
[2022-05-16 02:05:38] Created fine-tune: ft-0aDcGNgNcG0428ffEg98wQLJ

Stream interrupted (client disconnected).
To resume the stream, run:

  openai api fine_tunes.follow -i ft-0aDcGNgNcG0428ffEg98wQLJ



In [65]:
!openai api fine_tunes.follow -i ft-0aDcGNgNcG0428ffEg98wQLJ

[2022-05-16 02:05:38] Created fine-tune: ft-0aDcGNgNcG0428ffEg98wQLJ
[2022-05-16 02:12:07] Fine-tune costs $0.00
[2022-05-16 02:12:08] Fine-tune enqueued. Queue number: 0
[2022-05-16 02:12:10] Fine-tune started
[2022-05-16 02:12:28] Completed epoch 1/4
[2022-05-16 02:12:31] Completed epoch 2/4
[2022-05-16 02:12:33] Completed epoch 3/4
[2022-05-16 02:12:36] Completed epoch 4/4
[2022-05-16 02:12:55] Uploaded model: ada:ft-personal-2022-05-16-02-12-53
[2022-05-16 02:12:58] Uploaded result file: file-UayQnPtF3tbXqyq2mFGwoZ9d
[2022-05-16 02:12:58] Fine-tune succeeded

Job complete! Status: succeeded 🎉
Try out your fine-tuned model:

openai api completions.create -m ada:ft-personal-2022-05-16-02-12-53 -p <YOUR_PROMPT>


## Sync fine-tune jobs to Weights & Biases

Log fine-tuning runs.

In [66]:
!openai wandb sync --project "Medical GPT-3" 

[34m[1mwandb[0m: Currently logged in as: [33mnatviv[0m ([33mnatviv-gpt[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.12.16
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/content/wandb/run-20220516_022438-ft-tVmTevSFiUgcdJdeSAgokoQA[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mft-tVmTevSFiUgcdJdeSAgokoQA[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/natviv-gpt/Medical%20GPT-3[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/natviv-gpt/Medical%20GPT-3/runs/ft-tVmTevSFiUgcdJdeSAgokoQA[0m
[34m[1mwandb[0m: Waiting for W&B process to finish... [32m(success).[0m
[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:             elapsed_examples ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


## Run inference on validation examples

Run some predictions on a few validation samples.

In [67]:
# create eval job
run = wandb.init(project='Medical GPT-3', job_type='eval')
entity = wandb.run.entity

[34m[1mwandb[0m: Currently logged in as: [33mnatviv[0m ([33mnatviv-gpt[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [68]:
# choose a fine-tuned model
artifact_job = run.use_artifact(f'{entity}/Medical GPT-3/fine_tune_details:latest', type='fine_tune_details')
artifact_job.metadata

wandb.config.update({k:artifact_job.metadata[k] for k in ['fine_tuned_model', 'model', 'hyperparams']})
fine_tuned_model = artifact_job.metadata['fine_tuned_model']
fine_tuned_model

'ada:ft-personal-2022-05-16-02-12-53'

Loading validation data as dataframe

In [69]:
df = pd.read_json("valid.jsonl", orient='records', lines=True)
df.head()

Unnamed: 0,prompt,completion
0,"PTA, vss were not normal ->","Prior to admission, vital signs were not normal."
1,UTI as primary cause of fever ->,Urinary tract infection as primary cause of f...
2,The pt has recurring history of jt pain ->,The patient has recurring history of joint pain.
3,The pt requires lbp pt ->,The patient requires lower back pain physical...
4,Pt dc ama ->,Patient discharged against medical advice.


Perform inference on 10 validation examples. 

In [70]:
n_samples = 10
df = df.iloc[:n_samples]

In [71]:
results = []
import os
openai.api_key = os.getenv("OPENAI_API_KEY")

for _, row in tqdm(df.iterrows()):
    prompt = row['prompt']
    res = openai.Completion.create(model=fine_tuned_model, prompt=prompt, max_tokens=30, stop=[" END"])
    completion = res['choices'][0]['text']
    completion = completion[1:]       # remove initial space
    prompt = prompt[:-3]              # remove " ->"
    target = row['completion'][1:-4]  # remove initial space and "END"
    results.append(f"Prompt: {prompt}, Target: {target}, Prediction: {completion}")

print_results(results)

10it [00:17,  1.73s/it]


Prompt: PTA, vss were not normal, Target: Prior to admission, vital signs were not nor, Prediction: Delayed phase analysis Voluntary action test was not normal.



PATI, scales indicated normal physiological tests.



DAL

Prompt: UTI as primary cause of fever, Target: Urinary tract infection as primary cause of fe, Prediction: Statin User As Principal Cause of Fever.

104. When treating primary illness with a grant of continuous use of treatment.

105.

Prompt: The pt has recurring history of jt pain, Target: The patient has recurring history of joint p, Prediction: patient has recurrent jt pain.

Intervention

Treatment Plan

Instruct patient to adhere to prescription medication and instructions. Do not

Prompt: The pt requires lbp pt, Target: The patient requires lower back pain physical ther, Prediction: Trained patient requires lbp.

Memory loss disorder (more common) memory patient required memory loss disorder. Memory loss disorder. Memory loss patient

Prompt: Pt dc ama, Targe




### Smaller Ada model not so useful in the generation task and the results are a bit all over the place.

To try with 'curie' or 'da-vinci' for better results 

Create and log a W&B Table to explore, query & compare model predictions if helpful.

In [None]:
# prediction_table = wandb.Table(columns=['prompt', 'target', 'completion'], data=data)
# wandb.log({'predictions': prediction_table})
wandb.finish()
