In [1]:
import torch
def get_per_head_fixed_mask_causal(batch_length, mask_percents, flag, soft, device, soft_penalty=-1.0) -> torch.Tensor:
    num_heads = len(mask_percents)
    mask = torch.full((1, num_heads, batch_length, batch_length), float('-inf'), device=device)
    for h in range(num_heads):
        indices = torch.tril_indices(batch_length, batch_length, offset=0, device=device)
        mask[0, h, indices[0], indices[1]] = 0.0
        
        for i in range(1, batch_length):
            idx = torch.arange(0, i, device=device)
            num_tokens = min(int(len(idx) * mask_percents[h]), len(idx))
            if num_tokens > 0:
                if not flag:
                    to_mask = idx[:num_tokens]
                else:
                    to_mask = idx[torch.randperm(len(idx), device=device)[:num_tokens]]
                if soft:
                    mask[0, h, i, to_mask] = soft_penalty
                else:
                    mask[0, h, i, to_mask] = float('-inf')
    return mask

In [7]:
from mask_visualization import visualize_multi_head_masks, quick_demo

# This will now work perfectly!
fig = visualize_multi_head_masks(
    batch_length=64,
    mask_percents=[0.0, 0.0357, 0.0714, 0.1071, 0.1429, 0.1786, 0.2143, 0.25],
    flag=False,     # Sequential masking
    soft=False,      # Soft penalties
    device='cpu'
)
fig.savefig('my_masks.png')

  plt.tight_layout(rect=[0, 0.03, 1, 0.92])


In [6]:
import torch
import torch.nn as nn

class StateMaskGate(nn.Module):
    """Time-step action blinding gate (paper-aligned StateMask integration).

    Maps the agent's current latent (feature embedding) to a gate probability
    g_t in (0, 1). At rollout time, sample Bernoulli(g_t) to decide whether to
    pass the policy action through (g_t=1) or blind to a random action (g_t=0).

    This module is trained separately (e.g., periodically) to preserve the
    agent's action distribution (fidelity) while encouraging more blinding
    (sparsity), decoupled from the main RL objective.
    """

    def __init__(self, feat_dim: int, hidden_dim: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(feat_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, last_latent: torch.Tensor) -> torch.Tensor:
        """Returns gate probability in (0, 1) for each item in the batch.

        last_latent: [B, F] embedding of the current time step fed to the actor.
        """
        return torch.sigmoid(self.net(last_latent))  # [B, 1]

    @staticmethod
    def kl_fidelity_with_uniform(original_logits: torch.Tensor, gate_prob: torch.Tensor) -> torch.Tensor:
        """KL fidelity between original policy and a masked mixture with uniform.

        Approximates masked action distribution as a convex mixture of the
        original policy and a uniform distribution, weighted by gate_prob.
        """
        original_log_probs = original_logits.log_softmax(dim=-1)
        original_probs = original_log_probs.exp()
        num_actions = original_logits.shape[-1]
        uniform_probs = torch.full_like(original_probs, 1.0 / num_actions)
        mixed_probs = gate_prob * original_probs + (1.0 - gate_prob) * uniform_probs
        # KL(P || Q) = sum P log(P/Q)
        kl = (original_probs * (original_log_probs - mixed_probs.clamp_min(1e-8).log())).sum(dim=-1)
        return kl.mean()

    @staticmethod
    def sparsity_term(gate_prob: torch.Tensor) -> torch.Tensor:
        """Encourage more blinding: penalize pass-through probability.

        Lower gate_prob -> more blinding -> smaller penalty desired when masked.
        We penalize mean(gate_prob) to push towards masking non-critical steps.
        """
        return gate_prob.mean()


statemask =  StateMaskGate(feat_dim = 1536) # 1024 + 512
print(statemask)

StateMaskGate(
  (net): Sequential(
    (0): Linear(in_features=1536, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=128, out_features=1, bias=True)
  )
)
