In [None]:
import torch
import math
import torch.nn.functional as F

def linear_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, apply_softmax: bool = False) -> torch.Tensor:
    """Linear attention

    Args:
        q (torch.Tensor): (B, T, D)
        k (torch.Tensor): (B, T, D)
        v (torch.Tensor): (B, T, D)

    Returns:
        torch.Tensor: (B, T, T)
    """

    # attn[i, j] = <q_i, k_j>
    attn = torch.einsum("bid,bjd->bij", q, k)
    if apply_softmax:
        attn = F.softmax(attn, dim=-1)

    # y[i] = Σⱼ attn[i, j] * v[j]
    y = torch.einsum("bij,bjd->bid", attn, v)
    return y


class AttnMach(torch.nn.Module):
    def __init__(self, n_vars, d_vocab, d_hidden):
        super().__init__()
        self.n_vars = n_vars
        self.d_vocab = d_vocab
        self.d_hidden = d_hidden

        # Initialize weights
        std_fan_in = torch.sqrt(torch.tensor(2.0)) / d_vocab**0.5
        self._w_v = torch.nn.Parameter(torch.randn(1, d_vocab) * std_fan_in)
        self._w_q = torch.nn.Parameter(torch.randn(d_hidden, d_vocab) * std_fan_in)
        self._w_k = torch.nn.Parameter(torch.randn(d_hidden, d_vocab) * std_fan_in)

    @property
    def w_v(self) -> torch.Tensor:
        return self._w_v.abs()

    @property
    def w_q(self) -> torch.Tensor:
        return self._w_q.abs()

    @property
    def w_k(self) -> torch.Tensor:
        return self._w_k.abs()


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Computes log prob of seqs.

        Args:
            x (torch.Tensor): (B, n_vars, d_vocab)

        Returns:
            torch.Tensor: (B,)
        """
        # Validate input
        assert torch.all(x < self.d_vocab), f"Expected input to be in range [0, {self.d_vocab}), got max {x.max()}"
        x = torch.nn.functional.one_hot(x, num_classes=self.d_vocab).to(torch.get_default_dtype())  # (B, n_vars, d_vocab)
        B, N, D = x.shape

        # Compute p_tilde
        q = torch.einsum("btd,hd->bth", x, self.w_q)
        k = torch.einsum("btd,hd->bth", x, self.w_k)
        v = torch.einsum("btd,hd->bth", x, self.w_v)
        p_tilde = linear_attention(q, k, v)  # (B, N, d_value)
        p_tilde = p_tilde[:, :, 0].sum(dim=-1)  # (B,)

        # Compute z_tilde
        z_tilde = torch.einsum("hq,hk,pk->", self.w_q, self.w_k, self.w_v)  # (1,)

        # Check p_tilde always positive
        assert torch.all(p_tilde > 0), "p_tilde is not always positive"

        # Compute loss
        loss = (N-2)*math.log(D) + 2*math.log(N) + z_tilde.log() - p_tilde.log()

        return loss.mean()
        
        

In [None]:
d_vocab = 2
d_hidden = 8
n_vars = 4

x = torch.randint(0, d_vocab, (1, n_vars))

# Given discrete r.v.s will fit a density to it
model = AttnMach(n_vars, d_vocab, d_hidden)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
for i in range(10_000):
    optimizer.zero_grad()
    loss = model(x)
    loss.backward()
    optimizer.step()
    if loss.item() < 0:
        raise ValueError("Loss is negative")
    if i % 100 == 0:
        print(f"Iteration {i} loss: {loss.item()}")



In [242]:
def test_linear_attention():
    B, T, D = 1, 4, 8
    q, k, v = torch.randn(B, T, D).abs(), torch.randn(B, T, D).abs(), torch.randn(B, T, D).abs()
    y_lin = lin_attn = linear_attention(q, k, v, apply_softmax=True)
    y_pt = pt_attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=1.0)
    return (y_lin == y_pt).all().item()

test_res_str = "PASS" if test_linear_attention() else "FAIL"
print(f"[{test_res_str}] linear_attention == scaled_dot_product_attention")

[PASS] linear_attention == scaled_dot_product_attention


In [258]:
# Random uniform distribution loss
n_vars*math.log(d_vocab)

2.772588722239781