In [1]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))  # adjust if needed
if project_root not in sys.path:
    sys.path.insert(0, project_root)


In [2]:
import torch
from src.feature_transformer import FeatureTransformer

# Dummy input
batch_size = 8
input_dim = 16
n_d = 16
n_steps = 3

x = torch.randn(batch_size, input_dim)

# Initialize transformer
ft = FeatureTransformer(input_dim=input_dim, n_d=n_d, n_steps=n_steps)

# Forward pass through each step
for step in range(n_steps):
    out = ft(x, step_idx=step)
    print(f"Step {step} output shape:", out.shape)


Step 0 output shape: torch.Size([8, 16])
Step 1 output shape: torch.Size([8, 16])
Step 2 output shape: torch.Size([8, 16])


In [3]:
from src.glu import GatedLinearUnit
import torch

# Dummy input
x = torch.randn(8, 16)

# Initialize GLU block
glu = GatedLinearUnit(input_dim=16, output_dim=16)

# Forward pass
out = glu(x)
print("GLU output shape:", out.shape)



GLU output shape: torch.Size([8, 16])


In [4]:
import torch
from src.attentive_transformer import AttentiveTransformer

# Dummy input
batch_size = 8
input_dim = 16

x = torch.randn(batch_size, input_dim)
prior = torch.ones(batch_size, input_dim)  # start with uniform prior

attn = AttentiveTransformer(input_dim=input_dim, output_dim=input_dim)
mask = attn(x, prior)

print("Attention mask shape:", mask.shape)
print("Sum across features (should be close to 1):", mask.sum(dim=1))
print("Some example rows:\n", mask[:3])


Attention mask shape: torch.Size([8, 16])
Sum across features (should be close to 1): tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)
Some example rows:
 tensor([[0.0000, 0.0000, 0.0760, 0.1058, 0.1270, 0.3628, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.1546, 0.0000, 0.0000, 0.1738],
        [0.0000, 0.0000, 0.0400, 0.2520, 0.0000, 0.0368, 0.0000, 0.0000, 0.0110,
         0.0000, 0.0000, 0.0000, 0.0793, 0.0000, 0.2426, 0.3383],
        [0.0181, 0.0098, 0.0434, 0.2471, 0.0000, 0.0044, 0.0311, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0600, 0.2371, 0.0000, 0.1259, 0.2231]],
       grad_fn=<SliceBackward0>)
