In [None]:
def create_training_dataset(self, query_tensors, response_tensors, rewards):
    """Create a dataset for PPO training from experience
    
    Handles tensors of any shape and ensures proper conversion to strings
    
    Parameters:
    - query_tensors: List of query tensors (any shape)
    - response_tensors: List of response tensors (any shape)
    - rewards: List of reward values
    
    Returns:
    - Dataset object ready for PPO training
    """
    # Format data for the dataset
    formatted_data = {
        "prompt": [],
        "response": [],
        "reward": []
    }
    
    # Debug info
    print(f"Creating dataset from {len(query_tensors)} experiences")
    if len(query_tensors) > 0:
        print(f"Sample query tensor shape: {query_tensors[0].shape}")
        print(f"Sample response tensor shape: {response_tensors[0].shape}")
    
    # Process each example
    for i in range(len(query_tensors)):
        try:
            # Get individual tensors
            query = query_tensors[i]
            response = response_tensors[i]
            
            # Handle any tensor shape by flattening to 1D if needed
            # We need to ensure we're working with 1D tensors
            if len(query.shape) > 1:
                # For tensors with shape [1, 1, sequence_length]
                if len(query.shape) == 3:
                    query = query.squeeze()  # Remove extra dimensions
                # For tensors with shape [1, sequence_length]
                elif len(query.shape) == 2:
                    query = query.squeeze(0)  # Remove batch dimension
            
            # Same for response tensor
            if len(response.shape) > 1:
                if len(response.shape) == 3:
                    response = response.squeeze()
                elif len(response.shape) == 2:
                    response = response.squeeze(0)
            
            # Get the prompt text (full query)
            prompt_text = self.tokenizer.decode(query, skip_special_tokens=True)
            
            # For the response, we need only the part after the prompt
            full_response_text = self.tokenizer.decode(response, skip_special_tokens=True)
            
            # Extract response portion by removing the prompt text prefix
            # This handles token ID conversion issues better than token-based slicing
            if full_response_text.startswith(prompt_text):
                response_text = full_response_text[len(prompt_text):]
            else:
                # If we can't find the exact prefix, use the full response
                # (better than empty response for training)
                response_text = full_response_text
            
            # Add to the formatted data if both texts are non-empty
            if prompt_text and response_text:
                formatted_data["prompt"].append(prompt_text)
                formatted_data["response"].append(response_text)
                
                # Get the corresponding reward
                if i < len(rewards):
                    reward_value = float(rewards[i])
                else:
                    reward_value = 0.0
                formatted_data["reward"].append(reward_value)
            else:
                print(f"Skipping example {i} - empty prompt or response")
                
        except Exception as e:
            print(f"Error processing experience {i}: {e}")
            # Skip this example
    
    # Make sure we have at least some data
    if len(formatted_data["prompt"]) == 0:
        print("Warning: No valid examples found. Creating dummy data.")
        formatted_data["prompt"] = ["dummy prompt"] * 2
        formatted_data["response"] = ["dummy response"] * 2
        formatted_data["reward"] = [0.0] * 2
    
    # Create a dataset from the formatted data
    dataset = Dataset.from_dict(formatted_data)
    print(f"Created dataset with {len(dataset)} examples")
    return dataset

In [None]:
def predict(self, state: Dict) -> Tuple[float, str, torch.Tensor, torch.Tensor]:
    """Generate trading decision based on current state
    
    Returns consistent tensor shapes for PPO training
    
    Parameters:
    - state: Current environment state
    
    Returns:
    - position: Float value representing the position (-1 to 1)
    - response: Decoded text of the response
    - query_tensor: Token IDs of the query 
    - response_tensor: Token IDs of the full response (including query)
    """
    # Format the state into a prompt
    prompt = self.format_state(state)
    
    # Tokenize the input 
    inputs = self.tokenizer(prompt, return_tensors="pt").to(self.policy_model.device)
    
    # Generate response
    try:
        with torch.no_grad():
            outputs = self.policy_model.generate(
                inputs.input_ids,
                **self.generation_kwargs
            )
        
        # Decode response (everything after the prompt)
        full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        response_text = full_response[len(prompt):]
        
        # Extract positioning value
        position = self.extract_positioning(full_response)
        position = np.clip(position, -1.0, 1.0)
        
        # Ensure we're returning 2D tensors (batch, sequence) for consistency
        query_tensor = inputs.input_ids  # Already has shape [1, sequence_length]
        response_tensor = outputs        # Should have shape [1, sequence_length]
        
        # Check and fix tensor shapes for debugging
        if len(query_tensor.shape) != 2:
            print(f"Warning: query_tensor has unexpected shape {query_tensor.shape}")
            query_tensor = query_tensor.view(1, -1)  # Reshape to [1, sequence_length]
            
        if len(response_tensor.shape) != 2:
            print(f"Warning: response_tensor has unexpected shape {response_tensor.shape}")
            response_tensor = response_tensor.view(1, -1)  # Reshape to [1, sequence_length]
        
        return position, response_text, query_tensor, response_tensor
        
    except Exception as e:
        print(f"Error in prediction: {e}")
        import traceback
        traceback.print_exc()
        
        # Return safe defaults
        return 0.0, "", inputs.input_ids, inputs.input_ids  # Return same tensor as both query and response