In [None]:
import torch
import numpy as np
import gym
from gym import spaces
from typing import Dict, List, Tuple, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
from datasets import Dataset
from tqdm import tqdm
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class FlagTraderAgent:
    """
    Implementation of FLAG-TRADER agent: Fusion LLM-Agent with Gradient-based
    Reinforcement Learning for Financial Trading.
    """
    def __init__(self, model_name, device='cuda:0', freeze_layers=0.7):
        """
        Initialize FLAG-TRADER agent with parameter-efficient fine-tuning
        
        Args:
            model_name (str): Name of the LLM to use
            device (str): Device to run the model on
            freeze_layers (float): Proportion of layers to freeze (0.0-1.0)
        """
        self.device = device
        self.model_name = model_name
        
        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # PPO Configuration
        self.config = PPOConfig(
            model_name=model_name,
            learning_rate=5e-5,
            batch_size=32,
            mini_batch_size=4,
            gradient_accumulation_steps=8,
            num_train_epochs=3,
            vf_coef=0.1,
            kl_coef=0.05,
            cliprange=0.2,
            gamma=0.99,
            lam=0.95,
            output_dir="ppo_output",
            seed=42
        )

        # Load policy model with value head (actor-critic)
        self.policy_model = AutoModelForCausalLMWithValueHead.from_pretrained(
            model_name, 
            torch_dtype=torch.float16, 
            device_map=self.device
        )

        # Implement parameter-efficient fine-tuning
        self._freeze_bottom_layers(freeze_proportion=freeze_layers)
        
        # Load reference model (for PPO's KL penalty)
        self.ref_model = AutoModelForCausalLM.from_pretrained(
            model_name, 
            torch_dtype=torch.float16, 
            device_map=self.device
        )

        # Value model is embedded within policy_model (provided by trl)
        self.value_model = self.policy_model
        
        # Simple reward model
        class SimpleRewardModel(torch.nn.Module):
            def forward(self, input_ids, attention_mask=None):
                return torch.ones((input_ids.shape[0], 1), device=input_ids.device)

        self.reward_model = SimpleRewardModel().to(self.device)

        # PPO Trainer
        self.trainer = PPOTrainer(
            config=self.config,
            model=self.policy_model,
            ref_model=self.ref_model,
            tokenizer=self.tokenizer,
            dataset=None,  # Will be set during training
            optimizer=None  # Will use default Adam optimizer
        )
    
    def _freeze_bottom_layers(self, freeze_proportion: float):
        """
        Implement parameter-efficient fine-tuning by freezing bottom layers
        
        Args:
            freeze_proportion (float): Proportion of layers to freeze (0.0-1.0)
        """
        # Get all transformers layers
        if hasattr(self.policy_model, 'transformer'):
            transformer_layers = self.policy_model.transformer.h
        elif hasattr(self.policy_model, 'model') and hasattr(self.policy_model.model, 'layers'):
            transformer_layers = self.policy_model.model.layers
        else:
            logger.warning("Could not identify transformer layers. No layers frozen.")
            return
        
        # Calculate how many layers to freeze
        num_layers = len(transformer_layers)
        num_frozen = int(num_layers * freeze_proportion)
        
        logger.info(f"Freezing {num_frozen} of {num_layers} transformer layers")
        
        # Freeze bottom layers
        for i in range(num_frozen):
            for param in transformer_layers[i].parameters():
                param.requires_grad = False
        
        # Log number of trainable parameters
        total_params = sum(p.numel() for p in self.policy_model.parameters())
        trainable_params = sum(p.numel() for p in self.policy_model.parameters() if p.requires_grad)
        logger.info(f"Total parameters: {total_params:,}")
        logger.info(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%})")

    def generate_prompt(self, state: Dict) -> str:
        """
        Generate a structured prompt for the LLM based on the current state
        
        Args:
            state (Dict): Current market state including price history and account info
            
        Returns:
            str: Formatted prompt for the LLM
        """
        # Extract key information from state
        price_data = state.get("price_data", [])
        account_info = state.get("account_info", {})
        
        # Format state information into a structured prompt
        prompt = (
            "Financial Stock Trading\n"
            "Task: Assist in making optimal buy, hold, or sell decisions for stock "
            "portfolio. The goal is to maximize returns while managing risk.\n\n"
            "Legible Actions: Choose from \"Buy\", \"Sell\", or \"Hold\" based on "
            "market conditions and risk assessment.\n\n"
            "Current State:\n"
            f"Price History: {price_data}\n"
            f"Cash Balance: {account_info.get('cash', 0):.2f}\n"
            f"Asset Position: {account_info.get('asset', 0)}\n"
            f"Total Value: {account_info.get('total', 0):.2f}\n\n"
            "Output Action: "
        )
        
        return prompt

    def predict(self, state: Dict) -> str:
        """
        Predict the next action given the current state
        
        Args:
            state (Dict): Current market state
            
        Returns:
            str: Predicted action (Buy, Sell, or Hold)
        """
        prompt = self.generate_prompt(state)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        # Generate response with temperature for exploration
        outputs = self.policy_model.generate(
            **inputs, 
            max_new_tokens=16, 
            temperature=0.7, 
            do_sample=True,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id
        )
        
        # Extract the generated text
        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Parse the action from the generated text
        action = generated_text.split("Output Action:")[-1].strip()
        
        # Normalize action text - handle various formats the LLM might produce
        if "buy" in action.lower():
            return "Buy"
        elif "sell" in action.lower():
            return "Sell"
        else:
            return "Hold"

