In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import re
from collections import deque
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import Dataset
import matplotlib.pyplot as plt
from tqdm import tqdm
from typing import Dict, List, Tuple
import random
import time

# System prompt for the portfolio manager
SYSTEM_PROMPT = """
You are a macro event driven portfolio manager, you make positioning decision of S&P500 index based on market context and news.

Only edit macro state if there are changes in the macro regime that would impact returns of S&P500.

Positioning should be a float that ranges from -1 (full short) to 1 (full long).

You must respond in the following XML format:

<macro state>
...
</macro state>
<reasoning>
...
</reasoning>
<positioning>
...
</positioning>
"""

# =============================================================================
# 1. Environment Definition
# =============================================================================

class MacroTradingEnv:
    """
    Environment for training a macro trading agent
    Handles market state, position tracking, and reward calculation
    """
    
    def __init__(self, df: pd.DataFrame, window_size: int = 7):
        self.df = df.copy()
        self.position = 0.0  # -1 (full short) to 1 (full long)
        self.current_step = 0
        self.window_size = window_size
        self.headline_window = deque(maxlen=window_size)
        self.action_history = deque(maxlen=5)
        
        # Ensure required columns exist
        required_cols = ['headline', 'returns']
        for col in required_cols:
            assert col in self.df.columns, f"DataFrame must contain '{col}' column"
        
        # Initialize action history with zeros
        for _ in range(5):
            self.action_history.append(0.0)

    def reset(self, random_start: bool = True):
        """Reset environment, optionally to a random starting point"""
        if random_start:
            # Ensure we have enough data ahead for a full episode
            max_start = len(self.df) - 30
            self.current_step = random.randint(self.window_size, max_start) if max_start > self.window_size else self.window_size
        else:
            self.current_step = self.window_size
        
        # Reset state
        self.position = 0.0
        self.headline_window.clear()
        self.action_history.clear()
        
        # Initialize headline window with past headlines
        for i in range(self.window_size):
            idx = self.current_step - i - 1
            if idx >= 0:
                self.headline_window.appendleft(self.df.iloc[idx]['headline'])
            else:
                self.headline_window.appendleft("No headline available")
        
        # Initialize action history with zeros
        for _ in range(5):
            self.action_history.append(0.0)
            
        return self.get_state()

    def step(self, new_position: float) -> Tuple[Dict, float, bool, Dict]:
        """Execute position adjustment and return (state, reward, done, info)"""
        if self.current_step >= len(self.df) - 1:
            return self.get_state(), 0.0, True, {'status': 'completed'}
        
        # Calculate position change
        position_change = new_position - self.position
        transaction_cost = abs(position_change) * 0.001  # 0.1% friction
        
        # Update position
        self.position = new_position
        
        # Calculate return (using pre-calculated returns from dataframe)
        next_return = self.df.iloc[self.current_step]['returns']
        position_return = next_return * self.position
        
        # Move to next day
        self.current_step += 1
        
        # Update headline window and action history
        if self.current_step < len(self.df):
            self.headline_window.append(self.df.iloc[self.current_step]['headline'])
        self.action_history.append(new_position)
        
        # Calculate reward (return minus transaction cost)
        reward = position_return - transaction_cost
        
        # Check if episode is done
        done = (self.current_step >= len(self.df) - 1)
        
        info = {
            'return': position_return,
            'transaction_cost': transaction_cost,
            'position_change': position_change
        }
        
        return self.get_state(), reward, done, info

    def get_state(self) -> Dict:
        """Return current environment state dictionary"""
        if self.current_step >= len(self.df):
            self.current_step = len(self.df) - 1
            
        current_row = self.df.iloc[self.current_step]
        
        # Create context dictionary with all technical indicators
        context = {}
        for col in current_row.index:
            # Skip specific columns
            if col not in ['headline', 'returns', 'date']:
                context[col] = current_row[col]
        
        return {
            'market_context': context,
            'headlines': list(self.headline_window),
            'position': self.position,
            'action_history': list(self.action_history)
        }

# =============================================================================
# 2. LLM Trading Agent
# =============================================================================

