In [None]:
import numpy as np
import pandas as pd
import torch
import os
import json
import logging
import re
import time
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainerCallback
from trl import GRPOTrainer, GRPOConfig
from datasets import Dataset
from accelerate import PartialState

# Set up logging
os.makedirs("logs", exist_ok=True)
log_file = f"logs/training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
rewards_file = f"logs/rewards_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl"

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger("GRPO-Trader")

# Keep your original system prompt
SYSTEM_PROMPT = """
You are a macro event driven portfolio manager, you make positioning decision of S&P500 index based on market context and news.

Only edit macro state if there are changes in the macro regime that would impact returns of S&P500.

Positioning should be a float that ranges from -1 (full short) to 1 (full long).

You must respond in the following XML format:

<macro state>
...
</macro state>
<reasoning>
...
</reasoning>
<positioning>
...
</positioning>
"""

def is_main_process():
    """Check if current process is the main process"""
    try:
        state = PartialState()
        return state.process_index == 0
    except:
        return True

def log_reward(data):
    """Log rewards from main process only"""
    if not is_main_process():
        return
        
    with open(rewards_file, 'a') as f:
        f.write(json.dumps(data) + '\n')

def prepare_data(df):
    """Prepare training data from DataFrame"""
    if is_main_process():
        logger.info("Preparing training data...")
    
    def prepare_prompt(df):
        df['prompt'] = df.apply(lambda row: [
            {
                "role": "system",
                "content": SYSTEM_PROMPT.strip()
            },
            {
                "role": "user",
                "content": f"Market Context:{', '.join(f'{k}:{v}' for k, v in row.drop('date', errors='ignore').items())}"
            }
        ], axis=1)
        return df

    df = prepare_prompt(df)
    df['returns'] = df['close'].pct_change().shift(-1)
    train_dataset = df[['prompt', 'returns']]
    data = Dataset.from_pandas(train_dataset, preserve_index=False)
    
    if is_main_process():
        logger.info(f"Prepared {len(data)} training examples")
    return data

def extract_positioning(text):
    """Extract the positioning value from the XML format"""
    try:
        match = re.search(r"<positioning>(.*?)</positioning>", text, re.DOTALL)
        if match:
            value = match.group(1).strip()
            return float(value)
        return 0.0
    except Exception as e:
        if is_main_process():
            logger.error(f"Error extracting positioning: {e}")
        return 0.0

def format_reward_func(prompts, completions, **kwargs):
    """Reward function that checks if completions follow the expected format"""
    # Get step information if available
    step = kwargs.get("step", 0)
    
    # Fixed pattern with DOTALL flag to match across newlines
    pattern = r"<macro state>.*?</macro state>.*?<reasoning>.*?</reasoning>.*?<positioning>.*?</positioning>"
    
    completion_contents = [
        completion[0]["content"] if isinstance(completion[0], dict) else completion[0] 
        for completion in completions
    ]
    
    # Log sample completions periodically (main process only)
    if is_main_process() and step % 20 == 0 and len(completion_contents) > 0:
        sample_idx = min(2, len(completion_contents) - 1)
        sample = completion_contents[sample_idx]
        logger.info(f"Step {step} - Sample completion:\n{sample[:200]}...")
    
    # Check format with proper regex (using re.DOTALL)
    matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
    rewards = [3.0 if match else -1.0 for match in matches]
    
    # Log format match rate (main process only)
    if is_main_process() and matches:
        match_rate = sum(matches) / len(matches)
        logger.info(f"Step {step} - Format match rate: {match_rate:.2f}")
        log_reward({
            "step": step,
            "time": time.time(),
            "type": "format",
            "match_rate": match_rate,
            "avg_reward": sum(rewards) / len(rewards)
        })
    
    return rewards