class FlagTraderEnv(gym.Env):
    """
    Trading environment for FLAG-TRADER, implementing Sharpe ratio based rewards
    """
    def __init__(self, data, initial_balance=10000, window_size=10, risk_free_rate=0):
        """
        Initialize trading environment
        
        Args:
            data (np.ndarray): Market data with time steps as rows and features as columns
            initial_balance (float): Initial cash balance
            window_size (int): Size of historical window to use for state
            risk_free_rate (float): Risk-free rate for Sharpe ratio calculation
        """
        super().__init__()
        
        if not isinstance(data, np.ndarray):
            raise TypeError("Data must be a numpy array")
        
        self.data = data
        self.initial_balance = initial_balance
        self.window_size = window_size
        self.risk_free_rate = risk_free_rate
        
        # Define action and observation spaces
        self.action_space = spaces.Discrete(3)  # Buy (0), Hold (1), Sell (2)
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf, shape=(window_size, data.shape[1])
        )
        
        # Reset environment
        self.reset()

    def reset(self):
        """
        Reset the environment to initial state
        
        Returns:
            dict: Initial state
        """
        self.current_step = self.window_size  # Start after window_size to have history
        self.balance = self.initial_balance
        self.asset = 0
        self.pnl_history = []
        
        # Get initial state
        return self._get_state()

    def step(self, action: int) -> Tuple[Dict, float, bool, Dict]:
        """
        Take a step in the environment based on the action
        
        Args:
            action (int): Action to take (0: Buy, 1: Hold, 2: Sell)
            
        Returns:
            tuple: (state, reward, done, info)
        """
        if self.current_step >= len(self.data) - 1:
            return self._get_state(), 0, True, {}
        
        # Get current price
        current_price = self.data[self.current_step, -1]
        
        # Store portfolio value before action
        prev_value = self.balance + self.asset * current_price
        
        # Execute action
        if action == 0 and self.balance >= current_price:  # Buy
            shares_to_buy = self.balance // current_price
            self.asset += shares_to_buy
            self.balance -= shares_to_buy * current_price
        elif action == 2 and self.asset > 0:  # Sell
            self.balance += self.asset * current_price
            self.asset = 0
        
        # Move to next time step
        self.current_step += 1
        
        # Calculate new portfolio value
        new_price = self.data[self.current_step, -1]
        new_value = self.balance + self.asset * new_price
        
        # Calculate daily PnL and add to history
        daily_pnl = new_value - prev_value
        self.pnl_history.append(daily_pnl)
        
        # Calculate reward based on Sharpe ratio
        reward = self._calculate_sharpe_ratio()
        
        # Check if episode is done
        done = self.current_step >= len(self.data) - 1
        
        # Create info dict
        info = {
            'portfolio_value': new_value,
            'daily_pnl': daily_pnl,
            'sharpe_ratio': reward
        }
        
        return self._get_state(), reward, done, info

    def _get_state(self) -> Dict:
        """
        Get current state representation
        
        Returns:
            dict: Dictionary containing state information
        """
        # Get window of historical data
        historical_window = self.data[self.current_step - self.window_size:self.current_step]
        
        # Create account info
        current_price = self.data[self.current_step, -1]
        portfolio_value = self.balance + self.asset * current_price
        
        # Format state as a dictionary
        state = {
            "price_data": historical_window.tolist(),
            "account_info": {
                "cash": float(self.balance),
                "asset": int(self.asset),
                "total": float(portfolio_value)
            }
        }
        
        return state
    
    def _calculate_sharpe_ratio(self) -> float:
        """
        Calculate Sharpe ratio using PnL history
        
        Returns:
            float: Sharpe ratio value
        """
        if len(self.pnl_history) < 2:
            return 0.0
            
        pnl_array = np.array(self.pnl_history)
        avg_pnl = np.mean(pnl_array)
        std_pnl = np.std(pnl_array)
        
        # Avoid division by zero
        if std_pnl == 0:
            return 0.0
            
        sharpe = (avg_pnl - self.risk_free_rate) / std_pnl
        
        # Annualize the Sharpe ratio (assuming daily data, √252 factor)
        annualized_sharpe = sharpe * np.sqrt(252)
        
        return annualized_sharpe

