In [None]:
def ppo_step(self, query_tensors, response_tensors, rewards):
    """Run PPO optimization step
    
    Parameters:
    - query_tensors: List of tensors for queries
    - response_tensors: List of tensors for responses
    - rewards: List of reward values
    
    Returns:
    - Statistics from the PPO update
    """
    if self.trainer is None:
        self.initialize_trainer()
    
    # Format inputs for PPO trainer
    texts = []
    for i in range(len(query_tensors)):
        query = query_tensors[i]
        response = response_tensors[i]
        
        # Make sure we're working with individual tensors, not batches
        if len(query.shape) > 1 and query.shape[0] > 1:
            # Handle batched tensors - we'll process each item separately
            for j in range(query.shape[0]):
                prompt_text = self.tokenizer.decode(query[j], skip_special_tokens=True)
                # Calculate where the response starts in the full sequence
                query_length = query[j].shape[0]
                response_text = self.tokenizer.decode(
                    response[j][query_length:], 
                    skip_special_tokens=True
                )
                texts.append({
                    "prompt": prompt_text,
                    "response": response_text,
                })
        else:
            # Handle single tensors
            prompt_text = self.tokenizer.decode(query, skip_special_tokens=True)
            query_length = query.shape[0]
            response_text = self.tokenizer.decode(
                response[query_length:], 
                skip_special_tokens=True
            )
            texts.append({
                "prompt": prompt_text,
                "response": response_text,
            })
    
    # Ensure rewards match the number of texts
    if isinstance(rewards, list) and len(rewards) != len(texts):
        # If lengths don't match, we need to expand the rewards
        if len(rewards) == 1:
            # If we have a single reward, duplicate it
            rewards = [rewards[0]] * len(texts)
        else:
            # Otherwise, we need to handle this mismatch more carefully
            print(f"Warning: Number of rewards ({len(rewards)}) doesn't match number of texts ({len(texts)})")
            # Simple approach: truncate or pad with the last value
            if len(rewards) < len(texts):
                last_reward = rewards[-1]
                rewards = rewards + [last_reward] * (len(texts) - len(rewards))
            else:
                rewards = rewards[:len(texts)]
    
    # Log what we're passing to the PPO trainer
    print(f"Running PPO step with {len(texts)} text pairs and {len(rewards) if isinstance(rewards, list) else 'tensor'} rewards")
    
    # Run PPO step with formatted texts and rewards
    try:
        stats = self.trainer.step(texts, rewards)
        return stats
    except Exception as e:
        print(f"Error in PPO step: {e}")
        import traceback
        traceback.print_exc()
        return None

In [None]:
def ppo_step(self, query_tensors, response_tensors, rewards):
    """Run PPO optimization step with proper formatting for TRL PPOTrainer
    
    This method is specifically formatted to match the exact expectations
    of Hugging Face's PPOTrainer.
    
    Parameters:
    - query_tensors: List of query tensors
    - response_tensors: List of response tensors
    - rewards: List of reward values
    
    Returns:
    - Statistics from the PPO update
    """
    if self.trainer is None:
        self.initialize_trainer()
    
    try:
        # Format inputs exactly as PPOTrainer expects
        # PPOTrainer wants a list of dicts with "prompt" and "response" keys
        texts = []
        
        for i in range(len(query_tensors)):
            query = query_tensors[i]
            response = response_tensors[i]
            
            # Remove batch dimension if present
            if len(query.shape) > 1:
                query = query.squeeze(0)
            if len(response.shape) > 1:
                response = response.squeeze(0)
            
            # Get the prompt text (full query)
            prompt_text = self.tokenizer.decode(query, skip_special_tokens=True)
            
            # For the response, we need only the part after the prompt
            # Determine the length of the query (in tokens)
            query_length = query.shape[0]
            
            # Extract only the response portion (everything after the query)
            if response.shape[0] > query_length:
                response_only = response[query_length:]
                response_text = self.tokenizer.decode(response_only, skip_special_tokens=True)
            else:
                # Handle case where response doesn't extend beyond query
                response_text = ""
            
            # Add to the formatted texts
            texts.append({
                "prompt": prompt_text,
                "response": response_text
            })
            
        # Ensure rewards are a flat list of floats matching the texts
        if isinstance(rewards, list):
            # Convert any non-float rewards to floats
            rewards = [float(r) for r in rewards]
            
            # Make sure lengths match
            if len(rewards) != len(texts):
                print(f"Warning: rewards length ({len(rewards)}) doesn't match texts length ({len(texts)})")
                if len(rewards) > len(texts):
                    rewards = rewards[:len(texts)]
                else:
                    last_reward = rewards[-1] if rewards else 0.0
                    rewards = rewards + [last_reward] * (len(texts) - len(rewards))
        else:
            # Try to convert tensor to list if needed
            try:
                rewards_list = rewards.flatten().tolist()
                rewards = rewards_list[:len(texts)] if len(rewards_list) >= len(texts) else rewards_list + [0.0] * (len(texts) - len(rewards_list))
            except:
                print("Warning: Could not process rewards properly")
                rewards = [0.0] * len(texts)
        
        print(f"Running PPO step with {len(texts)} text pairs")
        
        # Run the PPO step with properly formatted inputs
        stats = self.trainer.step(texts, rewards)
        return stats
        
    except Exception as e:
        print(f"Error in PPO step: {e}")
        import traceback
        traceback.print_exc()
        
        # Print sample inputs for debugging
        if len(query_tensors) > 0:
            print("\nSample query tensor shape:", query_tensors[0].shape)
        if len(response_tensors) > 0:
            print("Sample response tensor shape:", response_tensors[0].shape)
        print("Rewards type:", type(rewards))
        if isinstance(rewards, list) and len(rewards) > 0:
            print("Sample reward:", rewards[0])
            
        return None