def return_reward(prompts, completions, returns, **kwargs):
    """Return-based reward function"""
    step = kwargs.get("step", 0)
    
    rewards = []
    completion_contents = [
        completion[0]["content"] if isinstance(completion[0], dict) else completion[0] 
        for completion in completions
    ]
    
    # Track valid positioning values
    positions = []
    
    for i, completion in enumerate(completion_contents):
        try:
            position = extract_positioning(completion)
            positions.append(position)
            reward = position * returns[i % len(returns)]
            rewards.append(reward)
        except Exception as e:
            if is_main_process():
                logger.error(f"Error in return_reward: {e}")
            rewards.append(0.0)
    
    # Log stats about positions and rewards (main process only)
    if is_main_process() and len(positions) > 0:
        avg_position = sum(positions) / len(positions)
        avg_reward = sum(rewards) / len(rewards)
        logger.info(f"Step {step} - Avg position: {avg_position:.2f}, Avg reward: {avg_reward:.2f}")
        log_reward({
            "step": step,
            "time": time.time(),
            "type": "returns",
            "avg_position": avg_position,
            "avg_reward": avg_reward
        })
    
    return rewards

def main():
    if is_main_process():
        logger.info("Starting GRPO Trader training with multi-GPU support")
    
    # Model and output configuration
    model_path = './huggingface_mirror/models--Qwen--Qwen2.5-0.5B-Instruct/snapshots/7ae557604adf67be50417f59c2c2f167def9a775'
    output_dir = "outputs/Qwen-2.5-0.5B-GRPO-trader"
    os.makedirs(output_dir, exist_ok=True)
    
    # Load dataset
    if is_main_process():
        logger.info("Loading dataset")
    train_df = pd.read_csv('../train.csv')
    data = prepare_data(train_df)
    
    # Load model and tokenizer
    if is_main_process():
        logger.info(f"Loading model from {model_path}")
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    )  # Don't manually move to CUDA - accelerate will handle device placement

    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        padding_side="left",
        trust_remote_code=True
    )
    
    # Set pad token if needed
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # GRPO Configuration for multi-GPU
    grpo_config = GRPOConfig(
        output_dir=output_dir,
        run_name="Qwen-2.5-0.5B-GRPO-trader",
        learning_rate=5e-6,
        adam_beta1=0.9,
        adam_beta2=0.99,
        weight_decay=0.1,
        warmup_ratio=0.1,
        lr_scheduler_type='cosine',
        logging_steps=10,
        bf16=True,
        per_device_train_batch_size=2,  # This is PER GPU
        gradient_accumulation_steps=8,
        num_generations=4,  # Crucial for multi-GPU: ensure num_generations divides global batch size
        max_prompt_length=768,
        max_completion_length=256,
        num_train_epochs=1,
        save_steps=100,
        max_grad_norm=0.1,
        temperature=0.7,
        num_iterations=2,
        reward_weights=[1.0, 1.0],
        # Multi-GPU settings
        ddp_find_unused_parameters=False,
        dataloader_num_workers=4,
        # Disable VLLM for multi-GPU - not compatible with DDP
        use_vllm=False
    )
    
    # Store current global step for logging
    current_step = [0]
    
    # Create wrapped reward functions with step info
    def format_reward_with_step(prompts, completions, **kwargs):
        return format_reward_func(prompts, completions, step=current_step[0], **kwargs)
    
    def return_reward_with_step(prompts, completions, returns, **kwargs):
        return return_reward(prompts, completions, returns, step=current_step[0], **kwargs)
    
    # Initialize trainer with multi-GPU settings
    if is_main_process():
        logger.info("Initializing trainer with multi-GPU support")
    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=[
            format_reward_with_step,
            return_reward_with_step
        ],
        args=grpo_config,
        train_dataset=data,
    )
    
    # Update global step during training using a TrainerCallback
    class StepTracker(TrainerCallback):
        def on_step_end(self, args, state, control, **kwargs):
            current_step[0] = state.global_step
    
    trainer.add_callback(StepTracker())
    
    # Start training
    if is_main_process():
        logger.info("Starting training")
    try:
        trainer.train()
        if is_main_process():
            logger.info("Training completed successfully")
    except Exception as e:
        if is_main_process():
            logger.error(f"Training failed with error: {str(e)}")
    
    return trainer

if __name__ == "__main__":
    main()