In [99]:
import torch
import torch.nn.functional as F

class AttnMach(torch.nn.Module):
    def __init__(self, n_vars, d_vocab, d_hidden):
        super().__init__()
        self.n_vars = n_vars
        self.d_vocab = d_vocab
        self.d_hidden = d_hidden

        self._w_v = torch.nn.Parameter(torch.randn(1, d_vocab))
        self._w_q = torch.nn.Parameter(torch.randn(d_hidden, d_vocab))
        self._w_k = torch.nn.Parameter(torch.randn(d_hidden, d_vocab))

    @property
    def w_v(self) -> torch.Tensor:
        return self._w_v.abs()

    @property
    def w_q(self) -> torch.Tensor:
        return self._w_q.abs()

    @property
    def w_k(self) -> torch.Tensor:
        return self._w_k.abs()

    def _contract(self, x: torch.Tensor) -> torch.Tensor:
        q = torch.einsum("bvd,hd->bvh", x, self.w_q)
        k = torch.einsum("bvd,hd->bvh", x, self.w_k)
        v = torch.einsum("bvd,hd->bvh", x, self.w_v)
        y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
        return y


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Computes log prob of seqs.

        Args:
            x (torch.Tensor): (B, n_vars, d_vocab)

        Returns:
            torch.Tensor: (B,)
        """
        assert torch.all(x < self.d_vocab), f"Expected input to be in range [0, {self.d_vocab}), got max {x.max()}"
        x = torch.nn.functional.one_hot(x, num_classes=self.d_vocab).to(torch.get_default_dtype())  # (B, n_vars, d_vocab)

        p_tilde = self._contract(x)
        z = self._contract(torch.ones_like(x))
        print(p_tilde.shape, z.shape)

        if (p_tilde > z.prod(dim=-1)).any():
            print("Error: p_tilde > z")

        loss = z.log().sum(dim=-1) - p_tilde.squeeze(-1).log().sum(dim=-1)
        return loss.mean()
        
        

In [100]:
d_vocab = 2
d_hidden = 8
n_vars = 4

x = torch.randint(0, d_vocab, (1, n_vars))
# print(x.shape)

attn_mach = AttnMach(n_vars, d_vocab, d_hidden)
loss = attn_mach(x)
loss


torch.Size([1, 4, 1]) torch.Size([1, 4, 1])


tensor(-0.1248, grad_fn=<MeanBackward0>)

In [17]:
y.shape

torch.Size([1, 4, 1])