In [1]:
import re
from datasets import load_dataset, Dataset, interleave_datasets, concatenate_datasets

# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_datasets(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer']),
        'db_set':'gsm8k'
    }) # type: ignore
    data = data.remove_columns(['question'])

    data_qa = load_dataset("qiaojin/PubMedQA", "pqa_artificial")[split] # two times more than other datasets
    data_qa = data_qa.filter(lambda x: len("\n".join(x['context']['contexts'])) < 1024) # avoid long traces
    data_qa = data_qa.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {
                "role": "user",
                "content": "Given the scientific context below:\n" +
                          "\n".join(x['context']['contexts']) +
                          "\n\nAnswer the following question:\n" +
                          x['question'] +
                          " with 'yes', 'no' or 'maybe'. You need to carefully review the context and reason before answering."
            },
        ],
        'answer': x['final_decision'],
        'db_set': 'pubmedqa'
    }) # type: ignore
    data_qa = data_qa.remove_columns(['pubid', 'question', 'context', 'long_answer', 'final_decision'])


    categories =['Lab_Medicine', 'Wearables', 'Dermatology', 'Gastroenterology', 'Internal_Medicine', 'Oncology', 'Orthopedics', 'General_Surgery', 'Ophthalmology', 'Audiology', 'Head_Neck_Surgery', 'Elderly_Care', 'Pediatrics', 'Allergy_Immunology', 'Rheumatology', 'Pharmacy', 'Obstetrics_Gynecology', 'Microbiology', 'Dentistry', 'Physical_Medicine_and_Rehabilitation', 'Neurology', 'Psychiatry', 'Pathology', 'Genetics', 'Rare_Diseases', 'Hematology', 'Emergency', 'Endocrinology', 'Radiology', 'Cardiology', 'Pulmonology', 'Infectious_Diseases', 'Critical_Care', 'Pediatric_Surgery', 'Neuroscience', 'Epidemiology', 'Fitness_Sports', 'Health_Education', 'Health_Economics', 'Health_Entrepreneurship', 'Hospital_Management', 'Mental_Health', 'Nutrition', 'Palliative_Care', 'Preventive_Medicine', 'Public_Health', 'Social_Media_Addiction', 'Sleep', 'Supplements', 'Vaccination', 'Work_Health', 'Wellbeing']
    data_mc = concatenate_datasets([load_dataset("yesilhealth/Health_Benchmarks",i)[i] for i in categories])
    data_mc = data_mc.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {
                "role": "user",
                "content": "\n\nAnswer the following question:\n" +
                          x['Questions'] +
                          "\n With 'A', 'B', 'C' or 'D'. You need to carefully review the context and reason before answering."
            },
        ],
        'answer': x['Answers'],
        'db_set': 'med_mc'
    }) # type: ignore
    data_mc = data_mc.remove_columns(['Answers', 'Questions'])

    dataset = concatenate_datasets([data, data_qa, data_mc])
    return dataset


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
dataset = get_datasets()
dataset = dataset.shuffle(seed=42)

In [24]:
from pprint import pprint

# Pretty print the first 5 rows
for i in range(2):
    pprint(dataset[i])
    print("-" * 50)


