In [None]:
import re
import numpy as np
import pandas as pd
import torch
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    PreTrainedModel,
    PreTrainedTokenizerBase,
)
from datasets import Dataset

from trl import GRPOTrainer, GRPOConfig
from trl.trainer.grpo_trainer import is_conversational, maybe_apply_chat_template


##############################################################################
# 1) Setup your system prompt
##############################################################################

SYSTEM_PROMPT = """
You are a macro event driven portfolio manager, you make positioning decision of S&P500 index based on market context and news.

Only edit macro state if there are changes in the macro regime that would impact returns of S&P500.

Positioning should be a float that ranges from -1 (full short) to 1 (full long).

You must respond in the following XML format:

<macro state>
...
</macro state>
<reasoning>
...
</reasoning>
<positioning>
...
</positioning>
""".strip()

##############################################################################
# 2) Prepare your data as a conversation
#    IMPORTANT: Each "prompt" MUST be a list of role-content dicts
##############################################################################

def prepare_data(df: pd.DataFrame) -> Dataset:
    """Creates a 'prompt' column with [system,user] structure."""
    def row_to_prompt(row: pd.Series) -> List[Dict[str, str]]:
        # Example conversation structure
        # (System) + (User).
        return [
            {
                "role": "system",
                "content": SYSTEM_PROMPT
            },
            {
                "role": "user",
                # Join all columns except date (if present)
                "content": "Market Context:" + ", ".join(
                    f"{k}:{v}" for k, v in row.drop("date", errors="ignore").items()
                )
            }
        ]

    df["prompt"] = df.apply(row_to_prompt, axis=1)

    # For demonstration, define returns or any numeric reward column
    if "close" in df.columns:
        df["returns"] = df["close"].pct_change().shift(-1)
    else:
        df["returns"] = 0.05  # dummy reward if you don’t have close
    
    # Keep only needed columns
    train_dataset = df[["prompt", "returns"]]
    # Convert to a HuggingFace Dataset
    data = Dataset.from_pandas(train_dataset, preserve_index=False)
    return data

##############################################################################
# 3) Two reward functions:
#    (a) format_reward_func: checks if model output matches the XML pattern
#    (b) return_reward: tries to parse <positioning> and scale by 'returns'
##############################################################################

# Regex fix: Use DOTALL so '.' can match newlines
PATTERN = re.compile(
    r"<macro state>.*?</macro state>\s*<reasoning>.*?</reasoning>\s*<positioning>.*?</positioning>",
    re.DOTALL
)

def format_reward_func(prompts, completions, **kwargs):
    """
    Gives +2.0 if the entire XML structure is found, else 0.0
    (Increased from 0.5 to 2.0 for a stronger training signal.)
    """
    completion_contents = [
        completion[0]["content"] if isinstance(completion, list) else completion
        for completion in completions
    ]
    rewards = []
    for content in completion_contents:
        match = re.search(PATTERN, content)
        rewards.append(2.0 if match else 0.0)
    return rewards

def extract_positioning(text: str) -> float:
    try:
        match = re.search(r"<positioning>(.*?)</positioning>", text, flags=re.DOTALL)
        if match:
            return float(match.group(1).strip())
        return 0.0
    except:
        return 0.0

def return_reward(prompts, completions, returns, **kwargs):
    """
    Reward based on position * returns.
    (If <positioning> is missing, that portion is 0.)
    """
    rewards = []
    completion_contents = [
        completion[0]["content"] if isinstance(completion, list) else completion
        for completion in completions
    ]
    for comp_text, r in zip(completion_contents, returns):
        pos = extract_positioning(comp_text)
        # If r is NaN, treat it as 0
        r = 0.0 if pd.isna(r) else float(r)
        # multiply them
        rewards.append(pos * r)
    return rewards

##############################################################################
# 4) Patch the "pop" logic so it does NOT remove your system/user messages
##############################################################################
# We'll override GRPOTrainer by subclassing and modifying _generate_and_score_completions

from trl.trainer.grpo_trainer import GRPOTrainer as OriginalGRPOTrainer

