In [None]:
from init_notebook import *

In [None]:
class DiagonalEmbedding(nn.Module):

    def __init__(
            self,
            channels_in: int,
            channels_out: int,
            diagonal: bool = True,
            symmetric: bool = True,
            fft: bool = False,
            fft_concat_dim: int = -1,
            requires_grad: bool = True,
    ):
        """
        Wrapper around torch.nn.Embedding

        :param channels_in: int, vocabulary size
        :param channels_out: int, internal representation size
        :param diagonal: bool, if True, the embedding weights are initialized with
            a diagonal matrix, e.g. if channels_in==channels_out, the representation
            matches the input
        :param symmetric: bool, if True, the embedding weights and readout weights are shared.
            If False, the readout has its own set of weights.
        :param fft: bool, If True, the representation is the concatenation of the
            real and imaginary FFT transform
        :param fft_concat_dim: int, either -1 or -2,
            if -1, fft real and imaginary output is concatenated along the sequence dimension
            if -2, it's concatenated along the channel dimensions and `channels_out` is divided by 2!
        """
        super().__init__()
        self._diagonal = diagonal
        self._symmetric = symmetric
        self._fft = fft
        self._fft_concat_dim = fft_concat_dim

        if fft:
            assert sum(int(x) for x in list(bin(channels_out)[2:])) == 1, \
                f"`channels_in` must be a power of 2 when using `fft`"
            assert fft_concat_dim in (-1, -2), f"`fft_concat_dim` must be -1 or -2, got '{fft_concat_dim}'"
            if fft_concat_dim == -2:
                channels_out //= 2

        self.input = nn.Embedding(channels_in, channels_out)
        with torch.no_grad():
            if diagonal:
                self.input.weight[:] = create_diagonal_matrix(self.input.weight.shape)
            elif fft:
                self.input.weight[:] = F.softmax(self.input.weight, dim=-1)
            if not requires_grad:
                self.input.weight.requires_grad = False
    
            if not symmetric:
                self.output = nn.Linear(channels_out, channels_in)
                self.output.weight[:] = self.input.weight.T
                self.output.bias[:] = 0.
                if not requires_grad:
                    self.output.weight.requires_grad = False
                    self.output.bias.requires_grad = False

    def extra_repr(self) -> str:
        return (
            f"diagonal={self._diagonal}, symmetric={self._symmetric}"
            f", fft={self._fft}, fft_concat_dim={self._fft_concat_dim}"
        )

    def forward(
            self,
            x: torch.Tensor,
            reverse: bool = False,
    ) -> torch.Tensor:
        """
        Converts token indices to representation or representation to token class logits

        :param x: torch.Tensor,
            if reverse==False, the token indices of shape [B, L] (where L is sequence length),
            if reverse==True and fft==False, the representation of shape [B, C, L]
            if reverse==True and fft==True, the representation
                of shape [B, C, L] if `fft_concat_dim==-2` or [B, C, L*2] if `fft_concat_dim==-1`
        :param reverse: bool, if True, reverses the embedding
        :return: torch.Tensor,
            if reverse==False and fft==False, the representation of shape [B, C, L]
            if reverse==False and fft==True, the representation
                of shape [B, C, L] if `fft_concat_dim==-2` or [B, C, L*2] if `fft_concat_dim==-1`
            if reverse==True, the token class logits of shape [B, L, V] (where V is vocab_size)
        """
        if not reverse:

            outp = self.input(x).permute(0, 2, 1)
            if self._fft:
                outp = torch.fft.fft(outp, dim=-2)
                outp = torch.concat([outp.real, outp.imag], dim=self._fft_concat_dim)
            return outp

        else:
            if self._fft:
                x = torch.complex(*torch.split(x, x.shape[self._fft_concat_dim] // 2, dim=self._fft_concat_dim))
                x = torch.fft.ifft(x, dim=-2).real

            if self._symmetric:
                return (self.input.weight @ x).permute(0, 2, 1).contiguous()
            else:
                return self.output(x.permute(0, 2, 1))


emb = DiagonalEmbedding(16, 16, symmetric=False, fft=True)
inp = torch.randint(0, 9, (2, 5), generator=torch.Generator().manual_seed(23))
outp = emb(inp)
display(inp)
print("output:")
print(outp.shape)
display(outp)
print("reverse:")
inp2 = emb(outp, reverse=True)
display(inp2.argmax(dim=-1))
display(inp2)


# try without weights

In [None]:
class DiagonalEmbedding(nn.Module):

    def __init__(
            self,
            channels_in: int,
            channels_out: int,
            diagonal: bool = True,
            symmetric: bool = True,
            no_weights: bool = False,
            fft: bool = False,
            fft_concat_dim: int = -1,
    ):
        """
        Wrapper around torch.nn.Embedding

        :param channels_in: int, vocabulary size
        :param channels_out: int, internal representation size
        :param diagonal: bool, if True, the embedding weights are initialized with
            a diagonal matrix, e.g. if channels_in==channels_out, the representation
            matches the input
        :param symmetric: bool, if True, the embedding weights and readout weights are shared.
            If False, the readout has its own set of weights.
        :param fft: bool, If True, the representation is the concatenation of the
            real and imaginary FFT transform
        :param fft_concat_dim: int, either -1 or -2,
            if -1, fft real and imaginary output is concatenated along the sequence dimension
            if -2, it's concatenated along the channel dimensions and `channels_out` is divided by 2!
        """
        super().__init__()
        self._diagonal = diagonal
        self._symmetric = symmetric
        self._no_weights = no_weights
        self._fft = fft
        self._fft_concat_dim = fft_concat_dim

        if fft:
            assert sum(int(x) for x in list(bin(channels_out)[2:])) == 1, \
                f"`channels_in` must be a power of 2 when using `fft`"
            assert fft_concat_dim in (-1, -2), f"`fft_concat_dim` must be -1 or -2, got '{fft_concat_dim}'"
            if fft_concat_dim == -2:
                channels_out //= 2

        
        if self._no_weights:
            assert channels_in == channels_out, \
                f"In and out channels must match when no_weights is True, got {channels_in} and {channels_out}"
            self.input = None
            self._channels = channels_in
        else:
            self.input = nn.Embedding(channels_in, channels_out)
            with torch.no_grad():
                if diagonal:
                    self.input.weight[:] = create_diagonal_matrix(self.input.weight.shape)
                elif fft:
                    self.input.weight[:] = F.softmax(self.input.weight, dim=-1)

        if not symmetric:
            self.output = nn.Linear(channels_out, channels_in)

    def extra_repr(self) -> str:
        return (
            f"diagonal={self._diagonal}, symmetric={self._symmetric}, no_weights={self._no_weights}"
            f", fft={self._fft}, fft_concat_dim={self._fft_concat_dim}"
        )

    def forward(
            self,
            x: torch.Tensor,
            reverse: bool = False,
    ) -> torch.Tensor:
        """
        Converts token indices to representation or representation to token class logits

        :param x: torch.Tensor,
            if reverse==False, the token indices of shape [B, L] (where L is sequence length),
            if reverse==True and fft==False, the representation of shape [B, C, L]
            if reverse==True and fft==True, the representation
                of shape [B, C, L] if `fft_concat_dim==-2` or [B, C, L*2] if `fft_concat_dim==-1`
        :param reverse: bool, if True, reverses the embedding
        :return: torch.Tensor,
            if reverse==False and fft==False, the representation of shape [B, C, L]
            if reverse==False and fft==True, the representation
                of shape [B, C, L] if `fft_concat_dim==-2` or [B, C, L*2] if `fft_concat_dim==-1`
            if reverse==True, the token class logits of shape [B, L, V] (where V is vocab_size)
        """
        if not reverse:

            if self.input is not None:
                outp = self.input(x).permute(0, 2, 1)
            else:
                # TODO: NOT RIGHT
                B, L = x.shape
                outp = torch.zeros(B, self._channels, L)
                if x.dtype != torch.int64:
                    x = x.to(torch.int64)
                print(x.shape, outp.shape)
                outp = torch.cat([
                    outp[..., i].scatter(-1, x[..., i, None], 1)
                    for i in range(L)
                ]).reshape(B, -1, L)#.permute(0, 2, 1)
        
            if self._fft:
                outp = torch.fft.fft(outp, dim=-2)
                outp = torch.concat([outp.real, outp.imag], dim=self._fft_concat_dim)
            return outp

        else:
            if self._fft:
                x = torch.complex(*torch.split(x, x.shape[self._fft_concat_dim] // 2, dim=self._fft_concat_dim))
                x = torch.fft.ifft(x, dim=-2).real

            if self._symmetric:
                return (self.input.weight @ x).permute(0, 2, 1).contiguous()
            else:
                return self.output(x.permute(0, 2, 1))

emb = DiagonalEmbedding(10, 10, no_weights=False)
inp = torch.randint(0, 9, (2, 5), generator=torch.Generator().manual_seed(23))
outp = emb(inp)
display(inp)
print(outp.shape)
display(outp)
emb = DiagonalEmbedding(10, 10, no_weights=True)
outp2 = emb(inp)
print(outp2.shape)
display(outp2)
