In [None]:
import numpy as np
import pandas as pd
import torch
import os
import json
import logging
import re
import time
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainerCallback
from trl import GRPOTrainer, GRPOConfig
from datasets import Dataset
import copy

# Set up logging
os.makedirs("logs", exist_ok=True)
log_file = f"logs/training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger("GRPO-Trader")

# Define system prompt with example
SYSTEM_PROMPT = """
You are a macro event driven portfolio manager, you make positioning decision of S&P500 index based on market context and news.

Only edit macro state if there are changes in the macro regime that would impact returns of S&P500.

Positioning should be a float that ranges from -1 (full short) to 1 (full long).

Example:
Market Context: PMI:52.6, inflation:5.1%, unemployment:3.8%

<macro state>
Economy is showing strength with PMI in expansion territory, but inflation remains elevated.
</macro state>
<reasoning>
The PMI at 52.6 indicates expansion, which is positive for equities. However, inflation at 5.1% is above the Fed's target, suggesting possible rate hikes. Employment is strong at 3.8%, supporting consumer spending.
</reasoning>
<positioning>
0.3
</positioning>

You must respond in the above XML format.
"""

# File path constants
RESULTS_DIR = "./results"
os.makedirs(RESULTS_DIR, exist_ok=True)
REWARDS_FILE = os.path.join(RESULTS_DIR, f"rewards_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl")
COMPLETIONS_FILE = os.path.join(RESULTS_DIR, f"completions_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl")
MODEL_DIR = "./outputs/Qwen-2.5-0.5B-GRPO-trader"
os.makedirs(MODEL_DIR, exist_ok=True)

def log_to_jsonl(file_path, data):
    """Append a JSON entry to a JSONL file"""
    with open(file_path, 'a') as f:
        f.write(json.dumps(data) + '\n')

def extract_positioning(text):
    """Extract the positioning value from the XML format"""
    try:
        match = re.search(r"<positioning>(.*?)</positioning>", text, re.DOTALL)
        if match:
            value = match.group(1).strip()
            return float(value)
        return 0.0
    except Exception as e:
        logger.error(f"Error extracting positioning: {e}")
        return 0.0

def check_format(text):
    """Check if text follows the expected XML format"""
    pattern = r"<macro state>.*?</macro state>.*?<reasoning>.*?</reasoning>.*?<positioning>.*?</positioning>"
    return re.search(pattern, text, re.DOTALL) is not None

def format_reward_func(prompts, completions, **kwargs):
    """Reward function that checks if completions follow the expected format"""
    completion_contents = [
        completion[0]["content"] if isinstance(completion[0], dict) else completion[0] 
        for completion in completions
    ]
    
    # Log sample completions periodically
    global_step = kwargs.get("global_step", 0)
    if global_step % 20 == 0 and len(completion_contents) > 0:
        sample_idx = min(2, len(completion_contents) - 1)
        sample = completion_contents[sample_idx]
        format_valid = check_format(sample)
        logger.info(f"Step {global_step} - Sample completion (format valid: {format_valid}):\n{sample[:200]}...")
    
    # Calculate rewards
    is_valid_format = [check_format(content) for content in completion_contents]
    format_rewards = [3.0 if valid else -2.0 for valid in is_valid_format]
    
    # Log statistics
    if len(is_valid_format) > 0:
        format_rate = sum(is_valid_format) / len(is_valid_format)
        log_to_jsonl(REWARDS_FILE, {
            "step": global_step,
            "time": time.time(),
            "type": "format",
            "format_valid_rate": format_rate,
            "avg_reward": sum(format_rewards) / len(format_rewards)
        })
        
    return format_rewards

def return_reward(prompts, completions, returns, **kwargs):
    """Reward function based on the positioning value and returns"""
    completion_contents = [
        completion[0]["content"] if isinstance(completion[0], dict) else completion[0] 
        for completion in completions
    ]
    
    global_step = kwargs.get("global_step", 0)
    rewards = []
    positions = []
    valid_formats = []
    
    for i, completion in enumerate(completion_contents):
        try:
            format_valid = check_format(completion)
            valid_formats.append(format_valid)
            
            if format_valid:
                position = extract_positioning(completion)
                positions.append(position)
                reward = position * returns[i % len(returns)]
                rewards.append(reward)
                
                # Log sample completions with their rewards
                if i < 2 and global_step % 20 == 0:
                    prompt_text = prompts[i][1]["content"] if isinstance(prompts[i], list) else prompts[i]
                    log_to_jsonl(COMPLETIONS_FILE, {
                        "step": global_step,
                        "prompt": prompt_text,
                        "completion": completion[:500],
                        "position": position,
                        "reward": reward,
                        "return": returns[i % len(returns)]
                    })
            else:
                rewards.append(-1.0)
        except Exception as e:
            logger.error(f"Error calculating reward: {e}")
            rewards.append(-1.0)
    
    # Log statistics
    if len(rewards) > 0:
        log_to_jsonl(REWARDS_FILE, {
            "step": global_step,
            "time": time.time(),
            "type": "returns",
            "avg_reward": sum(rewards) / len(rewards),
            "format_valid_rate": sum(valid_formats) / len(valid_formats) if valid_formats else 0,
            "avg_position": sum(positions) / len(positions) if positions else 0
        })
    
    return rewards

