In [None]:
from init_notebook import *

In [None]:
class DiagonalEmbedding(nn.Module):

    def __init__(
            self,
            channels_in: int,
            channels_out: int,
            diagonal: bool = True,
            symmetric: bool = True,
            fft: bool = False,
    ):
        """
        Wrapper around torch.nn.Embedding

        :param channels_in: int, vocabulary size
        :param channels_out: int, internal representation size
        :param diagonal: bool, if True, the embedding weights are initialized with
            a diagonal matrix, e.g. if channels_in==channels_out, the representation
            matches the input
        :param symmetric: bool, if True, the embedding weights and readout weights are shared.
            If False, the readout has it's own set of weights.
        :param fft: bool, If True, the representation is the concatenation of the
            real and imaginary FFT transform, which doubles the `channels_out`
        """
        super().__init__()
        self._diagonal = diagonal
        self._symmetric = symmetric
        self._fft = fft

        self.input = nn.Embedding(channels_in, channels_out)
        with torch.no_grad():
            if diagonal:
                self.input.weight[:] = create_diagonal_matrix(self.input.weight.shape)
            if fft:
                self.input.weight[:] = F.softmax(self.input.weight, dim=-1)

        if not symmetric:
            self.output = nn.Linear(channels_out, channels_in)
        #else:
        #    self.output = nn.Linear(channels_out, channels_in, bias=False)
        #    self.output.weight = self.input.weight

    def extra_repr(self) -> str:
        return f"diagonal={self._diagonal}, symmetric={self._symmetric}, fft={self._fft}"

    def forward(
            self,
            x: torch.Tensor,
            reverse: bool = False,
    ) -> torch.Tensor:
        """
        Converts token indices to representation or representation to token class logits

        :param x: torch.Tensor,
            if reverse==False, the token indices of shape [B, L] (where L is sequence length),
            if reverse==True and fft==False, the representation of shape [B, C, L]
            if reverse==True and fft==True, the representation of shape [B, C*2, L]
        :param reverse: bool, if True, reverses the embedding
        :return: torch.Tensor,
            if reverse==False and fft==False, the representation of shape [B, C, L]
            if reverse==False and fft==True, the representation of shape [B, C*2, L]
            if reverse==True, the token class logits of shape [B, L, V] (where V is vocab_size)
        """
        if not reverse:

            outp = self.input(x).permute(0, 2, 1)
            if self._fft:
                outp = torch.fft.fft(outp, dim=-2)
                outp = torch.concat([outp.real, outp.imag], dim=-2)
            return outp
            
        else:
            if self._fft:
                x = torch.complex(*torch.split(x, x.shape[-2] // 2, dim=-2))
                x = torch.fft.ifft(x, dim=-2).real
                
            if self._symmetric:
                return (self.input.weight @ x).permute(0, 2, 1).contiguous()
            else:
                return self.output(x.permute(0, 2, 1))

with torch.no_grad():
    emb = DiagonalEmbedding(256, 512, fft=True, diagonal=True, symmetric=False)
    #inp = torch.randint(0, 255, (1, 10))
    inp = torch.LongTensor([[0, 1, 2, 3, 2, 1, 201]])
    outp = emb(inp)
    inp2 = emb(outp, reverse=True)
    print(outp.shape, outp.min(), outp.max())
    display((inp, inp2.argmax(-1)))

In [None]:
with torch.no_grad():
    inp = torch.LongTensor([[0, 1, 2, 4, 16, 201]])
    x = emb(inp)
    x.shape
    
    x2 = torch.fft.fft(x, dim=-2)
    xc = torch.cat([x2.real, x2.imag], dim=-2)

    display(px.imshow(xc[0].T, height=400, aspect=False))
    
    #x2 = torch.cat([x2.real, x2.imag], dim=-2)
    #print(x2.shape, x2.min(), x2.max())
    
    x2 = torch.complex(*torch.split(xc, xc.shape[-2] // 2, dim=-2))
    x3 = torch.fft.ifft(x2, dim=-2)
    display(px.imshow(x3.real[0].T, height=250, aspect=False))
    display(px.imshow(x3.imag[0].T, height=250, aspect=False))
    outp = emb(x3.real, reverse=True).argmax(-1)
    display(outp)

In [None]:
xf = torch.fft.rfft(x, dim=-2).imag
print(xf.shape)
#xc = torch.cat([x2.real, x2.imag], dim=-2)

display(px.imshow(xf[0].T, height=400, aspect=False))

In [None]:
from torchvision.models import vgg11
vgg11()