def action_str_to_num(action_str: str) -> int:
    """
    Convert action string to numeric action
    
    Args:
        action_str (str): Action string (Buy, Hold, or Sell)
        
    Returns:
        int: Numeric action (0: Buy, 1: Hold, 2: Sell)
    """
    action_map = {'Buy': 0, 'Hold': 1, 'Sell': 2}
    return action_map.get(action_str, 1)  # Default to Hold if action not recognized

def orchestrate_training(agent: FlagTraderAgent, env: FlagTraderEnv, epochs: int = 3, steps_per_epoch: int = 200) -> None:
    """
    Train the FLAG-TRADER agent using PPO
    
    Args:
        agent (FlagTraderAgent): The FLAG-TRADER agent
        env (FlagTraderEnv): Trading environment
        epochs (int): Number of training epochs
        steps_per_epoch (int): Number of steps per epoch
    """
    for epoch in range(epochs):
        logger.info(f"Starting epoch {epoch+1}/{epochs}")
        
        # Reset collections for this epoch
        experiences = []
        states = []
        actions = []
        action_strs = []
        rewards = []
        total_reward = 0
        
        # Reset environment
        state = env.reset()
        
        # Collect experiences
        for step in tqdm(range(steps_per_epoch), desc=f"Collecting data - Epoch {epoch+1}/{epochs}"):
            # Get action from agent
            action_str = agent.predict(state)
            action_num = action_str_to_num(action_str)
            
            # Take action in environment
            next_state, reward, done, info = env.step(action_num)
            
            # Store experience
            query = agent.generate_prompt(state)
            experiences.append({
                "query": query,
                "response": action_str,
                "reward": float(reward)
            })
            
            # Store state, action, and reward
            states.append(state)
            actions.append(action_num)
            action_strs.append(action_str)
            rewards.append(reward)
            
            # Update total reward
            total_reward += reward
            
            # Update state
            state = next_state
            
            # Reset if episode is done
            if done:
                state = env.reset()
        
        # Convert experiences to Dataset
        train_dataset = Dataset.from_list(experiences)
        
        # Train with PPO
        logger.info(f"Training on {len(experiences)} experiences")
        agent.trainer.train(train_dataset)
        
        # Calculate and log metrics
        avg_reward = total_reward / steps_per_epoch
        logger.info(f"Epoch {epoch+1}: Avg Reward = {avg_reward:.4f}")
        
        # Additional logging for action distribution
        action_counts = {
            'Buy': action_strs.count('Buy'),
            'Hold': action_strs.count('Hold'),
            'Sell': action_strs.count('Sell')
        }
        logger.info(f"Action distribution: {action_counts}")
    
    logger.info("Training complete")

def create_synthetic_market_data(time_steps: int = 1000, features: int = 5, seed: Optional[int] = None) -> np.ndarray:
    """
    Create synthetic market data for testing
    
    Args:
        time_steps (int): Number of time steps
        features (int): Number of features per time step
        seed (int, optional): Random seed
        
    Returns:
        np.ndarray: Synthetic market data
    """
    if seed is not None:
        np.random.seed(seed)
    
    # Create base price series with random walk
    price = 100 + np.cumsum(np.random.normal(0.001, 0.02, time_steps))
    
    # Ensure price is positive
    price = np.maximum(price, 0.1)
    
    # Create additional features (could be technical indicators, etc.)
    data = np.random.randn(time_steps, features - 1)
    
    # Add price as the last column
    return np.column_stack((data, price))

def main():
    """Main function to run FLAG-TRADER training"""
    # Create synthetic market data
    market_data = create_synthetic_market_data(time_steps=1000, features=10, seed=42)
    
    # Initialize environment
    env = FlagTraderEnv(market_data, initial_balance=10000, window_size=10)
    
    # Initialize agent
    agent = FlagTraderAgent("Qwen/Qwen2.5-0.5B-Instruct", device='cuda:0', freeze_layers=0.7)
    
    # Run PPO training
    orchestrate_training(agent, env, epochs=3, steps_per_epoch=200)
    
    # Save the trained model
    agent.policy_model.save_pretrained("flag_trader_model")
    agent.tokenizer.save_pretrained("flag_trader_model")
    
    logger.info("Model saved to flag_trader_model/")

if __name__ == "__main__":
    main()