class SampleGenerationCallback(TrainerCallback):
    """Callback to generate sample completions during training"""
    def __init__(self, tokenizer, data):
        self.tokenizer = tokenizer
        self.data = data
    
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % 100 == 0:
            # Access model through the trainer
            trainer = kwargs.get("trainer", None)
            if trainer is not None and hasattr(trainer, "model"):
                self.generate_sample(trainer.model, state.global_step)
    
    def generate_sample(self, model, step):
        logger.info(f"\n===== Sample Generation at Step {step} =====")
        
        # Get a sample from the dataset
        sample = self.data[0]
        messages = copy.deepcopy(sample['prompt'])  # Deep copy to prevent modification
        
        # Generate completion
        text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = self.tokenizer([text], return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=512,
                temperature=0.7,
                do_sample=True
            )
        
        output = self.tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
        
        # Analyze output
        format_valid = check_format(output)
        position = extract_positioning(output) if format_valid else None
        
        # Log results
        logger.info(f"Format valid: {format_valid}")
        if position is not None:
            logger.info(f"Position: {position}")
        logger.info(f"Output: {output[:300]}...")
        
        # Save to completions file
        log_to_jsonl(COMPLETIONS_FILE, {
            "step": step,
            "type": "sample_generation",
            "prompt": text[:200],
            "completion": output,
            "format_valid": format_valid,
            "position": position
        })

def prepare_data(df):
    """Prepare training data from DataFrame"""
    logger.info("Preparing training data...")
    
    def prepare_prompt(df):
        df['prompt'] = df.apply(lambda row: [
            {
                "role": "system",
                "content": SYSTEM_PROMPT.strip()
            },
            {
                "role": "user",
                "content": f"Market Context:{', '.join(f'{k}:{v}' for k, v in row.drop('date', errors='ignore').items())}"
            }
        ], axis=1)
        return df

    df = prepare_prompt(df)
    df['returns'] = df['close'].pct_change().shift(-1)
    train_dataset = df[['prompt', 'returns']]
    data = Dataset.from_pandas(train_dataset, preserve_index=False)
    
    logger.info(f"Prepared {len(data)} training examples")
    return data

def main():
    logger.info("Starting GRPO Trader training")
    
    # Model configuration
    model_path = './huggingface_mirror/models--Qwen--Qwen2.5-0.5B-Instruct/snapshots/7ae557604adf67be50417f59c2c2f167def9a775'
    
    # Load dataset
    train_df = pd.read_csv('../train.csv')
    data = prepare_data(train_df)
    
    # Load model and tokenizer
    logger.info(f"Loading model from {model_path}")
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    ).to("cuda")

    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        padding_side="left",
        trust_remote_code=True
    )
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # GRPO Configuration
    grpo_config = GRPOConfig(
        output_dir=MODEL_DIR,
        run_name="Qwen-2.5-0.5B-GRPO-trader",
        learning_rate=5e-6,
        adam_beta1=0.9,
        adam_beta2=0.99,
        weight_decay=0.1,
        warmup_ratio=0.1,
        lr_scheduler_type='cosine',
        logging_steps=10,
        bf16=True,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        num_generations=4,
        max_prompt_length=1024,
        max_completion_length=512,
        num_train_epochs=1,
        save_steps=100,
        max_grad_norm=0.1,
        temperature=0.7,
        num_iterations=2,
        reward_weights=[1.0, 1.0],
        log_completions=True
    )
    
    # Create a wrapper for the reward functions to include the step information
    def format_reward_with_step(prompts, completions, **kwargs):
        return format_reward_func(prompts, completions, global_step=trainer.state.global_step, **kwargs)
        
    def return_reward_with_step(prompts, completions, returns, **kwargs):
        return return_reward(prompts, completions, returns, global_step=trainer.state.global_step, **kwargs)
    
    # Initialize trainer
    logger.info("Initializing trainer")
    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        args=grpo_config,
        train_dataset=data,
        reward_funcs=[
            lambda *args, **kwargs: format_reward_func(*args, global_step=trainer.state.global_step if hasattr(trainer, 'state') else 0, **kwargs),
            lambda *args, **kwargs: return_reward(*args, global_step=trainer.state.global_step if hasattr(trainer, 'state') else 0, **kwargs)
        ]
    )
    
    # Add callbacks
    sample_callback = SampleGenerationCallback(tokenizer, data)
    trainer.add_callback(sample_callback)
    
    # Start training
    logger.info("Starting training")
    try:
        trainer.train()
        logger.info("Training completed successfully")
    except Exception as e:
        logger.error(f"Training failed with error: {str(e)}")
    
    return trainer

if __name__ == "__main__":
    main()