In [5]:
import torch
import torch.nn as nn
import torch.optim as optim

custom convolution using fft

In [21]:
class Conv(nn.Module):
  def __init__(self, k) -> None:
    super().__init__()
    self.kernel = nn.Parameter(torch.randn(k))
  def forward(self, x):
    fft_x = torch.fft.fft(x, n= len(x[0])+len(self.kernel)-1)
    fft_kernel = torch.fft.fft(self.kernel, n= len(x[0])+len(self.kernel)-1)
    return torch.abs(torch.fft.ifft(fft_kernel*fft_x))


In [22]:
data = torch.tensor([
    [20.5, 21.0, 22.3, 23.1, 24.0, 22.7, 21.5, 20.8],
    [22.0, 23.2, 24.1, 25.5, 26.3, 25.0, 24.2, 23.8],
    [19.8, 20.3, 21.1, 20.5, 22.0, 22.8, 23.4, 22.7],
    [24.5, 25.3, 26.0, 25.7, 26.8, 27.2, 27.8, 26.9],
    [18.7, 19.2, 19.8, 20.5, 21.2, 20.9, 20.3, 19.7],
], dtype= torch.float32)

# Targets (temperatures for Day 8)
targets = torch.tensor([21.3, 25.1, 22.3, 28.0, 19.5], dtype= torch.float32)

In [23]:
class Model(nn.Module):
  def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.l1 = nn.Linear(8, 16)
    self.relu = nn.ReLU()
    self.conv = Conv(4)
    self.l2 = nn.Linear(19, 1)
  def forward(self, x):
    x = self.l1(x)
    x = self.relu(x)
    x = self.conv(x)
    return self.l2(x)


In [57]:
model = Model()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = 0.01)
epochs = 10000
for i in range(epochs):
  loss = criterion(model(data).view(len(data)), targets)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  if i % 1000 == 0:
    print(loss)

tensor(186.1384, grad_fn=<MseLossBackward0>)
tensor(0.1847, grad_fn=<MseLossBackward0>)
tensor(0.0358, grad_fn=<MseLossBackward0>)
tensor(0.0047, grad_fn=<MseLossBackward0>)
tensor(0.0006, grad_fn=<MseLossBackward0>)
tensor(0.0003, grad_fn=<MseLossBackward0>)
tensor(0.0001, grad_fn=<MseLossBackward0>)
tensor(0.0004, grad_fn=<MseLossBackward0>)
tensor(4.5909e-05, grad_fn=<MseLossBackward0>)
tensor(0.0010, grad_fn=<MseLossBackward0>)


In [58]:
test_data = torch.tensor([
    [21.2, 22.4, 23.6, 24.8, 25.9, 26.7, 27.2, 28.0],
    [25.1, 26.3, 27.7, 28.4, 29.0, 29.8, 30.2, 31.0],
    [20.5, 21.3, 22.6, 23.8, 24.5, 25.1, 26.0, 26.8],
    [28.3, 29.5, 30.1, 31.2, 31.8, 32.5, 33.0, 34.2],
    [19.0, 20.2, 21.6, 22.3, 23.0, 24.1, 25.0, 25.7],
])
with torch.no_grad():
  print(model(test_data))
# test_targets = np.array([28.7, 32.8, 27.5, 35.0, 24.8])

tensor([[28.2597],
        [32.1648],
        [26.7476],
        [35.9150],
        [25.4856]])