{'answer': 'no',
 'db_set': 'pubmedqa',
 'prompt': [{'content': '\n'
                        'Respond in the following format:\n'
                        '<reasoning>\n'
                        '...\n'
                        '</reasoning>\n'
                        '<answer>\n'
                        '...\n'
                        '</answer>\n',
             'role': 'system'},
            {'content': 'Given the scientific context below:\n'
                        'To define the role of dopamine-2 receptors in the '
                        'sympatho-inhibitory effects of '
                        'gamma-l-glutamyl-l-dopa in conscious rabbits.\n'
                        'gamma-l-glutamyl-l-dopa (gludopa) was infused iv at '
                        '25 and 100 micrograms.kg-1.min-1 with and without '
                        'prior dopamine-2 receptor blockade by YM-09151-2 (50 '
                        'micrograms.kg-1 iv) in conscious rabbits.\n'
                        'Mean arterial

In [3]:
dataset = dataset.select(range(100))  # Take first 100 rows


In [4]:
dataset

Dataset({
    features: ['answer', 'prompt', 'db_set'],
    num_rows: 100
})

In [5]:
small_split = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = small_split["train"]
test_dataset = small_split["test"]


In [6]:
train_dataset

Dataset({
    features: ['answer', 'prompt', 'db_set'],
    num_rows: 90
})

#### Reward Functions

In [7]:
## Reward functions
def correctness_reward_func(prompts, completions, answer, db_set, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    rewards = []
    for r,a,dt in zip(extracted_responses, answer, db_set):
        if dt == "gsm8k":
            if a in r:
                rewards.append(1.0)
            elif r == a:
                rewards.append(2.0)
            else:
                rewards.append(0.0)
        else:
            rewards.append(2.0 if r.lower() == a.strip().lower() else 0.0)
    return rewards


def int_reward_func(completions, db_set, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    rewards = []
    for r,dt in zip(extracted_responses,db_set):
        if dt == "gsm8k":
            rewards.append(0.5 if r.isdigit() else 0.0)
        elif dt == "pubmedqa":
            rewards.append(0.5 if ('yes' in r.lower() or 'no' in r.lower() or 'maybe' in r.lower()) else 0.0)
        else:
            rewards.append(0.5 if ('a' in r.lower() or 'b' in r.lower() or 'c' in r.lower() or 'd' in r.lower()) else 0.0)
    return rewards

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

In [20]:
def composite_reward_model(samples, responses, **kwargs):
    rewards = []

    for i, (sample, response) in enumerate(zip(samples, responses)):
        try:
            # Pack the response to match expected structure
            completion = [{"content": response}]
            completions = [completion]

            # Extract required fields
            db_set = sample.get("db_set", "gsm8k")
            answer = sample.get("answer", "")
            prompt_text = sample.get("prompt", "")

            # Convert for compatibility
            prompts = [[{"content": prompt_text}]]
            answers = [answer]
            db_sets = [db_set]

            # Evaluate all reward functions
            reward = sum([
                xmlcount_reward_func(completions)[0],
                soft_format_reward_func(completions)[0],
                strict_format_reward_func(completions)[0],
                int_reward_func(completions, db_sets)[0],
                correctness_reward_func(prompts, completions, answers, db_sets)[0]
            ])

        except Exception as e:
            print(f"[Reward Error] Sample {i}: {e}")
            reward = 0.0

        rewards.append(reward)

    return rewards


In [9]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import PPOConfig, PPOTrainer
import torch


In [10]:
model_name = "tiiuae/falcon-rw-1b"  # Substitute for a small model (2B unavailable here)
model_name = "Qwen/Qwen2.5-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name,device_map="auto")
model.config.pad_token_id = tokenizer.eos_token_id

Fetching 2 files: 100%|██████████| 2/2 [00:49<00:00, 24.91s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.10s/it]


In [11]:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id

In [12]:

# training_args = GRPOConfig(
#     learning_rate=5e-6,
#     adam_beta1=0.9,
#     adam_beta2=0.99,
#     weight_decay=0.1,
#     warmup_ratio=0.1,
#     lr_scheduler_type="cosine",
#     optim="adamw_8bit",  # Requires bitsandbytes
#     logging_steps=1,

#     per_device_train_batch_size=1,
#     gradient_accumulation_steps=2,  # Slightly smoother training
#     num_generations=2,  # Safe for Falcon-1B
#     max_prompt_length=512,
#     max_completion_length=128,
#     max_steps=15,
#     save_steps=3,
#     max_grad_norm=0.1,
#     # report_to="none",
#     report_to="wandb", 
#     output_dir="outputs",
# )

training_args = PPOConfig(
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    report_to="wandb",  # PPO uses `log_with` not `report_to`
    logging_steps=1,
    
    max_steps=15,
    save_steps=3,
    batch_size=1,
    mini_batch_size=1,  # PPO uses mini-batches instead of grad accumulation
    gradient_accumulation_steps=1,
    output_dir="outputs",

)


In [13]:
from peft import LoraConfig

peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
)


In [27]:
# trainer = GRPOTrainer(
#     model = model,
#     processing_class = tokenizer,
#     reward_funcs = [
#         xmlcount_reward_func,
#         soft_format_reward_func,
#         strict_format_reward_func,
#         int_reward_func,
#         correctness_reward_func,
#     ],
#     args = training_args,
#     train_dataset = train_dataset,
#     eval_dataset=test_dataset,
#     peft_config=peft_config
# )


trainer=PPOTrainer(
    model=model,
    processing_class=tokenizer,
    # ref_model=None,  # Using the same model as reference for simplicity
    # value_model=model,  # Using the same model as value model for simplicity
    # reward_model=[
    #     xmlcount_reward_func,
    #     soft_format_reward_func,
    #     strict_format_reward_func,
    #     int_reward_func,
    #     correctness_reward_func,
    # ],
    
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    peft_config=peft_config
)

TypeError: PPOTrainer.__init__() missing 3 required positional arguments: 'ref_model', 'reward_model', and 'value_model'

In [33]:
trainer.train()

-------------------- Question:
Given the scientific context below:
To verify whether the long-term retention of an emotionally arousing story is stronger than the retention of a neutral story, and the enhancing effects of emotional arousal on declarative memory in Alzheimer's disease (AD) patients.
Twenty subjects (10 with AD and 10 controls matched for age and educational level) were studied. After the audiovisual presentation (neutral story), the subjects rated the narrative's emotionality. Later, they answered a multiple-choice questionnaire about the stories. Two weeks later, they watched the emotionally arousing story.
Subjects who watched the emotionally arousing story assigned a score of emotionality higher than the subjects in the neutral group (P = 0.023). In addition, the participants remembered more details of the arousing story, and had a higher score in the questionnaire (P < 0.001).

Answer the following question:
Does emotional arousal enhance declarative memory in patie

Step,Training Loss
1,0.0
2,0.0
3,0.0
4,0.0
5,0.0
6,0.0
7,0.0
8,0.0
9,-0.0
10,-0.0


-------------------- Question:
Given the scientific context below:
The urinary bladder expresses Ca(2+)-activated Cl(-) channels (CACC), but its physiological role in governing contractility remains to be defined. The CACC modulator niflumic acid (NFA) is widely used despite the variable results arisen from different drug concentrations used. This study was designed to examine the effects of NFA at low concentrations on detrusor strip contractility.
Rat detrusor strips with mucosa-intact (+MU) and mucosa-denuded (-MU) were prepared in transverse (Tr) and longitudinal (Lg) with respect to the bladder orientation. Isometric force measurements were made at baseline (for spontaneous phasic contractile activity) and during drug stimulation (by carbachol, CCh) with and without NFA.
NFA (1 and 10 μmol/L) pretreatment enhanced CCh-induced contractions more in +MU than -MU strips with no selectivity on contractile direction. For spontaneous phasic contractions, NFA-treated strips in the Tr dire

TrainOutput(global_step=15, training_loss=-1.986821492513021e-09, metrics={'train_runtime': 126.6531, 'train_samples_per_second': 0.237, 'train_steps_per_second': 0.118, 'total_flos': 0.0, 'train_loss': -1.986821492513021e-09})

In [49]:
model1 = AutoModelForCausalLM.from_pretrained(model_name,device_map="auto")


Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.26s/it]


In [13]:
prompt = "Is Aspirin good for cardio vascular function?"
messages = [
    {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=512
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

In [14]:
response

"Aspirin can be beneficial for cardiovascular health, particularly in certain populations and under specific conditions. Here's an overview of its role:\n\n1. **Prevention of Thrombosis**: Aspirin is often used to prevent blood clots (thrombosis) in individuals who have had a heart attack or stroke, or who are at high risk of developing such events. It works by inhibiting the production of thromboxane, a substance that helps platelets stick together to form clots.\n\n2. **Reduction of Risk in High-Risk Individuals**: For people with a history of cardiovascular disease, aspirin may help reduce the risk of another event. However, this benefit is more pronounced in those who are considered to be at high risk, such as those with diabetes, peripheral artery disease, or a family history of heart disease.\n\n3. **Use in Primary Prevention**: While aspirin is sometimes recommended for primary prevention in certain high-risk groups (e.g., those with diabetes or a history of smoking), the decisi

##### Check with base model1

In [61]:
prompt = "Is Aspirin good for cardio vascular function?"
messages = [
    {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = trainer.model.generate(
    **model_inputs,
    max_new_tokens=512
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

In [62]:
response

'Aspirin can have benefits for cardiovascular health, particularly in certain populations, but it is not without risks and should be used with caution. Here are some key points to consider:\n\n### Benefits:\n1. **Primary Prevention**: For individuals who are at high risk of developing cardiovascular disease (CVD), such as those with a history of heart attack or stroke, or significant coronary artery disease, low-dose aspirin (usually 75-325 mg daily) can reduce the risk of future events.\n2. **Secondary Prevention**: In patients who have already experienced a heart attack, stroke, or other CVD events, aspirin helps prevent further complications by reducing blood clot formation.\n\n### Risks:\n1. **Gastrointestinal Bleeding**: Aspirin can increase the risk of gastrointestinal bleeding, which is a serious concern, especially in older adults and those with pre-existing conditions like peptic ulcers.\n2. **Gastrointestinal Irritation**: It may cause stomach pain, nausea, and other gastroin