## Fine Tuning OpenAI GPT4o-mini model with Clinical Studies Data from Clinical Trials.GOV

In [16]:
import json
import os
import pandas as pd
from pprint import pprint
import openai

client = openai.OpenAI(
    api_key=os.environ.get("OPENAI_API_KEY"),
    organization=os.environ.get("OPENAI_API_ORG_KEY"),
    project=os.environ.get("OPENAI_API_PROJ_KEY"),
)

### Read in Dataset

In [4]:
ctgov_df = pd.read_csv("csv/CTGOV_DATA.csv")

ctgov_df.head()

Unnamed: 0,nct_id,title,phase,submitted_date,condition,facility,country,eligible_crtieria
0,NCT03657433,Intravenous Infusions of Ferumoxytol Compared ...,Phase 3,2022-09-05,Iron Deficiency Anemia of Pregnancy,University of Arizona,United States,Inclusion Criteria:\n\nMaternal age >/= 18\nSi...
1,NCT00078949,Chemotherapy Before Autologous Stem Cell Trans...,Phase 3,2020-03-30,Lymphoma,Rush-Presbyterian-St. Luke's Medical Centre,United States,DISEASE CHARACTERISTICS:\n\nHistologically con...
2,NCT00078949,Chemotherapy Before Autologous Stem Cell Trans...,Phase 3,2020-03-30,Lymphoma,Indiana University Medical Center,United States,DISEASE CHARACTERISTICS:\n\nHistologically con...
3,NCT00078949,Chemotherapy Before Autologous Stem Cell Trans...,Phase 3,2020-03-30,Lymphoma,Hackensack University Medical Center,United States,DISEASE CHARACTERISTICS:\n\nHistologically con...
4,NCT00078949,Chemotherapy Before Autologous Stem Cell Trans...,Phase 3,2020-03-30,Lymphoma,"University of Cincinnati, Barrett Cancer Centre",United States,DISEASE CHARACTERISTICS:\n\nHistologically con...


### Data Preparation

In [41]:
system_message = "You are a helpful clinical trial studies assistant. You are to extract the clinical trial studies provided."


def create_user_message(row):
    return f"Title: {row['title']}\n\nPhase: {row['phase']}\n\nDate: {row['submitted_date']}\n\nFacility: {row['facility']}\n\nCountry: {row['country']}\n\nCriteria: {row['eligible_crtieria']}\n\n"


def prepare_example_conversation(row):
    return {
        "messages": [
            {"role": "system", "content": system_message},
            {"role": "user", "content": create_user_message(row)},
            {"role": "assistant", "content": row["condition"]},
        ]
    }


pprint(prepare_example_conversation(ctgov_df.iloc[0]))

