In [1]:
%load_ext autoreload
%autoreload 2

# Dependencies

In [2]:
!uv pip install dotenv "numpy<2.0.0"

[2mUsing Python 3.11.13 environment at: /usr[0m
[2mAudited [1m2 packages[0m [2min 82ms[0m[0m


In [3]:
!uv pip install openpipe-art openpipe --prerelease allow --no-cache-dir

[2mUsing Python 3.11.13 environment at: /usr[0m
[2mAudited [1m2 packages[0m [2min 109ms[0m[0m


# Environment variables

In [4]:
import os

# Optional
WANDB_API_KEY = ""
if WANDB_API_KEY:
    os.environ["WANDB_API_KEY"] = WANDB_API_KEY

# Optional
OPENPIPE_API_KEY = ""
if OPENPIPE_API_KEY:
    os.environ["OPENPIPE_API_KEY"] = OPENPIPE_API_KEY

MODEL_NAME = "001"
PROJECT = "generate-fhir-single-turn"
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
LEARNING_RATE = 1.2e-5
GROUPS_PER_STEP = 1
EVAL_STEPS = 50
VAL_SET_SIZE = 100
NUM_EPOCHS = 1
NUM_GENERATIONS = 6

# Main loop

## Validate Tool

In [None]:
import json
import requests
from typing import Dict, Any, Union, List, Optional
from dataclasses import dataclass
from datetime import datetime

@dataclass
class ValidationResult:
    """Result of FHIR resource validation from server"""
    is_valid: bool
    errors: List[str]
    warnings: List[str]
    information: List[str]
    resource_type: str = None
    timestamp: str = None
    operation_outcome: Dict[str, Any] = None

    def __post_init__(self):
        if self.timestamp is None:
            self.timestamp = datetime.utcnow().isoformat()

def validate_fhir_resource(
    resource_json: Union[str, Dict[str, Any]],
    fhir_server_url: str,
    resource_type: Optional[str] = None,
    profile_url: Optional[str] = None,
    timeout: int = 30
) -> ValidationResult:
    """
    Validate a FHIR resource by making a REST call to a FHIR server's validation endpoint.

    Args:
        resource_json: FHIR resource as JSON string or dictionary
        fhir_server_url: Base URL of the FHIR server (e.g., "https://hapi.fhir.org/baseR4")
        resource_type: Optional resource type for validation endpoint (e.g., "Patient")
        profile_url: Optional profile URL to validate against
        timeout: Request timeout in seconds (default: 30)

    Returns:
        ValidationResult: Object containing validation status, errors, and warnings
    """
    errors = []
    warnings = []
    information = []
    detected_resource_type = None
    operation_outcome = None

    try:
        # Parse JSON if string input
        if isinstance(resource_json, str):
            try:
                resource_dict = json.loads(resource_json)
            except json.JSONDecodeError as e:
                return ValidationResult(
                    is_valid=False,
                    errors=[f"Invalid JSON format: {str(e)}"],
                    warnings=[],
                    information=[]
                )
        else:
            resource_dict = resource_json

        # Extract resource type from the resource
        if isinstance(resource_dict, dict) and 'resourceType' in resource_dict:
            detected_resource_type = resource_dict['resourceType']

        # Prepare validation endpoint URL
        fhir_base_url = fhir_server_url.rstrip('/')

        # Build validation URL - can be specific to resource type or general
        if resource_type:
            validation_url = f"{fhir_base_url}/{resource_type}/$validate"
        elif detected_resource_type:
            validation_url = f"{fhir_base_url}/{detected_resource_type}/$validate"
        else:
            validation_url = f"{fhir_base_url}/$validate"

        # Prepare request headers
        headers = {
            'Content-Type': 'application/fhir+json',
            'Accept': 'application/fhir+json'
        }

        # Add profile parameter if specified
        params = {}
        if profile_url:
            params['profile'] = profile_url

        # Make the validation request
        response = requests.post(
            validation_url,
            json=resource_dict,
            headers=headers,
            params=params,
            timeout=timeout
        )

        # Parse the response
        if response.status_code == 200:
            # Successful validation - parse OperationOutcome
            try:
                operation_outcome = response.json()
                is_valid, parsed_errors, parsed_warnings, parsed_info = _parse_operation_outcome(operation_outcome)
                errors.extend(parsed_errors)
                warnings.extend(parsed_warnings)
                information.extend(parsed_info)

            except json.JSONDecodeError:
                errors.append("Server returned invalid JSON response")

        elif response.status_code == 400:
            # Bad request - usually means validation failed
            try:
                operation_outcome = response.json()
                is_valid, parsed_errors, parsed_warnings, parsed_info = _parse_operation_outcome(operation_outcome)
                errors.extend(parsed_errors)
                warnings.extend(parsed_warnings)
                information.extend(parsed_info)
            except json.JSONDecodeError:
                errors.append(f"Validation failed with status {response.status_code}: {response.text}")

        elif response.status_code == 404:
            errors.append("FHIR server validation endpoint not found - check server URL and resource type")

        elif response.status_code == 422:
            # Unprocessable Entity - validation errors
            try:
                operation_outcome = response.json()
                is_valid, parsed_errors, parsed_warnings, parsed_info = _parse_operation_outcome(operation_outcome)
                errors.extend(parsed_errors)
                warnings.extend(parsed_warnings)
                information.extend(parsed_info)
            except json.JSONDecodeError:
                errors.append(f"Validation failed with status {response.status_code}: {response.text}")

        else:
            errors.append(f"FHIR server returned status {response.status_code}: {response.text}")

    except requests.exceptions.Timeout:
        errors.append(f"Request to FHIR server timed out after {timeout} seconds")

    except requests.exceptions.ConnectionError:
        errors.append("Could not connect to FHIR server - check URL and network connectivity")

    except requests.exceptions.RequestException as e:
        errors.append(f"Request to FHIR server failed: {str(e)}")

    except Exception as e:
        errors.append(f"Unexpected error during validation: {str(e)}")

    return ValidationResult(
        is_valid=len(errors) == 0,
        errors=errors,
        warnings=warnings,
        information=information,
        resource_type=detected_resource_type,
        operation_outcome=operation_outcome
    )

def _parse_operation_outcome(operation_outcome: Dict[str, Any]) -> tuple[bool, List[str], List[str], List[str]]:
    """
    Parse FHIR OperationOutcome resource to extract validation results.

    Returns:
        Tuple of (is_valid, errors, warnings, information)
    """
    errors = []
    warnings = []
    information = []
    is_valid = True

    if not isinstance(operation_outcome, dict):
        errors.append("Invalid OperationOutcome format")
        return False, errors, warnings, information

    # Check if it's actually an OperationOutcome resource
    if operation_outcome.get('resourceType') != 'OperationOutcome':
        errors.append("Expected OperationOutcome resource from validation endpoint")
        return False, errors, warnings, information

    # Parse issues from OperationOutcome
    issues = operation_outcome.get('issue', [])

    for issue in issues:
        if not isinstance(issue, dict):
            continue

        severity = issue.get('severity', 'error')
        code = issue.get('code', 'unknown')
        details = issue.get('details', {})
        diagnostics = issue.get('diagnostics', '')
        location = issue.get('location', [])

        # Build error message
        message_parts = []

        if details and isinstance(details, dict):
            detail_text = details.get('text', '')
            if detail_text:
                message_parts.append(detail_text)

        if diagnostics:
            message_parts.append(diagnostics)

        if location:
            location_str = ', '.join(location)
            message_parts.append(f"Location: {location_str}")

        if not message_parts:
            message_parts.append(f"Validation issue ({code})")

        message = ' - '.join(message_parts)

        # Categorize by severity
        if severity in ['fatal', 'error']:
            errors.append(f"Error: {message}")
            is_valid = False
        elif severity == 'warning':
            warnings.append(f"Warning: {message}")
        elif severity == 'information':
            information.append(f"Info: {message}")

    # If no issues but we have an OperationOutcome, it might be successful
    if not issues and operation_outcome.get('resourceType') == 'OperationOutcome':
        information.append("Validation completed successfully")

    return is_valid, errors, warnings, information

## Model creation and data load

In [7]:
import datasets
import art
from art.local import LocalBackend
from openpipe import AsyncOpenPipe
from typing import List, Dict, Any

backend = LocalBackend(
    # Normally we don't want to run the server in-process, but for the output
    # to show up properly on Google Colab we'll enable this.
    in_process=True,
    path="./.art"
)
model = art.TrainableModel(
    name=MODEL_NAME,
    project=PROJECT,
    base_model=BASE_MODEL,
    _internal_config=art.dev.InternalModelConfig(
        init_args=art.dev.InitArgs(
            gpu_memory_utilization=0.75,
        ),
        peft_args=art.dev.PeftArgs(
            lora_alpha=8,
        ),
        trainer_args=art.dev.TrainerArgs(
            max_grad_norm=0.1,
        ),
    ),
)
await model.register(backend)
op_client = AsyncOpenPipe(api_key=os.getenv("OPENPIPE_API_KEY"))

## Load the training data
print("Loading training data...")
train_dataset: datasets.Dataset = datasets.load_dataset("jdjkelly/fhir_synthetic_snippets")

train_data_list: List[Dict[str, Any]] = list(train_dataset)  # type: ignore
print(f"Training data size: {len(train_data_list)}")

# Get OpenAI Client for the ART Model
openai_client = model.openai_client()

Loading training data...
Training data size: 1
Starting training from global step 0


Iterating dataset:   0%|          | 0/1 [00:00<?, ?batch/s]

['train']
0
0
0


## Rollout

In [None]:
# for batch_inputs, epoch, global_step, epoch_step in data_iterator:
#   print(batch_inputs)
#   print(epoch)
#   print(global_step)
#   print(epoch_step)
#   break
from pydantic import BaseModel
import requests
import openai

class GenerateFhirBundle(BaseModel):
    step: int

@art.retry(exceptions=(openai.LengthFinishReasonError, requests.ReadTimeout))
async def rollout(
    model: art.Model,
    row: Dict[str, Any],
) -> art.Trajectory:
    game = generate_game()

    move_number = 0

    trajectory = art.Trajectory(
        messages_and_choices=[
            {
                "role": "system",
                "content": """
You are a health informaticist expert in FHIR.
You will receive unstructured notes and you need to structure them into FHIR resources.
You must only include data that is present in the note.
You must only return a valid FHIR JSON Bundle, with the appropriate resources, with no additional explanation.
You may include multiple resources in the bundle.
You must follow the FHIR R4 specification.
You mut not include a meta element in the resources.
When generating a CodeableConcept, you must include a coding element with a system, code, and display.
When generating a CodeableConcept, you must use a display matching what is expected by the CodeSystem.
Each entry in a Bundle must have a fullUrl which is the identity of the resource in the entry.
The id of a resource must be a valid UUID in lowercase.

You have access to a validator tool that will validate the FHIR bundle.
You should use this tool recursively to fix errors, using it again after you have called it to ensure that FHIR resources are fully valid after making changes.

Include the FHIR JSON bundle in your final response.
<note>
{{note}}
</note>
""",
            }
        ],
        reward=0,
    )

    while True:
        trajectory.messages_and_choices.append(
            {"role": "user", "content": render_board(game)}
        )

        requested_at = int(time.time() * 1000)
        messages = trajectory.messages()

        async def get_completion():
            client = model.openai_client()
            return await client.chat.completions.create(
                max_completion_tokens=128,
                messages=messages,
                model=model.name,
            )

        try:
            chat_completion = await get_completion()
            last_completion = chat_completion
        except openai.LengthFinishReasonError as e:
            raise e
        except Exception as e:
            print("caught exception generating chat completion", e)
            raise e

        try:
            if op_client.api_key:
                await op_client.report(
                    requested_at=requested_at,
                    received_at=int(time.time() * 1000),
                    req_payload={
                        "model": model.name,
                        "messages": messages,
                        "metadata": {
                            "game_id": game["id"],
                            "notebook-id": "2048",
                            "step": str(scenario.step),
                            "move_number": str(move_number),
                        },
                    },
                    resp_payload=chat_completion,
                    status_code=200,
                )
        except Exception as e:
            print(f"Error reporting to OpenPipe: {e}")

        choice = chat_completion.choices[0]
        content = choice.message.content
        assert isinstance(content, str)
        trajectory.messages_and_choices.append(choice)

        try:
            apply_agent_move(game, content)
            move_number += 1
        except ValueError:
            trajectory.reward = -1
            break

        if check_game_finished(game):
            max_value = max_cell_value(game)
            board_value = total_board_value(game)
            trajectory.metrics["max_value"] = max_value
            trajectory.metrics["board_value"] = board_value

            if max_value < WINNING_VALUE:
                # scale max value logarithmically between 0 for 2 and 1 for WINNING_VALUE
                max_value_reward = (math.log(max_value, 2) - 1) / (
                    math.log(WINNING_VALUE, 2) - 1
                )
                # scale board value logarithmically between 0 for 2 * 16 and 1 for WINNING_VALUE * 16
                board_value_reward = (math.log(board_value, 2) - 1) / (
                    math.log(WINNING_VALUE * 16, 2) - 1
                )
                # combine the two rewards, with max value having a higher weight
                trajectory.reward = max_value_reward + (board_value_reward * 0.2)
            else:
                # double reward if the agent wins
                trajectory.reward = 2
            break

    try:
        if op_client.api_key:
            await op_client.update_log_metadata(
                filters=[
                    {
                        "field": "completionId",
                        "equals": last_completion.id,
                    }
                ],
                metadata={
                    "reward": str(trajectory.reward),
                    "reward_assigned": "true",
                },
            )
    except Exception as e:
        print(f"Error updating log metadata: {e}")

    return trajectory


## Training loop



In [None]:
start_step = await model.get_step()
print(f"Starting training from global step {start_step}")

data_iterator = art.utils.iterate_dataset(
    dataset=train_data_list,
    groups_per_step=GROUPS_PER_STEP,
    num_epochs=NUM_EPOCHS,
    initial_step=start_step,
    use_tqdm=True,
)

for i in range(await model.get_step(), 10):
    train_groups = await art.gather_trajectory_groups(
        (
            art.TrajectoryGroup(
                rollout(model, Scenario2048(step=i)) for _ in range(18)
            )
            for _ in range(1)
        ),
        pbar_desc="gather",
        max_exceptions=18,
    )
    await model.delete_checkpoints()
    await model.train(
        train_groups,
        config=art.TrainConfig(learning_rate=3e-5),
        # Lowering the logprob_calculation_chunk_size is a memory saving measure
        # to allow longer sequences (up to 4096 tokens) to be processed on a T4.
        _config={"logprob_calculation_chunk_size": 8},
    )