In [1]:
import sys, os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)


import torch
from src.tabnet_step import TabNetStep

# Dummy input
batch_size = 8
input_dim = 16
n_d = 16
output_dim = 16

x = torch.randn(batch_size, input_dim)
prior = torch.ones(batch_size, input_dim)

# Initialize step
tabnet_step = TabNetStep(
    input_dim=input_dim,
    n_d=n_d,
    output_dim=output_dim,
    shared_transformer=None  # or pass shared layers if you want
)

# Forward pass
decision_out, next_feat, updated_prior, mask = tabnet_step(x, prior)

# Inspect outputs
print("Decision output shape:", decision_out.shape)
print("Next feature shape:", next_feat.shape)
print("Updated prior shape:", updated_prior.shape)
print("Mask shape:", mask.shape)

# Quick checks
print("Sum across mask rows (should be ~1):", mask.sum(dim=1))
print("Example mask row:", mask[0])



Decision output shape: torch.Size([8, 16])
Next feature shape: torch.Size([8, 16])
Updated prior shape: torch.Size([8, 16])
Mask shape: torch.Size([8, 16])
Sum across mask rows (should be ~1): tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)
Example mask row: tensor([0.0000, 0.1745, 0.0660, 0.0463, 0.0121, 0.3839, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0863, 0.1289, 0.0000, 0.0000, 0.1020, 0.0000],
       grad_fn=<SelectBackward0>)