class LLMTradingAgent:
    """
    Trading agent that uses a language model to make decisions based on market context
    """
    
    def __init__(self, model_name="facebook/opt-350m"):
        """Initialize with a smaller model that works reliably"""
        self.model_name = model_name
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Initialize tokenizer and model with trust_remote_code=True
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto",
            trust_remote_code=True
        )
        
        # Generation parameters
        self.generation_kwargs = {
            "max_new_tokens": 256,
            "temperature": 0.7,
            "top_p": 0.9,
            "do_sample": True,
        }
    
    def format_state(self, state: Dict) -> str:
        """Create prompt from current state"""
        # Format market context
        context_str = []
        for k, v in state['market_context'].items():
            # Format number with appropriate precision
            if isinstance(v, (int, float)):
                if abs(v) < 0.01:
                    formatted_value = f"{v:.6f}"
                elif abs(v) < 1:
                    formatted_value = f"{v:.4f}"
                else:
                    formatted_value = f"{v:.2f}"
            else:
                formatted_value = str(v)
                
            context_str.append(f"{k}: {formatted_value}")
        
        # Format headlines
        headlines_str = "\n".join([f"- {h}" for h in state['headlines']])
        
        # Format previous positions
        positions_str = ", ".join([f"{pos:.2f}" for pos in state['action_history']])
        
        # Combine all context
        prompt = f"{SYSTEM_PROMPT.strip()}\n\n"
        prompt += "Market Context:\n"
        prompt += ", ".join(context_str) + "\n\n"
        prompt += f"Current Position: {state['position']:.2f}\n\n"
        prompt += "Recent Headlines:\n"
        prompt += headlines_str + "\n\n"
        prompt += f"Previous Positions: [{positions_str}]"
        
        return prompt

    def extract_positioning(self, text: str) -> float:
        """Extract positioning value from XML response"""
        try:
            match = re.search(r"<positioning>(.*?)</positioning>", text, re.DOTALL)
            if match:
                position_str = match.group(1).strip()
                # Try to extract a float from the text
                try:
                    # First look for float patterns
                    float_pattern = r"[-+]?\d*\.\d+|\d+"
                    float_match = re.search(float_pattern, position_str)
                    if float_match:
                        return float(float_match.group())
                    else:
                        return float(position_str)
                except ValueError:
                    print(f"Could not convert position to float: {position_str}")
                    return 0.0
            return 0.0
        except Exception as e:
            print(f"Error extracting position: {e}")
            return 0.0

    def check_format(self, text: str) -> bool:
        """Check if response follows the required XML format"""
        pattern = r"<macro state>.*?</macro state>.*?<reasoning>.*?</reasoning>.*?<positioning>.*?</positioning>"
        return bool(re.search(pattern, text, re.DOTALL))

    def predict(self, state: Dict) -> Tuple[float, str]:
        """Generate trading decision based on current state"""
        prompt = self.format_state(state)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        # Generate response
        outputs = self.model.generate(
            inputs.input_ids,
            **self.generation_kwargs
        )
        
        # Decode response
        response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
        
        # Extract positioning
        position = self.extract_positioning(response)
        position = np.clip(position, -1.0, 1.0)
        
        return position, response

# =============================================================================
# 3. Custom PPO Implementation for LLM Trading
# =============================================================================

class Memory:
    def __init__(self, batch_size=32):
        self.states = []
        self.actions = []
        self.rewards = []
        self.next_states = []
        self.dones = []
        self.responses = []  # Store LLM text responses
        self.format_rewards = []  # Store format-specific rewards
        self.batch_size = batch_size
        
    def add(self, state, action, reward, next_state, done, response, format_reward):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.next_states.append(next_state)
        self.dones.append(done)
        self.responses.append(response)
        self.format_rewards.append(format_reward)
        
    def sample(self):
        """Sample a batch of experiences"""
        if len(self.states) < self.batch_size:
            indices = range(len(self.states))
        else:
            indices = np.random.choice(len(self.states), self.batch_size, replace=False)
            
        states = [self.states[i] for i in indices]
        actions = [self.actions[i] for i in indices]
        rewards = [self.rewards[i] for i in indices]
        next_states = [self.next_states[i] for i in indices]
        dones = [self.dones[i] for i in indices]
        responses = [self.responses[i] for i in indices]
        format_rewards = [self.format_rewards[i] for i in indices]
        
        return states, actions, rewards, next_states, dones, responses, format_rewards
    
    def clear(self):
        self.states.clear()
        self.actions.clear()
        self.rewards.clear()
        self.next_states.clear()
        self.dones.clear()
        self.responses.clear()
        self.format_rewards.clear()
        
    def __len__(self):
        return len(self.states)

