In [1]:
import os
import backoff
import openai
from openai import AzureOpenAI

error_types = (openai.BadRequestError, TypeError)

client = AzureOpenAI(
    api_version="2023-07-01-preview",
    api_key=os.getenv('OPENAIAZURE_APIKEY'),
    azure_endpoint="https://gpt4v-jb.openai.azure.com",
)

@backoff.on_exception(backoff.expo, error_types)
def completions_with_backoff(**kwargs):
    return client.chat.completions.create(**kwargs)

def call_gpt_azure(system_prompt, prompt, temperature=0, n=1):
    response = completions_with_backoff(
        model="JBGPT4TURBO_1106_PREVIEW",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt}
        ],
        temperature=temperature,
        n=n,
    )
    cost = response.usage.completion_tokens * (0.06 / 1000) + response.usage.prompt_tokens * (0.03 / 1000)
    completion = response.choices[0].message.content
    return completion, cost

In [22]:
from openai import OpenAI
client = OpenAI(api_key=os.getenv('OPENAI_APIKEY'))

def call_gpt(system_prompt, user_prompt, temperature=0, max_tokens=512):
    response = client.chat.completions.create(
        # model="gpt-4-1106-preview",
        model="gpt-4-0125-preview",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=temperature,
        max_tokens=max_tokens,
    )
    completion = response.choices[0].message.content
    return completion

Prompt can include the predicates themselves

In [23]:
task_description = \
"""
The task is to predict in-hospital mortality. I want to extract a cohort of all patients that were admitted to the hospital who did not have COVID in the window 30 days leading up to the admission. These patients must have been in the hospital for at least 48 hours (ie. wasn't re-admitted, not discharged, or didn't die within the 48 hour window) and were either discharged or died in the hospital. I want to use data from the past 30 days before admission as well as 24 hours into the admission. The patients should have at least 50 events in the window leading up to the admission.

Here are the relevant data fields in the dataset:
Admission is shown by ADMISSION in event_type
Discharge is shown by DISCHARGE in event_type
Death is shown by DEATH in event_type
COVID is shown by COVID in diagnosis
"""

In [24]:
system_prompt = \
"""
You are an expert with electronic health records and understand the structure of medical time series data.
"""

