In [None]:
import numpy as np
import pandas as pd
import torch
import os
import json
import logging
import re
import time
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainerCallback
from trl import GRPOTrainer, GRPOConfig
from datasets import Dataset

SYSTEM_PROMPT = """
You are a macro event driven portfolio manager, you make positioning decision of S&P500 index based on market context and news.

Only edit macro state if there are changes in the macro regime that would impact returns of S&P500.

Positioning should be a float that ranges from -1 (full short) to 1 (full long).

Example:
Market Context: PMI:52.6, inflation:5.1%, unemployment:3.8%

<macro state>
Economy is showing strength with PMI in expansion territory, but inflation remains elevated.
</macro state>
<reasoning>
The PMI at 52.6 indicates expansion, which is positive for equities. However, inflation at 5.1% is above the Fed's target, suggesting possible rate hikes. Employment is strong at 3.8%, supporting consumer spending.
</reasoning>
<positioning>
0.3
</positioning>

You must respond in the above XML format.
"""

class Logger:
    def __init__(self, log_dir):
        os.makedirs(log_dir, exist_ok=True)
        
        self.log_file = os.path.join(log_dir, f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
        self.reward_file = os.path.join(log_dir, f"rewards_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl")
        self.completion_file = os.path.join(log_dir, f"completions_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl")
        
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(self.log_file),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger("GRPO-Trader")
        
        self.rewards = []
        self.completions = []
        
    def log_info(self, message):
        self.logger.info(message)
        
    def log_warning(self, message):
        self.logger.warning(message)
        
    def log_error(self, message):
        self.logger.error(message)
        
    def log_reward(self, step, reward_data):
        entry = {"step": step, "time": time.time(), **reward_data}
        self.rewards.append(entry)
        with open(self.reward_file, 'a') as f:
            f.write(json.dumps(entry) + '\n')
            
    def log_completion(self, step, prompt, completion, reward, format_valid, position=None):
        entry = {
            "step": step,
            "time": time.time(),
            "prompt": prompt,
            "completion": completion,
            "reward": reward,
            "format_valid": format_valid,
            "position": position
        }
        self.completions.append(entry)
        with open(self.completion_file, 'a') as f:
            f.write(json.dumps(entry) + '\n')
            
    def save_rewards(self):
        with open(self.reward_file, 'w') as f:
            for entry in self.rewards:
                f.write(json.dumps(entry) + '\n')
                
    def save_completions(self):
        with open(self.completion_file, 'w') as f:
            for entry in self.completions:
                f.write(json.dumps(entry) + '\n')

def prepare_data(df, logger):
    logger.log_info("Preparing training data...")
    
    def prepare_prompt(df):
        df['prompt'] = df.apply(lambda row: [
            {
                "role": "system",
                "content": SYSTEM_PROMPT.strip()
            },
            {
                "role": "user",
                "content": f"Market Context:{', '.join(f'{k}:{v}' for k, v in row.drop('date', errors='ignore').items())}"
            }
        ], axis=1)
        return df

    df = prepare_prompt(df)
    df['returns'] = df['close'].pct_change().shift(-1)
    train_dataset = df[['prompt', 'returns']]
    data = Dataset.from_pandas(train_dataset, preserve_index=False)
    
    logger.log_info(f"Prepared {len(data)} training examples")
    logger.log_info(f"Sample prompt: {data[0]['prompt']}")
    
    return data

def extract_positioning(text):
    try:
        match = re.search(r"<positioning>(.*?)</positioning>", text, re.DOTALL)
        if match:
            value = match.group(1).strip()
            return float(value)
        return 0.0
    except Exception as e:
        return 0.0

def format_reward_func(prompts, completions, **kwargs):
    logger = kwargs.get("logger", None)
    step = kwargs.get("step", 0)
    
    pattern = r"<macro state>.*?</macro state>.*?<reasoning>.*?</reasoning>.*?<positioning>.*?</positioning>"
    completion_contents = [
        completion[0]["content"] if isinstance(completion[0], dict) else completion[0] 
        for completion in completions
    ]
    
    if logger and step % 50 == 0:
        for i, content in enumerate(completion_contents[:2]):
            logger.log_info(f"Step {step} - Completion {i}: {content[:100]}...")
    
    matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
    format_rewards = [3.0 if match else -2.0 for match in matches]
    
    if logger:
        format_match_rate = sum(matches) / len(matches) if matches else 0
        logger.log_reward(step, {
            "format_match_rate": format_match_rate,
            "avg_format_reward": sum(format_rewards) / len(format_rewards) if format_rewards else 0
        })
        
    return format_rewards

def return_reward(prompts, completions, returns, **kwargs):
    logger = kwargs.get("logger", None)
    step = kwargs.get("step", 0)
    
    rewards = []
    completion_contents = [
        completion[0]["content"] if isinstance(completion[0], dict) else completion[0] 
        for completion in completions
    ]
    
    pattern = r"<macro state>.*?</macro state>.*?<reasoning>.*?</reasoning>.*?<positioning>.*?</positioning>"
    position_rewards = []
    format_valid_count = 0
    
    for i, completion in enumerate(completion_contents):
        try:
            position = extract_positioning(completion)
            format_valid = re.search(pattern, completion, re.DOTALL) is not None
            
            if format_valid:
                format_valid_count += 1
                position_reward = position * returns[i % len(returns)]
                position_rewards.append(position_reward)
                rewards.append(position_reward)
            else:
                rewards.append(-1.0)
                
            if logger and i < 2 and step % 50 == 0:
                prompt_text = prompts[i][1]["content"] if isinstance(prompts[i], list) else prompts[i]
                logger.log_completion(
                    step=step,
                    prompt=prompt_text,
                    completion=completion[:500],
                    reward=rewards[-1],
                    format_valid=format_valid,
                    position=position if format_valid else None
                )
                
        except Exception as e:
            rewards.append(-1.0)
    
    if logger:
        format_valid_rate = format_valid_count / len(completion_contents) if completion_contents else 0
        avg_position_reward = sum(position_rewards) / len(position_rewards) if position_rewards else 0
        
        logger.log_reward(step, {
            "format_valid_rate": format_valid_rate,
            "avg_position_reward": avg_position_reward,
            "avg_total_reward": sum(rewards) / len(rewards) if rewards else 0
        })
    
    return rewards

class MetricsCallback(TrainerCallback):
    def __init__(self, logger, tokenizer, data):
        self.logger = logger
        self.tokenizer = tokenizer
        self.data = data
        self.metrics_history = {
            "loss": [],
            "learning_rate": [],
            "format_match_rate": [],
            "reward": []
        }
        
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs:
            for key, value in logs.items():
                if key in self.metrics_history:
                    self.metrics_history[key].append((state.global_step, value))
            
            self.logger.log_info(f"Step {state.global_step} - Metrics: {logs}")
    
    def on_step_end(self, args, state, control, model=None, **kwargs):
        if state.global_step % 100 == 0 and model is not None:
            self.sample_generation(state.global_step, model)
    
    def sample_generation(self, step, model):
        self.logger.log_info(f"\n=== Step {step} Sample Generation ===")
        sample = self.data[0]
        messages = sample['prompt']
        text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = self.tokenizer([text], return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=512,
                temperature=0.7,
                do_sample=True
            )
        
        output = self.tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
        
        pattern = r"<macro state>.*?</macro state>.*?<reasoning>.*?</reasoning>.*?<positioning>.*?</positioning>"
        has_format = re.search(pattern, output, re.DOTALL) is not None
        
        try:
            position = extract_positioning(output) if has_format else None
        except:
            position = None
            
        self.logger.log_completion(
            step=step,
            prompt=text[:200],
            completion=output,
            reward=None,
            format_valid=has_format,
            position=position
        )
        
        self.logger.log_info(f"Output: {output[:200]}...")
        self.logger.log_info(f"Has correct format: {has_format}")
        if position is not None:
            self.logger.log_info(f"Extracted positioning: {position}")

def main():
    log_dir = "./logs"
    output_dir = "./outputs/Qwen-2.5-0.5B-GRPO-trader"
    run_name = "Qwen-2.5-0.5B-GRPO-trader"
    
    logger = Logger(log_dir)
    logger.log_info("Starting GRPO Trader training")
    
    model_path = './huggingface_mirror/models--Qwen--Qwen2.5-0.5B-Instruct/snapshots/7ae557604adf67be50417f59c2c2f167def9a775'
    
    logger.log_info(f"Loading dataset")
    train_df = pd.read_csv('../train.csv')
    data = prepare_data(train_df, logger)
    
    logger.log_info(f"Loading model and tokenizer from {model_path}")
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    ).to("cuda")

    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        padding_side="left",
        trust_remote_code=True
    )
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    logger.log_info("Configuring training parameters")
    grpo_config = GRPOConfig(
        output_dir=output_dir,
        run_name=run_name,
        learning_rate=5e-6,
        adam_beta1=0.9,
        adam_beta2=0.99,
        weight_decay=0.1,
        warmup_ratio=0.1,
        lr_scheduler_type='cosine',
        logging_steps=10,
        bf16=True,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        num_generations=4,
        max_prompt_length=1024,
        max_completion_length=512,
        num_train_epochs=1,
        save_steps=100,
        max_grad_norm=0.1,
        log_on_each_node=False,
        temperature=0.7,
        num_iterations=2,
        reward_weights=[1.0, 1.0],
        log_completions=True
    )
    
    # Create a factory for reward functions that capture the logger and step
    def create_reward_funcs(logger, model):
        def format_reward_with_logging(prompts, completions, **kwargs):
            step = model.state.global_step if hasattr(model, 'state') else 0
            return format_reward_func(prompts, completions, logger=logger, step=step)
            
        def return_reward_with_logging(prompts, completions, returns, **kwargs):
            step = model.state.global_step if hasattr(model, 'state') else 0
            return return_reward(prompts, completions, returns, logger=logger, step=step)
            
        return [format_reward_with_logging, return_reward_with_logging]
    
    logger.log_info("Initializing trainer")
    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=create_reward_funcs(logger, trainer),
        args=grpo_config,
        train_dataset=data,
    )
    
    metrics_callback = MetricsCallback(logger, tokenizer, data)
    trainer.add_callback(metrics_callback)
    
    logger.log_info("Starting training")
    try:
        trainer.train()
        logger.log_info("Training completed successfully")
    except Exception as e:
        logger.log_error(f"Training failed with error: {str(e)}")
    finally:
        logger.save_rewards()
        logger.save_completions()
        logger.log_info("Saved all logs and metrics")
    
    return trainer

if __name__ == "__main__":
    main()