In [2]:
from models import Koopa

Koopa.Model

models.Koopa.Model

In [5]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [85]:
class FourierSelecter(nn.Module):
    def __init__(self, input_len, alpha):
        super(FourierSelecter, self).__init__()
        
        self.frequency_size = input_len//2 + 1 
        self.hidden_size_factor = 1
        self.scale = 0.02
        self.alpha = alpha
        self.w1 = nn.Parameter(self.scale * torch.randn(2, self.frequency_size, self.frequency_size * self.hidden_size_factor))
        self.b1 = nn.Parameter(self.scale * torch.randn(2, self.frequency_size * self.hidden_size_factor))

        self.w2 = nn.Parameter(self.scale * torch.randn(2, self.frequency_size * self.hidden_size_factor, self.frequency_size))
        self.b2 = nn.Parameter(self.scale * torch.randn(2, self.frequency_size))

        self.w3 = nn.Parameter(self.scale * torch.randn(2, self.frequency_size, self.frequency_size * self.hidden_size_factor))
        self.b3 = nn.Parameter(self.scale * torch.randn(2, self.frequency_size * self.hidden_size_factor))

    def forward(self, x):
        if "mps" in str(x.device):
            device = 'mps'
            x = x.to(torch.device('mps'))
        else:
            device = None

        xf = torch.fft.rfft(x, dim=1)
        xf = xf.permute(0, 2, 1)

        o1_real = F.relu(torch.einsum('bli,ii->bli', xf.real, self.w1[0]) - torch.einsum('bli,ii->bli', xf.imag, self.w1[1]) + self.b1[0])
        o1_imag = F.relu(torch.einsum('bli,ii->bli', xf.imag, self.w1[0]) + torch.einsum('bli,ii->bli', xf.real, self.w1[1]) + self.b1[1])
        z1 = torch.stack([o1_real, o1_imag], dim=-1)
        # z1 = F.softshrink(z1, lambd=self.sparsity_threshold)

        o2_real = F.relu(torch.einsum('bli,ii->bli', o1_real, self.w2[0]) - torch.einsum('bli,ii->bli', o1_imag, self.w2[1]) + self.b2[0])
        o2_imag = F.relu(torch.einsum('bli,ii->bli', o1_imag, self.w2[0]) + torch.einsum('bli,ii->bli', o1_real, self.w2[1]) + self.b2[1])
        z2 = torch.stack([o2_real, o2_imag], dim=-1)
        # z2 = F.softshrink(z2, lambd=self.sparsity_threshold)

        o3_real = F.relu(torch.einsum('bli,ii->bli', o2_real, self.w3[0]) - torch.einsum('bli,ii->bli', o2_imag, self.w3[1]) + self.b3[0])
        o3_imag = F.relu(torch.einsum('bli,ii->bli', o2_imag, self.w3[0]) + torch.einsum('bli,ii->bli', o2_real, self.w3[1]) + self.b3[1])
        z3 = torch.stack([o3_real, o3_imag], dim=-1)
        # z3 = F.softshrink(z3, lambd=self.sparsity_threshold)

        logits = torch.view_as_complex(z1+z2+z3)
        logits = (logits * torch.conj(logits)).real.permute(0, 2, 1)

        gumbels = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()  # ~Gumbel(0,1)
        
        y_soft = F.softmax((logits + gumbels) / 0.5, dim=1)
        indices = y_soft.topk(int(self.frequency_size * (1-self.alpha)), dim=1, largest=False).indices
        masked = y_soft.scatter_(dim=1, index=indices, value=0.)
        x_masked = (masked - y_soft).detach() + y_soft
        
        x_inv = torch.fft.irfft(x_masked, dim=1)
        x_var = x - x_inv

        return x_var, x_inv
x = torch.randn((64, 128, 24))
fs = FourierSelecter(128, 0.2)
fs(x)

(tensor([[[ 1.5674,  3.3550,  2.1269,  ...,  1.7712,  1.0232,  0.2149],
          [ 0.9400, -0.0569, -1.6833,  ..., -0.0278, -0.3581,  0.4287],
          [-0.1486, -0.5470,  1.2446,  ...,  0.0596,  0.6384,  0.4652],
          ...,
          [-0.0233,  0.1718, -0.8715,  ...,  0.9948, -0.8395, -1.8394],
          [-1.6671, -1.5316,  0.6663,  ..., -0.9863, -0.5209, -1.8509],
          [ 0.1499, -0.9797, -1.0295,  ...,  1.1925, -1.2845,  1.4688]],
 
         [[-1.0955, -0.6625, -1.7931,  ..., -0.8457, -1.8101, -0.2930],
          [-1.5159,  0.7875,  1.4526,  ..., -0.2835, -0.1012, -0.5280],
          [ 1.0258,  0.0875,  0.5651,  ..., -0.1704, -1.3374,  1.0270],
          ...,
          [-0.9452,  0.4338, -0.2754,  ..., -0.2145,  0.5292, -0.6241],
          [-0.7583,  1.2524,  0.2361,  ..., -0.5022, -0.3223, -0.3035],
          [-0.6411, -1.2344, -1.7940,  ..., -0.4620,  1.0499,  0.4139]],
 
         [[ 2.7460, -0.7254, -2.2387,  ..., -0.2388,  0.9874,  0.8271],
          [-1.1311,  0.6453,

: 

In [20]:
fs = FourierSelecter()
fs(x)

torch.Size([64, 65, 24])
torch.Size([64, 24, 65])


RuntimeError: einsum(): subscript i is repeated for operand 1 but the sizes don't match, 16640 != 65

In [2]:
import torch
x = torch.randn((1, 10, 1))
x

tensor([[[-0.1373],
         [-0.5983],
         [ 0.7446],
         [-1.3609],
         [ 0.6735],
         [ 0.7368],
         [-0.3499],
         [ 0.0016],
         [-0.5404],
         [ 0.8356]]])