user_prompt = \
f"""
Your objective is to create a configuration file based on the provided patient cohort description. The configuration will be used to query a dataset for valid patients. Ensure to include all fields by inferring values where possible.

Begin by first defining the predicates that you need to describe the task. Then, define the windows that you need to segment the patient time series data and place constraints on the windows by using the predicates that were created.

Below is an example configuration file, and your output should follow the same YAML format:

```
predicates:
    <predicate_name_1>:
        column: <column_name>
        value: <column_value>
        system: <boolean OR count>
    <predicate_name_2>:
        column: <column_name>
        value: <column_value>
        system: <boolean OR count>
    <predicate_1_or_2>:
        type: ANY
        predicates: [predicate_name_1, predicate_name_2]
        system: boolean
    <predicate_1_and_2>:
        type: ALL
        predicates: [predicate_name_1, predicate_name_2]
        system: boolean
    ...

windows:
    trigger:
        start: <predicate_name_1>
        duration: 
        offset: <predicate_name_1>
        end:
        excludes:
        includes: 
        - predicate: <predicate_name_1>
                min: <int>
                max: <int>
        st_inclusive: <True OR False>
        end_inclusive: <True OR False>
    gap:
        start: <window_name_1.end>
        duration: <# hours>
        offset:
        end:
        excludes:
        - predicate: <predicate_name_2>
        includes:
        st_inclusive: <True OR False>
        end_inclusive: <True OR False>
    target:
        start: <window_name_2.end>
        duration:
        offset:
        end: <predicate_name_3>
        excludes:
        includes: 
        - predicate: <predicate_name_2>
                min: <int>
                max: <int>
        st_inclusive: <True OR False>
        end_inclusive: <True OR False>
        label: <predicate_name_3>
    input:
        start: <window_name_2.end>
        duration:
        offset:
        end: <predicate_name_3>
        excludes:
        includes: 
        - predicate: <predicate_name_2>
                min: <int>
                max: <int>
        st_inclusive: <True OR False>
        end_inclusive: <True OR False>
        label: <predicate_name_3>
    ...
```

The first part of the configuration file defines predicates, which are boolean markers for events in the dataset. These predicates will be used to filter the patients. Each predicate can bed defined using a 'name', 'column', 'value', and 'system'. 
    - 'column' is the name of the column in the dataset
    - 'value' is the value that the column should have
    - 'system' is either "boolean" or "count". Some common predicates with a "boolean" system are 'admission', 'death', and 'discharge'. Some common predicates with a "count" system are 'lab', 'medication', and 'procedure'.

Complex predicates can be created by combining multiple predicates using "ANY" or "ALL" in the 'type' field. 
    - "ANY" type is used to combine predicates with an 'OR' relationship
    - "ALL" type is used to combine predicates with an 'AND' relationship. 
    - 'predicates' field is a list of the names of the predicates that are being combined. 
    - All complex predicates have a 'boolean' system.
    - For instance, one can create a predicate is_discharge_or_death by combining the predicates 'discharge' and 'death' using the "ANY" type.

There is also always a special predicate named 'any' that can be used to capture any event in the dataset. Use this predicate when a task requires that a window has a certain number of events.

Time strictly increases, thus negative means before the event and positive means after the event.

The second part of the configuration file defines windows that are used to segment the patient time series data. Each window has a 'name', 'start', 'duration', 'offset', 'end', 'excludes', 'includes', 'st_inclusive', 'end_inclusive'. One of the windows also have a 'label' field. A window can be a time-bound window or a predicate-bound window, and certain fields can be left blank if they are not applicable.

There windows are one of 'trigger', 'gap', 'target', and 'input'. 'trigger' is often a window that has one event (ie. same 'start' and 'end'). 'gap' is often a window that excludes certain predicates. 'target' is often a window that has the event of interest and thus has the 'label' field. 'input' is often a window that has the events that are used to predict the event of interest.

For all windows:
    - 'start' is the name of the predicate that the window starts from or "window_name.end" of a previous window.
    - 'offset' is the offset from the 'start', defined using a number and one of seconds, minutes, hours, or days.
    - 'excludes' is a list of predicates that the window must not contain defined by their names.
    - 'includes' is a list of predicates that the window must contain, defined by their names and the minimum and maximum number of times they should occur.
    - 'st_inclusive' is a boolean that determines if the predicate at the start should be included in the window.
    - 'end_inclusive' is a boolean that determines if the predicate at the end should be included in the window. 

For time-bound windows:
    - 'duration' is the duration of the window, defined using a number and one of seconds, minutes, hours, or days.

For predicate-bound windows:
    - 'end' is the name of the predicate that the window ends at.

'label' is reserved for one window and is a predicate name. This field should be used to capture the question that the cohort aims to answer (ie. does the patient die in this window?)

Keep the predicates simple. Any task requirements that indicate there needs to be a certain number of events in a window should be captured in the window 'excludes' or 'includes' fields.

Your output should ONLY have the configuration file.

----------------
Cohort Description:
{task_description}

----------------
```
<Configuration File>
```
"""

In [25]:
response = call_gpt(system_prompt, user_prompt)
# response, _ = call_gpt_azure(system_prompt, user_prompt)

In [26]:
print(response)

```yaml
predicates:
  admission:
    column: event_type
    value: ADMISSION
    system: boolean
  discharge:
    column: event_type
    value: DISCHARGE
    system: boolean
  death:
    column: event_type
    value: DEATH
    system: boolean
  covid:
    column: diagnosis
    value: COVID
    system: boolean
  discharge_or_death:
    type: ANY
    predicates: [discharge, death]
    system: boolean
  no_covid:
    column: diagnosis
    value: COVID
    system: boolean
    negate: true

windows:
  pre_admission:
    start: admission
    duration: 30 days
    offset: -30 days
    end:
    excludes:
      - predicate: covid
    includes:
      - predicate: any
        min: 50
        max:
    st_inclusive: False
    end_inclusive: True
  hospital_stay:
    start: admission
    duration: 48 hours
    offset:
    end:
    excludes:
      - predicate: discharge
      - predicate: death
    includes:
    st_inclusive: True
    end_inclusive: False
  post_admission:
    start: admission
    du