In [None]:
class PPOPortfolioManager:
    """
    Portfolio manager using PPO to make macro-driven investment decisions
    """
    
    def __init__(self, model_name="Qwen/Qwen2.5-0.5B-Instruct"):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_name = model_name
        
        # PPO configuration
        self.ppo_config = PPOConfig(
            learning_rate=5e-5,
            per_device_train_batch_size=8,
            gradient_accumulation_steps=1,
            max_grad_norm=1.0,
            optimize_cuda_cache=True,
            num_train_epochs=3,
            seed=42,
            log_with="tensorboard",
            logging_dir="./output",  # Use a default value
            logging_steps=500,
            run_name="ppo_portfolio_manager",
            save_strategy="steps",
            save_steps=500,
            # PPO specific parameters
            num_ppo_epochs=4,
            gamma=0.99,
            lam=0.95,
            cliprange=0.2,
            cliprange_value=0.2,
            vf_coef=0.1,
            kl_coef=0.05,
            whiten_rewards=False,
            temperature=0.7,
            response_length=512,
            remove_unused_columns=False
        )
        
        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Initialize policy model with value head for PPO
        self.policy_model = AutoModelForCausalLMWithValueHead.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        
        # Initialize reference model
        from transformers import AutoModelForCausalLM
        self.ref_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        
        # Initialize value model
        self.value_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        
        # Initialize dummy reward model
        class DummyRewardModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.dummy_param = torch.nn.Parameter(torch.zeros(1))
                
            def forward(self, *args, **kwargs):
                return torch.zeros(1, device=self.dummy_param.device)
        
        self.reward_model = DummyRewardModel()
        
        # Initialize trainer to None
        self.trainer = None
        
        # Define text generation parameters
        self.generation_kwargs = {
            "max_new_tokens": 512,
            "temperature": 0.7,
            "top_p": 0.95,
            "do_sample": True,
        }

    def initialize_trainer(self):
        """Initialize the PPO trainer with dummy dataset"""
        dummy_data = {"prompt": [""] * 2, "response": [""] * 2, "reward": [0.0] * 2}
        dummy_dataset = Dataset.from_dict(dummy_data)
        
        # The current PPO Trainer signature with properly initialized models
        self.trainer = PPOTrainer(
            args=self.ppo_config,
            model=self.policy_model,
            ref_model=self.ref_model,
            reward_model=self.reward_model,
            value_model=self.value_model,
            train_dataset=dummy_dataset,
            tokenizer=self.tokenizer
        )

    # Rest of the methods remain the same
    def format_state(self, state: Dict) -> str:
        """Create prompt from current state"""
        # Format market context
        context_str = []
        for k, v in state['market_context'].items():
            # Format number with appropriate precision
            if isinstance(v, (int, float)):
                if abs(v) < 0.01:
                    formatted_value = f"{v:.6f}"
                elif abs(v) < 1:
                    formatted_value = f"{v:.4f}"
                else:
                    formatted_value = f"{v:.2f}"
            else:
                formatted_value = str(v)
                
            context_str.append(f"{k}: {formatted_value}")
        
        # Format headlines
        headlines_str = "\n".join([f"- {h}" for h in state['headlines']])
        
        # Format previous positions
        positions_str = ", ".join([f"{pos:.2f}" for pos in state['action_history']])
        
        # Combine all context
        prompt = f"{SYSTEM_PROMPT.strip()}\n\n"
        prompt += "Market Context:\n"
        prompt += ", ".join(context_str) + "\n\n"
        prompt += f"Current Position: {state['position']:.2f}\n\n"
        prompt += "Recent Headlines:\n"
        prompt += headlines_str + "\n\n"
        prompt += f"Previous Positions: [{positions_str}]"
        
        return prompt

    def extract_positioning(self, text: str) -> float:
        """Extract positioning value from XML response"""
        try:
            match = re.search(r"<positioning>(.*?)</positioning>", text, re.DOTALL)
            if match:
                position_str = match.group(1).strip()
                # Try to extract a float from the text
                try:
                    # First look for float patterns
                    float_pattern = r"[-+]?\d*\.\d+|\d+"
                    float_match = re.search(float_pattern, position_str)
                    if float_match:
                        return float(float_match.group())
                    else:
                        return float(position_str)
                except ValueError:
                    print(f"Could not convert position to float: {position_str}")
                    return 0.0
            return 0.0
        except Exception as e:
            print(f"Error extracting position: {e}")
            return 0.0

    def check_format(self, text: str) -> bool:
        """Check if response follows the required XML format"""
        pattern = r"<macro state>.*?</macro state>.*?<reasoning>.*?</reasoning>.*?<positioning>.*?</positioning>"
        return bool(re.search(pattern, text, re.DOTALL))

    def predict(self, state: Dict) -> Tuple[float, str, torch.Tensor, torch.Tensor]:
        """Generate trading decision based on current state"""
        prompt = self.format_state(state)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.policy_model.device)
        
        # Generate response
        outputs = self.policy_model.generate(
            inputs.input_ids,
            **self.generation_kwargs
        )
        
        # Decode response
        response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
        
        # Extract positioning
        position = self.extract_positioning(response)
        position = np.clip(position, -1.0, 1.0)
        
        return position, response, inputs.input_ids, outputs

    def ppo_step(self, query_tensors, response_tensors, rewards):
        """Run PPO optimization step"""
        if self.trainer is None:
            self.initialize_trainer()
        
        # Format inputs for PPO trainer
        texts = []
        for query, response in zip(query_tensors, response_tensors):
            prompt_text = self.tokenizer.decode(query, skip_special_tokens=True)
            response_text = self.tokenizer.decode(response[query.shape[0]:], skip_special_tokens=True)
            texts.append({
                "prompt": prompt_text,
                "response": response_text,
            })
        
        # Run PPO step with formatted texts and rewards
        stats = self.trainer.step(texts, rewards)
        return stats