In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
from datasets import Dataset
import numpy as np
import gym
from gym import spaces


class FlagTraderAgent:
    def __init__(self, model_name, device='cuda:0'):
        self.device = device
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # PPO Configuration
        self.config = PPOConfig(
            model_name=model_name,
            learning_rate=5e-5,
            batch_size=32,
            mini_batch_size=4,
            gradient_accumulation_steps=8,
            num_train_epochs=3,
            vf_coef=0.1,
            kl_coef=0.05,
            cliprange=0.2,
            gamma=0.99,
            lam=0.95,
            output_dir="ppo_output",
            seed=42
        )

        # Load policy model with value head (actor-critic)
        self.policy_model = AutoModelForCausalLMWithValueHead.from_pretrained(
            model_name, torch_dtype=torch.float16, device_map='cuda:0'
        )

        # Load reference model (for PPO's KL penalty)
        self.ref_model = AutoModelForCausalLM.from_pretrained(
            model_name, torch_dtype=torch.float16, device_map='cuda:0'
        )

        # Simple Reward Model (just returns reward as is)
        class SimpleRewardModel(torch.nn.Module):
            def forward(self, input_ids, attention_mask=None):
                return torch.ones((input_ids.shape[0], 1), device=input_ids.device)

        self.reward_model = SimpleRewardModel().to('cuda:0')

        # Value model is embedded within policy_model (provided by trl)
        self.value_model = self.policy_model

        # PPO Trainer
        self.trainer = PPOTrainer(
            args=self.config,
            model=self.policy_model,
            ref_model=self.policy_model,  # using same as reference
            reward_model=self.reward_model,
            tokenizer=self.tokenizer,
            train_dataset=None,
            value_model=self.value_model,
            data_collator=None
        )

    def generate_prompt(self, state):
        return f"Trading Task:\nState: {state}\nActions:[Buy,Sell,Hold]\nDecision:"

    def predict(self, state):
        prompt = self.generate_prompt(state)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        outputs = self.policy_model.generate(**inputs, max_new_tokens=16, temperature=0.7, do_sample=True)
        decision = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return decision.split("Decision:")[-1].strip()

class FlagTraderEnv(gym.Env):
    def __init__(self, data, initial_balance=10000):
        super().__init__()
        self.data = data
        self.initial_balance = initial_balance
        self.action_space = spaces.Discrete(3)  # Buy, Hold, Sell
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=data.shape[1:])
        self.reset()

    def reset(self):
        self.current_step = 0
        self.balance = self.initial_balance
        self.asset = 0
        return self.data[self.current_step]

    def step(self, action):
        prev_value = self.balance + self.asset * self.data[self.current_step, -1]
        price = self.data[self.current_step, -1]

        if action == 0 and self.balance >= price:  # Buy
            self.asset += self.balance // price
            self.balance %= price
        elif action == 2 and self.asset > 0:  # Sell
            self.balance += self.asset * price
            self.asset = 0

        self.current_step += 1
        done = self.current_step >= len(self.data) - 1
        next_state = self.data[self.current_step]

        portfolio_value = self.balance + self.asset * self.data[self.current_step, -1]
        reward = portfolio_value - prev_value  # return-based reward

        return next_state, reward, done, {}

from tqdm import tqdm
from datasets import Dataset

def action_str_to_num(action_str):
    return {'Buy': 0, 'Hold': 1, 'Sell': 2}.get(action_str, 1)

def orchestrate_training(agent, env, epochs=3, steps_per_epoch=200):
    for epoch in range(epochs):
        experiences = []
        state = env.reset()
        total_reward = 0

        for _ in tqdm(range(steps_per_epoch), desc=f"Epoch {epoch+1}/{epochs}"):
            action_str = agent.predict(state)
            action_num = action_str_to_num(action_str)
            next_state, reward, done, _ = env.step(action_num)

            query = agent.generate_prompt(state)
            experiences.append({
                "query": query,
                "response": action_str,
                "reward": reward
            })

            total_reward += reward
            state = next_state
            if done:
                state = env.reset()

        # Convert experiences to Dataset
        dataset = Dataset.from_list(experiences)
        agent.trainer.train_dataset = dataset

        # PPO Train step
        agent.trainer.train()

        avg_reward = total_reward / steps_per_epoch
        print(f"Epoch {epoch+1}: Avg Reward = {avg_reward:.2f}")

# Initialize random market data (replace with real data for real cases)
market_data = np.random.rand(1000, 10)  # 10 features per timestep

# Initialize agent and environment
env = FlagTraderEnv(market_data)
agent = FlagTraderAgent("Qwen/Qwen2.5-0.5B-Instruct")

# Run PPO training
orchestrate_training(agent, env)