class GRPOTrainerNoPop(OriginalGRPOTrainer):
    def _generate_and_score_completions(
        self, inputs: dict[str, Union[torch.Tensor, Any]]
    ) -> dict[str, Union[torch.Tensor, Any]]:
        device = self.accelerator.device
        prompts = [x["prompt"] for x in inputs]

        # Convert your conversation to the text that the model sees
        # same as in official code: maybe_apply_chat_template
        prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]

        prompt_inputs = self.processing_class(
            prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
        )
        prompt_inputs = super(OriginalGRPOTrainer, self)._prepare_inputs(prompt_inputs)
        prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

        if self.max_prompt_length is not None:
            prompt_ids = prompt_ids[:, -self.max_prompt_length :]
            prompt_mask = prompt_mask[:, -self.max_prompt_length :]

        # ----- Next lines are the same as official code, we just skip them for brevity. -----
        # We'll show the portion that changes the "pop" logic

        # Let’s call the parent method to do the main generation & reward steps:
        # but we need to replicate the parent's code with the slight fix near
        # the "if is_conversational" block.

        # -- so let's just inline the parent's code from official _generate_and_score_completions:

        # (We do NOT do `prompt.pop()`.)

        # For demonstration, let's do it fully:

        # (1) Generate completions
        if self.args.use_vllm:
            # ... omitted for brevity ...
            raise NotImplementedError("For simplicity, skip vLLM in this snippet.")
        else:
            with self.accelerator.unwrap_model(self.model) as unwrapped_model:
                prompt_completion_ids = unwrapped_model.generate(
                    prompt_ids,
                    attention_mask=prompt_mask,
                    generation_config=self.generation_config,
                )
        # (2) Separate the prompt tokens from completion tokens
        prompt_length = prompt_ids.size(1)
        prompt_ids = prompt_completion_ids[:, :prompt_length]
        completion_ids = prompt_completion_ids[:, prompt_length:]

        # (3) Compute completion_mask by stopping at first EOS
        is_eos = completion_ids == self.processing_class.eos_token_id
        eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
        eos_rows = is_eos.any(dim=1)
        eos_idx[eos_rows] = is_eos.int().argmax(dim=1)[eos_rows]
        seq_idx = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
        completion_mask = (seq_idx <= eos_idx.unsqueeze(1)).int()

        # (4) Build final attention_mask
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)

        # (5) Possibly compute old_per_token_logps, ref_per_token_logps, etc. (skipping details)
        old_per_token_logps = None
        ref_per_token_logps = None

        # (6) Convert completions_ids to text
        completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)

        # (7) If is_conversational => produce completions as a list of role-based dict
        #     Notice we remove the old "pop()" logic and just build a new assistant message
        if is_conversational(inputs[0]):
            completions = []
            for prompt_data, ctext in zip(prompts, completions_text):
                # We do NOT pop anything; we just say:
                completions.append([{"role": "assistant", "content": ctext}])
        else:
            completions = completions_text

        # (8) Evaluate all reward functions
        # We'll do a minimal version:
        device = self.accelerator.device
        rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
        for i, reward_func in enumerate(self.reward_funcs):
            # Build the kwargs:
            keys = [k for k in inputs[0] if k not in ["prompt", "completion"]]
            reward_kwargs = {k: [ex[k] for ex in inputs] for k in keys}
            out = reward_func(prompts, completions, **reward_kwargs)
            rewards_per_func[:, i] = torch.tensor(out, dtype=torch.float32, device=device)
        # Sum weighted reward
        total_rewards = rewards_per_func.sum(dim=1)  # if no weighting

        # (9) Compute advantage & fill in the final dict.
        # For brevity, let's skip the grouping logic. We'll just store the expansions:
        return {
            "prompt_ids": prompt_ids,
            "prompt_mask": prompt_mask,
            "completion_ids": completion_ids,
            "completion_mask": completion_mask,
            "old_per_token_logps": old_per_token_logps,
            "ref_per_token_logps": ref_per_token_logps,
            "advantages": torch.zeros_like(total_rewards),  # placeholder
            # You would replicate the normal advantage logic from the official code
        }


##############################################################################
# 5) Putting it all together in an actual script
##############################################################################

def main():
    # Example usage

    # Build a sample DataFrame
    df = pd.DataFrame({
        "close": [100, 101, 102, 105, 104, 107],
        "news": ["Fed hawkish", "Earnings up", "Geopolitical tension", "Jobs data", "Another event", "Yet more news"]
    })

    # Prepare data
    dataset = prepare_data(df)

    # Model config
    model_path = "Qwen/Qwen2-7B-Chat"  # example checkpoint
    output_dir = "./grpo-trader-out"

    grpo_config = GRPOConfig(
        output_dir=output_dir,
        run_name="Qwen-1.5B-GRPO-trader",
        learning_rate=5e-6,
        bf16=True,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        num_generations=4,
        max_prompt_length=256,
        max_completion_length=256,
        num_train_epochs=1,
        logging_steps=1,
        temperature=0.1,
        remove_unused_columns=False,
        # etc. if needed
    )

    # Load model / tokenizer
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    # Override the pop logic by using our `GRPOTrainerNoPop`
    trainer = GRPOTrainerNoPop(
        model=model,
        processing_class=tokenizer,
        reward_funcs=[format_reward_func, return_reward],
        args=grpo_config,
        train_dataset=dataset,
    )

    # Train
    trainer.train()

    # Now you can do `trainer.model.generate(...)` to see if it produces the XML structure

if __name__ == "__main__":
    main()
