In [None]:
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainerCallback
from trl import GRPOTrainer, GRPOConfig
from typing import Dict, List, Tuple
import json
from collections import deque
from datasets import Dataset
import re

# First, apply our monkey patch to fix the GRPOTrainer
def apply_no_pop_fix():
    """
    Apply a fix to prevent GRPOTrainer from popping messages from prompt.
    Call this before creating your GRPOTrainer instance.
    """
    from trl.data_utils import is_conversational
    
    # Store the original method
    original_method = GRPOTrainer._generate_and_score_completions
    
    # Define our patched method
    def patched_generate_and_score_completions(self, inputs):
        device = self.accelerator.device
        prompts = [x["prompt"] for x in inputs]
        
        # Make a deep copy of prompts if conversational to prevent modification
        is_conv = is_conversational(inputs[0])
        if is_conv:
            # Deep copy each prompt to prevent the original method from modifying them
            copied_inputs = []
            for inp in inputs:
                copied_prompt = []
                for message in inp["prompt"]:
                    copied_prompt.append({k: v for k, v in message.items()})
                
                copied_inp = {k: v for k, v in inp.items()}
                copied_inp["prompt"] = copied_prompt
                copied_inputs.append(copied_inp)
                
            # Call original with copied inputs
            result = original_method(self, copied_inputs)
            return result
            
        # For non-conversational, just use the original method
        return original_method(self, inputs)
    
    # Apply the monkey patch
    GRPOTrainer._generate_and_score_completions = patched_generate_and_score_completions
    
    print("✅ Applied fix to prevent GRPOTrainer from popping messages from prompt.")
    
    return GRPOTrainer

# Apply the fix
FixedGRPOTrainer = apply_no_pop_fix()

# System prompt for the portfolio manager
SYSTEM_PROMPT = """
You are a macro event driven portfolio manager, you make positioning decision of S&P500 index based on market context and news.

Only edit macro state if there are changes in the macro regime that would impact returns of S&P500.

Positioning should be a float that ranges from -1 (full short) to 1 (full long).

Example:
Market Context: PMI:52.6, inflation:5.1%, unemployment:3.8%

<macro state>
Economy is showing strength with PMI in expansion territory, but inflation remains elevated.
</macro state>
<reasoning>
The PMI at 52.6 indicates expansion, which is positive for equities. However, inflation at 5.1% is above the Fed's target, suggesting possible rate hikes. Employment is strong at 3.8%, supporting consumer spending.
</reasoning>
<positioning>
0.3
</positioning>

You must respond in the above XML format.
"""

def prepare_data(df):
    def prepare_prompt(df):
        df['prompt'] = df.apply(lambda row: [
            {
                "role": "system",
                "content": SYSTEM_PROMPT.strip()
            },
            {
                "role": "user",
                "content": f"Market Context:{', '.join(f'{k}:{v}' for k, v in row.drop('date', errors='ignore').items())}"
            }
        ], axis=1)
        return df

    df = prepare_prompt(df)
    df['returns'] = df['close'].pct_change().shift(-1)
    train_dataset = df[['prompt', 'returns']]
    data = Dataset.from_pandas(train_dataset, preserve_index=False)
    return data

def extract_positioning(text):
    try:
        match = re.search(r"<positioning>(.*?)</positioning>", text, re.DOTALL)
        if match:
            # Extract the value and clean any whitespace
            value = match.group(1).strip()
            return float(value)
        return 0.0
    except Exception as e:
        print(f"Positioning extraction error: {e}")
        return 0.0

def format_reward_func(prompts, completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<macro state>.*?</macro state>.*?<reasoning>.*?</reasoning>.*?<positioning>.*?</positioning>"
    completion_contents = [
        completion[0]["content"] if isinstance(completion[0], dict) else completion[0] 
        for completion in completions
    ]
    # Debug output for first few completions
    for i, content in enumerate(completion_contents[:3]):
        print(f"Completion {i}: {content[:100]}...")
    
    matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
    # Stronger reward/penalty for format
    return [3.0 if match else -2.0 for match in matches]

def return_reward(prompts, completions, returns, **kwargs):
    """Main reward function combining multiple factors"""
    rewards = []
    completion_contents = [
        completion[0]["content"] if isinstance(completion[0], dict) else completion[0] 
        for completion in completions
    ]
    
    pattern = r"<macro state>.*?</macro state>.*?<reasoning>.*?</reasoning>.*?<positioning>.*?</positioning>"
    for i, completion in enumerate(completion_contents):
        try:
            position = extract_positioning(completion)
            # Only reward correct format and positioning
            if re.search(pattern, completion, re.DOTALL):
                rewards.append(position * returns[i % len(returns)])
            else:
                rewards.append(-1.0)  # Penalty for incorrect format
        except Exception as e:
            print(f"Reward calculation error: {e} for completion: {completion[:100]}...")
            rewards.append(-1.0)
    return rewards

# Load and prepare data
train = pd.read_csv('../train.csv')
data = prepare_data(train)

# Model configuration
model_path = './huggingface_mirror/models--Qwen--Qwen2.5-0.5B-Instruct/snapshots/7ae557604adf67be50417f59c2c2f167def9a775'
output_dir = "outputs/Qwen-2.5-0.5B-GRPO-trader"
run_name = "Qwen-2.5-0.5B-GRPO-trader"

# GRPO Configuration with increased lengths and more exploration
grpo_config = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_generations=4,
    max_prompt_length=512,  # Increased from 256
    max_completion_length=512,  # Increased from 256
    num_train_epochs=1,
    save_steps=100,
    max_grad_norm=0.1,
    log_on_each_node=False,
    temperature=0.7,  # Increased from 0.1 for better exploration
    num_iterations=2,  # Now safe to use iterations > 1
    reward_weights=[1.0, 1.0],  # Equal weight for format and returns
    log_completions=True  # Enable logging completions for debugging
)

# Initialize model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True  # Important for Qwen models
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    padding_side="left",
    trust_remote_code=True  # Important for Qwen models
)

# Ensure tokenizer has pad token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Add simple callback to print sample completions at various steps
class CompletionDebugCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % 50 == 0:
            print(f"\n=== Step {state.global_step} Sample Completions ===")
            # Generate a sample completion with the current model state
            sample = data[0]
            messages = sample['prompt']
            text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = tokenizer([text], return_tensors="pt").to(model.device)
            
            with torch.no_grad():
                output_ids = model.generate(
                    **inputs,
                    max_new_tokens=256,
                    temperature=0.7,
                    do_sample=True
                )
            
            output = tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
            print(f"Output: {output}\n")
            
            # Check if it matches the format
            pattern = r"<macro state>.*?</macro state>.*?<reasoning>.*?</reasoning>.*?<positioning>.*?</positioning>"
            has_format = re.search(pattern, output, re.DOTALL) is not None
            print(f"Has correct format: {has_format}")
            
            # Try to extract positioning
            try:
                position = extract_positioning(output)
                print(f"Extracted positioning: {position}")
            except Exception as e:
                print(f"Failed to extract positioning: {e}")

# Use our fixed trainer
trainer = FixedGRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        format_reward_func,
        return_reward
    ],
    args=grpo_config,
    train_dataset=data,
)

# Add the callback
trainer.add_callback(CompletionDebugCallback())

# Start training
trainer.train()