In [None]:
import torch
from torch import nn

In [None]:
import torch
import torch.nn as nn

class PDSSetNetwork(nn.Module):
    def __init__(self, group_size=99):
        super().__init__()
        self.group_size = group_size
        
        # DeepSets encoder: process each element independently
        self.element_encoder = nn.Sequential(
            nn.Linear(1, 64),  # element index → embedding
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU()
        )
        
        # Set aggregation with attention
        self.set_attention = nn.MultiheadAttention(
            embed_dim=128, 
            num_heads=8,
            batch_first=True
        )
        
        # Decision head: should element i be in PDS?
        self.decision_head = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
    
    def forward(self, current_pds_binary):
        # Embed each group element
        element_indices = torch.arange(self.group_size).float().unsqueeze(-1)
        element_embeddings = self.element_encoder(element_indices)
        
        # Attention over current PDS elements
        current_pds_mask = current_pds_binary.bool()
        attended_features, _ = self.set_attention(
            element_embeddings, element_embeddings, element_embeddings,
            key_padding_mask=~current_pds_mask
        )
        
        # Decide on each element
        decisions = self.decision_head(attended_features).squeeze(-1)
        return decisions

In [None]:
# Usage with PPO
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

class PDSFeaturesExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim=256):
        super().__init__(observation_space, features_dim)
        self.pds_net = HybridPDSNetwork(observation_space.shape[0])
    
    def forward(self, observations):
        return self.pds_net(observations)

# Train with custom architecture
model = PPO(
    "MlpPolicy", 
    env, 
    policy_kwargs=dict(
        features_extractor_class=PDSFeaturesExtractor,
        features_extractor_kwargs=dict(features_dim=256)
    )
)