{'messages': [{'content': 'You are a helpful clinical trial studies assistant. '
                          'You are to extract the clinical trial studies '
                          'provided.',
               'role': 'system'},
              {'content': 'Title: Intravenous Infusions of Ferumoxytol '
                          'Compared to Oral Ferrous Sulfate for the Treatment '
                          'of Anemia in Pregnancy\n'
                          '\n'
                          'Phase: Phase 3\n'
                          '\n'
                          'Date: 2022-09-05\n'
                          '\n'
                          'Facility: University of Arizona\n'
                          '\n'
                          'Country: United States\n'
                          '\n'
                          'Criteria: Inclusion Criteria:\n'
                          '\n'
                          'Maternal age >/= 18\n'
                          'Singleton gestation\n'
            

### Data Split for Training

In [18]:
# use the first 100 rows of the dataset for training
training_df = ctgov_df.loc[0:10000]

# apply the prepare_example_conversation function to each row of the training_df
training_data = training_df.apply(prepare_example_conversation, axis=1).tolist()

for example in training_data[:5]:
    print(example)

{'messages': [{'role': 'system', 'content': 'You are a helpful clinical trial studies assistant. You are to extract the clinical trial studies provided.'}, {'role': 'user', 'content': 'Title: Intravenous Infusions of Ferumoxytol Compared to Oral Ferrous Sulfate for the Treatment of Anemia in Pregnancy\n\nPhase: Phase 3\n\nDate: 2022-09-05\n\nFacility: University of Arizona\n\nCountry: United States\n\nCriteria: Inclusion Criteria:\n\nMaternal age >/= 18\nSingleton gestation\n>/=20 weeks gestation, <37 weeks gestation\nHemoglobin <11g/dL and/or hematocrit <33%\nAble to read/speak English or Spanish\n\nExclusion Criteria:\n\nMaternal age <18\nMultiple gestation\n<20 weeks gestation, </= 37 weeks gestation\nHemoglobin >/=11g/dL and/or hematocrit >/=33%\nUnable to read or speak English or Spanish\nIncarcerated patients\n\n'}, {'role': 'assistant', 'content': 'Iron Deficiency Anemia of Pregnancy'}]}
{'messages': [{'role': 'system', 'content': 'You are a helpful clinical trial studies assist

### Data Split for validation

In [19]:
validation_df = ctgov_df.loc[10001:20000]
validation_data = validation_df.apply(
    prepare_example_conversation, axis=1).tolist()

### Save data for conversations

In [13]:
def write_jsonl(data_list: list, filename: str) -> None:
    with open(filename, "w") as out:
        for ddict in data_list:
            jout = json.dumps(ddict) + "\n"
            out.write(jout)

In [20]:
training_file_name = "data/tmp_ctgov_finetune_training.jsonl"
write_jsonl(training_data, training_file_name)

validation_file_name = "data/tmp_ctgov_finetune_validation.jsonl"
write_jsonl(validation_data, validation_file_name)

### Upload files to OpenAI Repo

In [35]:
def upload_file(file_name: str, purpose: str) -> str:
    with open(file_name, "rb") as file_fd:
        response = client.files.create(file=file_fd, purpose=purpose)
    return response.id


training_file_id = upload_file(training_file_name, "fine-tune")
validation_file_id = upload_file(validation_file_name, "fine-tune")

print("Training file ID:", training_file_id)
print("Validation file ID:", validation_file_id)

Training file ID: file-6cyryWZCQ8x2mGUvomW3nsHz
Validation file ID: file-Y89q1s98eQCAIyHIeNIHXohG


### Fine Tuning

In [36]:
MODEL = "gpt-4o-mini-2024-07-18"

response = client.fine_tuning.jobs.create(
    training_file=training_file_id,
    validation_file=validation_file_id,
    model=MODEL,
    suffix="ctgov-ner",
)

job_id = response.id

print("Job ID:", response.id)
print("Status:", response.status)

Job ID: ftjob-U5qYWnItOgyKrnUqsbV3svJS
Status: validating_files


### Check Job Status

In [52]:
response = client.fine_tuning.jobs.retrieve(job_id)

print("Job ID:", response.id)
print("Status:", response.status)
print("Trained Tokens:", response.trained_tokens)

Job ID: ftjob-U5qYWnItOgyKrnUqsbV3svJS
Status: succeeded
Trained Tokens: 13002778


### Track Progress

In [53]:
response = client.fine_tuning.jobs.list_events(job_id)

events = response.data
events.reverse()

for event in events:
    print(event.message)

Step 1524/1539: training loss=0.15
Step 1525/1539: training loss=0.07
Step 1526/1539: training loss=0.09
Step 1527/1539: training loss=0.07
Step 1528/1539: training loss=0.05
Step 1529/1539: training loss=0.14
Step 1530/1539: training loss=0.17
Step 1531/1539: training loss=0.11
Step 1532/1539: training loss=0.09
Step 1533/1539: training loss=0.13
Step 1534/1539: training loss=0.09
Step 1535/1539: training loss=0.12
Step 1536/1539: training loss=0.14
Step 1537/1539: training loss=0.09
Step 1538/1539: training loss=0.14, full validation loss=0.28
Step 1539/1539: training loss=0.09
Checkpoint created at step 769 with Snapshot ID: ft:gpt-4o-mini-2024-07-18:personal:ctgov-ner:9uRw7ny2:ckpt-step-769
Checkpoint created at step 1538 with Snapshot ID: ft:gpt-4o-mini-2024-07-18:personal:ctgov-ner:9uRw7Qxm:ckpt-step-1538
New fine-tuned model created: ft:gpt-4o-mini-2024-07-18:personal:ctgov-ner:9uRw82NO
The job has successfully completed


### After Training, Get Fine Tune Model ID

In [54]:
response = client.fine_tuning.jobs.retrieve(job_id)
fine_tuned_model_id = response.fine_tuned_model

if fine_tuned_model_id is None:
    raise RuntimeError(
        "Fine-tuned model ID not found. Your job has likely not been completed yet."
    )

print("Fine-tuned model ID:", fine_tuned_model_id)

Fine-tuned model ID: ft:gpt-4o-mini-2024-07-18:personal:ctgov-ner:9uRw82NO


### Inference

In [55]:
test_df = ctgov_df.loc[20001:30000]
test_row = test_df.iloc[0]
test_messages = []
test_messages.append({"role": "system", "content": system_message})
user_message = create_user_message(test_row)
test_messages.append({"role": "user", "content": user_message})

pprint(test_messages)

[{'content': 'You are a helpful clinical trial studies assistant. You are to '
             'extract the clinical trial studies provided.',
  'role': 'system'},
 {'content': 'Title: Benign Prostatic Hyperplasia Trial With Dutasteride And '
             'Tamsulosin Combination Treatment\n'
             '\n'
             'Phase: Phase 3\n'
             '\n'
             'Date: 2017-01-16\n'
             '\n'
             'Facility: GSK Investigational Site\n'
             '\n'
             'Country: Mexico\n'
             '\n'
             'Criteria: Inclusion criteria:\n'
             '\n'
             'A subject will be considered eligible for inclusion in this '
             'study only if all of the following criteria apply:\n'
             'males, aged ≥50 years\n'
             'clinical diagnosis of BPH by medical history and physical '
             'examination, including a digital rectal examination (DRE)\n'
             'International Prostate Symptom Score (IPSS) ≥12 points at 

In [56]:
response = client.chat.completions.create(
    model=fine_tuned_model_id, messages=test_messages, temperature=0, max_tokens=500
)
print(response.choices[0].message.content)

Benign Prostatic Hyperplasia