class LLMPPOTrainer:
    """Custom PPO implementation for training LLM agents"""
    
    def __init__(self, agent, env, learning_rate=3e-5, gamma=0.99, epsilon=0.2, 
                 value_coef=0.5, entropy_coef=0.01, format_reward_weight=0.5):
        self.agent = agent
        self.env = env
        self.lr = learning_rate
        self.gamma = gamma
        self.epsilon = epsilon  # PPO clipping parameter
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef
        self.format_reward_weight = format_reward_weight
        self.memory = Memory()
        
        # Create value network (simple MLP)
        self.value_net = nn.Sequential(
            nn.Linear(len(self.vectorize_state(env.get_state())), 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        ).to(self.agent.device)
        
        # Create policy network (simple MLP)
        self.policy_net = nn.Sequential(
            nn.Linear(len(self.vectorize_state(env.get_state())), 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Tanh()  # Outputs between -1 and 1
        ).to(self.agent.device)
        
        # Optimizers
        self.value_optimizer = torch.optim.Adam(self.value_net.parameters(), lr=self.lr)
        self.policy_optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=self.lr)
        
    def vectorize_state(self, state):
        """Convert state dictionary to vector for neural network input"""
        # Extract features from state dictionary
        vector = []
        
        # Add position
        vector.append(state['position'])
        
        # Add market context features
        for key, value in state['market_context'].items():
            if isinstance(value, (int, float)):
                vector.append(value)
        
        # Add action history
        vector.extend(state['action_history'])
        
        # Include sentiment from headlines if available
        if 'llm_sentiment' in state['market_context']:
            vector.append(state['market_context']['llm_sentiment'])
            
        return np.array(vector, dtype=np.float32)
    
    def compute_format_reward(self, response: str) -> float:
        """Calculate reward for formatting according to required XML structure"""
        # Check overall structure
        has_correct_format = self.agent.check_format(response)
        
        # Check individual tags
        has_macro_state = "<macro state>" in response and "</macro state>" in response
        has_reasoning = "<reasoning>" in response and "</reasoning>" in response
        has_positioning = "<positioning>" in response and "</positioning>" in response
        
        # Calculate format reward component
        if has_correct_format:
            return 0.5  # Full format reward
        elif has_macro_state and has_reasoning and has_positioning:
            return 0.3  # Tags exist but not in correct order/format
        elif (has_macro_state and has_reasoning) or (has_macro_state and has_positioning) or (has_reasoning and has_positioning):
            return 0.1  # Some tags exist
        else:
            return -0.2  # Format completely wrong
    
    def collect_rollout(self, num_episodes=5, max_steps=21):
        """Collect experience by running episodes"""
        total_rewards = []
        
        for _ in tqdm(range(num_episodes), desc="Collecting experience"):
            state = self.env.reset()
            episode_reward = 0
            
            for step in range(max_steps):
                # Get LLM-based action 
                position, response = self.agent.predict(state)
                
                # Get neural network-based value estimate
                state_vector = self.vectorize_state(state)
                state_tensor = torch.FloatTensor(state_vector).unsqueeze(0).to(self.agent.device)
                value = self.value_net(state_tensor).item()
                
                # Calculate format reward
                format_reward = self.compute_format_reward(response)
                
                # Take action in environment
                next_state, reward, done, _ = self.env.step(position)
                
                # Combine environment reward and format reward
                combined_reward = reward + self.format_reward_weight * format_reward
                
                # Store experience
                self.memory.add(state, position, combined_reward, next_state, done, response, format_reward)
                
                episode_reward += combined_reward
                state = next_state
                
                if done:
                    break
                    
            total_rewards.append(episode_reward)
            
        return np.mean(total_rewards)
    
    def update_policy(self, num_epochs=10):
        """Update policy and value networks using PPO"""
        if len(self.memory) == 0:
            return
            
        for _ in range(num_epochs):
            # Sample batch of experiences
            states, actions, rewards, next_states, dones, responses, format_rewards = self.memory.sample()
            
            # Convert to tensors
            state_vectors = [self.vectorize_state(s) for s in states]
            state_tensors = torch.FloatTensor(state_vectors).to(self.agent.device)
            action_tensors = torch.FloatTensor(actions).unsqueeze(1).to(self.agent.device)
            reward_tensors = torch.FloatTensor(rewards).unsqueeze(1).to(self.agent.device)
            
            # Compute advantages
            with torch.no_grad():
                values = self.value_net(state_tensors)
                next_state_vectors = [self.vectorize_state(s) for s in next_states]
                next_state_tensors = torch.FloatTensor(next_state_vectors).to(self.agent.device)
                next_values = self.value_net(next_state_tensors)
                
                # Compute returns and advantages
                returns = []
                advantages = []
                for r, d, v, nv in zip(rewards, dones, values, next_values):
                    if d:
                        R = r
                    else:
                        R = r + self.gamma * nv.item()
                    adv = R - v.item()
                    returns.append(R)
                    advantages.append(adv)
                    
                returns_tensor = torch.FloatTensor(returns).unsqueeze(1).to(self.agent.device)
                advantages_tensor = torch.FloatTensor(advantages).unsqueeze(1).to(self.agent.device)
            
            # Normalize advantages
            if len(advantages_tensor) > 1:
                advantages_tensor = (advantages_tensor - advantages_tensor.mean()) / (advantages_tensor.std() + 1e-8)
            
            # Policy loss
            old_actions = self.policy_net(state_tensors).detach()
            new_actions = self.policy_net(state_tensors)
            
            # Get probability ratio
            ratio = torch.exp(new_actions - old_actions)
            
            # PPO loss
            clip_adv = torch.clamp(ratio, 1-self.epsilon, 1+self.epsilon) * advantages_tensor
            policy_loss = -torch.min(ratio * advantages_tensor, clip_adv).mean()
            
            # Value loss
            value_loss = nn.MSELoss()(self.value_net(state_tensors), returns_tensor)
            
            # Update policy network
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()
            
            # Update value network
            self.value_optimizer.zero_grad()
            value_loss.backward()
            self.value_optimizer.step()
    
    def train(self, num_iterations=50, num_episodes_per_iter=5, update_epochs=10):
        """Main training loop"""
        all_rewards = []
        
        for iteration in range(num_iterations):
            # Collect experiences
            mean_reward = self.collect_rollout(num_episodes=num_episodes_per_iter)
            
            # Update policy
            self.update_policy(num_epochs=update_epochs)
            
            # Clear memory after update
            self.memory.clear()
            
            all_rewards.append(mean_reward)
            
            # Print progress
            print(f"Iteration {iteration+1}/{num_iterations}, Mean Reward: {mean_reward:.4f}")
            
            # Save model periodically
            if (iteration + 1) % 10 == 0:
                self.save_models(f"models/iteration_{iteration+1}")
                
                # Plot rewards
                plt.figure(figsize=(10, 5))
                plt.plot(all_rewards)
                plt.title("Mean Episode Rewards")
                plt.xlabel("Iteration")
                plt.ylabel("Reward")
                plt.savefig(f"models/rewards_iter_{iteration+1}.png")
                plt.close()
                
        return all_rewards
    
    def save_models(self, path_prefix):
        """Save both the LLM and neural networks"""
        # Save policy network
        torch.save(self.policy_net.state_dict(), f"{path_prefix}_policy.pt")
        
        # Save value network
        torch.save(self.value_net.state_dict(), f"{path_prefix}_value.pt")
        
        # Save LLM model
        self.agent.model.save_pretrained(f"{path_prefix}_llm")
        self.agent.tokenizer.save_pretrained(f"{path_prefix}_llm")

# =============================================================================
# 4. Main Training Function
# =============================================================================

def train_macro_trader(df, model_name="facebook/opt-350m", num_iterations=50):
    """Train a macro trading model with PPO and the custom environment"""
    
    # Initialize environment and agent
    env = MacroTradingEnv(df)
    agent = LLMTradingAgent(model_name=model_name)
    
    # Initialize trainer
    trainer = LLMPPOTrainer(agent, env)
    
    # Train the agent
    rewards = trainer.train(num_iterations=num_iterations, num_episodes_per_iter=5)
    
    # Save final model
    trainer.save_models("models/final_model")
    
    return agent, env, rewards

# Usage example
if __name__ == "__main__":
    # Load data
    df = pd.read_csv('your_data.csv')
    
    # Ensure 'returns' column exists
    if 'returns' not in df.columns:
        df['returns'] = df['close'].pct_change().shift(-1)
    
    # Train macro trader
    agent, env, rewards = train_macro_trader(df, num_iterations=30)