# Import packages

In [1]:
import torch
from transformers import pipeline
import json
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# Import Med-Gemma-4b-it
## Note you may need to install flash attention from torch

In [4]:
pipe = pipeline(
    "text-generation",
    model="google/medgemma-4b-it",
    model_kwargs={
        "torch_dtype": torch.bfloat16, 
        "attn_implementation": "flash_attention_2" 
    },
    device_map="auto" 
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:26<00:00, 13.25s/it]
Device set to use cuda:0


# Import clinical notes (Progress Notes)

In [5]:
import pandas as pd
df = pd.read_parquet('latest_notes.parquet')
#df = df.iloc[:100]
sample_notes = df['note_text'].values.tolist() # Convert to list for easier handling

In [6]:
#system_prompt = "Extract the following information from the clinical note in JSON format."

In [7]:
# system_prompt = """
# You are a strict JSON generator. Only output JSON that matches the schema below.
# Do not include any extra text, commentary, or explanations.

# Schema:
# - imatinib_mentioned: true if the drug imatinib is mentioned in the note, otherwise false
# - related_drugs_mentioned: true if drugs related to imatinib (e.g., dasatinib, nilotinib, bosutinib) are mentioned, otherwise false
# - cml_diagnosed: true if chronic myeloid leukemia is diagnosed, otherwise false
# - cml_in_regression: true if chronic myeloid leukemia is mentioned as being in regression, otherwise false

# Rules:
# 1. Only mark a field as true if the note clearly indicates it.
# 2. If the note does not explicitly mention a field, mark it false.
# 3. The output must always be valid JSON with all four fields present.
# """


# Set prompts

In [8]:
system_prompt = """
You are a strict JSON generator. Only output JSON that matches the schema below.
Do not include any extra text, commentary, or explanations.

Schema:
- imatinib_mentioned: true if the drug imatinib (also known as Gleevec) is mentioned in the note, otherwise false
- related_drugs_mentioned: true if drugs related to imatinib (e.g., dasatinib, nilotinib, bosutinib) are mentioned, otherwise false
- cml_diagnosed: true if chronic myeloid leukemia (CML) is diagnosed, otherwise false
- cml_in_regression: true if chronic myeloid leukemia is mentioned as being in regression, otherwise false
- aml_diagnosed: true if acute myeloid leukemia (AML) is diagnosed, otherwise false
- blast_phase_cml: true if blast phase CML is explicitly mentioned, otherwise false
- bmt_history: true if history of bone marrow transplant (BMT) is mentioned, otherwise false
- acute_phase_cml: true if acute phase CML is explicitly mentioned, otherwise false

Rules:
1. Only mark a field as true if the note clearly indicates it.
2. If the note does not explicitly mention a field, mark it false.
3. The output must always be valid JSON with all eight fields present.
"""


In [9]:
prompts = [
    system_prompt + "\n\n" + (note[:11000] if len(note) > 11000 else note)
    for note in sample_notes
]

In [10]:
schema_fields = [
    "imatinib_mentioned",
    "related_drugs_mentioned",
    "cml_diagnosed",
    "cml_in_regression",
    "aml_diagnosed",
    "blast_phase_cml",
    "bmt_history",
    "acute_phase_cml"
]


# Run inference

In [11]:
outputs = pipe(
     prompts,
     batch_size=8,
     max_new_tokens=128,
     do_sample=False,
     temperature=0 
 )

The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignore

# Extraction of inferenced notes

In [12]:
labelled_notes = []
for i, output in enumerate(outputs):
    structured_output_text = output[0]["generated_text"]
    
    try:
        json_start_index = structured_output_text.find("{")
        json_end_index = structured_output_text.rfind("}") + 1
        
        if json_start_index != -1 and json_end_index != 0:
            json_str = structured_output_text[json_start_index:json_end_index]
            label = json.loads(json_str)
        else:
            raise ValueError("No JSON object found in the output")

        for k in schema_fields:
            if k not in label:
                label[k] = False
                
    except (json.JSONDecodeError, ValueError, IndexError):
        label = {k: False for k in schema_fields}
    
    labelled_notes.append({"original_note": sample_notes[i], **label})

In [13]:
output_filename = 'labelled_notes_additonal_prompts_medgemma4b.json'
with open(output_filename, 'w') as f:
    json.dump(labelled_notes, f, indent=4)

print(f"Successfully saved {len(labelled_notes)} labeled notes to {output_filename}")


Successfully saved 1762 labeled notes to labelled_notes_additonal_prompts_medgemma4b.json
