In [None]:
def format_reward_func(prompts, completions, **kwargs):
    """
    Enhanced reward function that checks format and penalizes numbers
    presented as strings in the macro state section
    """
    step = kwargs.get("step", 0)
    
    # Basic format pattern
    pattern = r"<macro state>.*?</macro state>.*?<reasoning>.*?</reasoning>.*?<positioning>.*?</positioning>"
    
    # Pattern to detect numbers written as strings in macro state
    number_as_string_pattern = r'<macro state>.*?(\b(?:one|two|three|four|five|six|seven|eight|nine|ten|hundred|thousand|million|billion|point|percent|[0-9]+\.[0-9]+%?)\b).*?</macro state>'
    
    completion_contents = [
        completion[0]["content"] if isinstance(completion[0], dict) else completion[0] 
        for completion in completions
    ]
    
    rewards = []
    format_valid_count = 0
    number_as_string_count = 0
    
    for content in completion_contents:
        # Check basic format
        format_valid = re.search(pattern, content, re.DOTALL) is not None
        
        if format_valid:
            format_valid_count += 1
            
            # Check for numbers as strings in macro state
            has_number_as_string = re.search(number_as_string_pattern, content, re.DOTALL) is not None
            
            if has_number_as_string:
                number_as_string_count += 1
                # Penalize but still give some reward for correct format
                rewards.append(1.0)  # Reduced reward (from 3.0 to 1.0)
            else:
                # Full reward for correct format without numbers as strings
                rewards.append(3.0)
        else:
            # Penalty for incorrect format
            rewards.append(-1.0)
    
    # Log statistics
    if is_main_process() and len(completion_contents) > 0:
        match_rate = format_valid_count / len(completion_contents)
        number_as_string_rate = number_as_string_count / format_valid_count if format_valid_count > 0 else 0
        
        logger.info(f"Step {step} - Format valid rate: {match_rate:.2f}, Numbers as strings rate: {number_as_string_rate:.2f}")
        
        # Log sample completions periodically
        if step % 20 == 0 and len(completion_contents) > 0:
            sample_idx = min(2, len(completion_contents) - 1)
            sample = completion_contents[sample_idx]
            has_string_numbers = re.search(number_as_string_pattern, sample, re.DOTALL) is not None
            logger.info(f"Step {step} - Sample completion (has string numbers: {has_string_numbers}):\n{sample[:200]}...")
            
            if has_string_numbers:
                # Extract and log the problematic part
                match = re.search(number_as_string_pattern, sample, re.DOTALL)
                if match:
                    problematic_string = match.group(1)
                    logger.info(f"Problematic string number found: '{problematic_string}'")
        
        log_reward({
            "step": step,
            "time": time.time(),
            "type": "format",
            "format_valid_rate": match_rate,
            "number_as_string_rate": number_as_string_rate,
            "avg_reward": sum(rewards) / len(rewards)
        })
    
    return rewards