In [44]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch import nn

In [45]:
# Dataset from https://www.kaggle.com/datasets/saurabhshahane/electricity-load-forecasting
df = pd.read_csv('kaggle_electricity_load_forecasting_train.csv')
df = df.set_index('datetime')

In [46]:
data_output = np.array(df['nat_demand']).reshape(-1, 1)
data_input = np.array(df[['T2M_toc', 'QV2M_toc','TQL_toc','W2M_toc']])

In [47]:
data_output = torch.Tensor(data_output).cuda()
data_input = torch.Tensor(data_input).cuda()

In [48]:
print(data_output.shape, data_input.shape)

torch.Size([43775, 1]) torch.Size([43775, 4])


In [49]:
class AttentionNet(nn.Module):
    def __init__(self, inp_size, out_size, linear_size=36):
        super(AttentionNet, self).__init__()

        self.inp_size = inp_size
        self.linear_size = linear_size
        self.out_size = out_size
        
        self.inp_lin = nn.Linear(self.inp_size, self.linear_size)
        self.lin1 = nn.Linear(self.linear_size, self.linear_size)
        self.lin2 = nn.Linear(self.linear_size, self.linear_size)
        self.out_lin = nn.Linear(self.linear_size, self.out_size)
    def forward(self, x):
        x = self.inp_lin(x)
        x1 = F.relu(self.lin1(x))
        x2 = F.relu(self.lin2(x))
        temp = torch.matmul(x1, torch.matmul(x2.transpose(-2, -1), x2))
        x = self.out_lin(temp)
        return x


class DLFTNet(nn.Module):
    def __init__(self, inp_size, out_size, linear_size=36):
        super(DLFTNet, self).__init__()

        self.inp_size = inp_size
        self.linear_size = linear_size
        self.out_size = out_size
        
        self.inp_lin = nn.Linear(self.inp_size, self.linear_size)
        self.lin1 = nn.Linear(self.linear_size, self.linear_size)
        self.out_lin = nn.Linear(self.linear_size ** 2, self.out_size)
        
    def forward(self, x):
        x = self.inp_lin(x)
        x1 = torch.fft.fft(x).real.view(-1, x.shape[1], 1)
        x2 = F.relu(self.lin1(x).view(-1, 1, x.shape[1]))
        temp = torch.matmul(x1, x2).reshape(-1, x.shape[1] ** 2)
        x = self.out_lin(temp)
        return x

class MLPNet(nn.Module):
    def __init__(self, inp_size, out_size, freq_domain=None, linear_size=36):
        super(MLPNet, self).__init__()

        self.inp_size = inp_size
        self.linear_size = linear_size
        self.out_size = out_size
        self.freq_domain = freq_domain
        
        self.inp_lin = nn.Linear(self.inp_size, self.linear_size)
        self.lin1 = nn.Linear(self.linear_size, self.linear_size)
        self.out_lin = nn.Linear(self.linear_size, self.out_size)
        
    def forward(self, x):
        x = self.inp_lin(x)
        x1 = F.relu(self.lin1(x))
        x = self.out_lin(x1)
        return x

In [50]:
net = MLPNet(4, 1).cuda()
opti = torch.optim.Adam(net.parameters(), lr=0.0001)
loss = torch.nn.L1Loss().cuda()

for i in range(0, 10000) :
    opti.zero_grad()
    out = net(data_input)
    l = loss(out, data_output)
    l.backward()
    opti.step()
    if i % 100 == 0 :
        print(i, l.item())

0 1182.92431640625
100 1179.1015625
200 1173.0625
300 1164.6961669921875
400 1153.1851806640625
500 1137.6766357421875
600 1117.4923095703125
700 1092.098388671875
800 1060.964111328125
900 1023.2356567382812
1000 978.001953125
1100 925.0321044921875
1200 864.125
1300 795.038818359375
1400 717.5177001953125
1500 631.3187255859375
1600 536.2973022460938
1700 434.2058410644531
1800 334.7191467285156
1900 256.9653625488281
2000 212.17080688476562
2100 192.01516723632812
2200 184.04354858398438
2300 180.54049682617188
2400 178.3057098388672
2500 176.3870086669922
2600 174.5247039794922
2700 172.6586151123047
2800 170.77926635742188
2900 168.8882293701172
3000 166.9946746826172
3100 165.1099395751953
3200 163.23497009277344
3300 161.38656616210938
3400 159.57266235351562
3500 157.77854919433594
3600 156.01593017578125
3700 154.28868103027344
3800 152.61277770996094
3900 151.00071716308594
4000 149.4571533203125
4100 147.96910095214844
4200 146.54611206054688
4300 145.19082641601562
4400 143

In [51]:
net = AttentionNet(4, 1).cuda()
opti = torch.optim.Adam(net.parameters(), lr=0.0001)
loss = torch.nn.L1Loss().cuda()

for i in range(0, 10000) :
    opti.zero_grad()
    out = net(data_input)
    l = loss(out, data_output)
    l.backward()
    opti.step()
    if i % 100 == 0 :
        print(i, l.item())

0 16949928.0
100 58438.375
200 6046.73095703125
300 76989.234375
400 44840.625
500 30474.568359375
600 10466.1630859375
700 31421.912109375
800 16472.296875
900 27446.421875
1000 10747.4072265625
1100 16210.2275390625
1200 11122.3759765625
1300 19126.126953125
1400 11635.0595703125
1500 14489.080078125
1600 11856.1865234375
1700 10987.4833984375
1800 14056.0205078125
1900 6495.38330078125
2000 8355.244140625
2100 4868.52490234375
2200 3878.144287109375
2300 10148.0576171875
2400 16889.34375
2500 4347.19873046875
2600 2375.919189453125
2700 1852.834716796875
2800 7450.623046875
2900 12951.109375
3000 4790.60986328125
3100 5395.76318359375
3200 3173.140625
3300 2804.185791015625
3400 15962.4580078125
3500 12554.427734375
3600 12189.3203125
3700 11385.1767578125
3800 1329.499267578125
3900 8431.568359375
4000 2396.234619140625
4100 3974.418701171875
4200 4754.67529296875
4300 2409.241455078125
4400 3222.387939453125
4500 3862.133544921875
4600 1722.7645263671875
4700 2258.510009765625
480

In [52]:
net = DLFTNet(4, 1).cuda()
opti = torch.optim.Adam(net.parameters(), lr=0.0001)
loss = torch.nn.L1Loss().cuda()

for i in range(0, 10000) :
    opti.zero_grad()
    out = net(data_input)
    l = loss(out, data_output)
    l.backward()
    opti.step()
    if i % 100 == 0 :
        print(i, l.item())

0 1067.1510009765625
100 130.3993682861328
200 114.87847137451172
300 114.79458618164062
400 114.73668670654297
500 114.70279693603516
600 114.68274688720703
700 114.66938018798828
800 114.65923309326172
900 114.65076446533203
1000 114.64447784423828
1100 114.63792419433594
1200 114.63272857666016
1300 114.62834930419922
1400 114.62430572509766
1500 114.62074279785156
1600 114.61722564697266
1700 114.61347961425781
1800 114.60973358154297
1900 114.6061782836914
2000 114.6025619506836
2100 114.59900665283203
2200 114.5954818725586
2300 114.59193420410156
2400 114.588134765625
2500 114.58421325683594
2600 114.58049011230469
2700 114.57674407958984
2800 114.5729751586914
2900 114.569091796875
3000 114.56509399414062
3100 114.56102752685547
3200 114.55689239501953
3300 114.55270385742188
3400 114.5484390258789
3500 114.54408264160156
3600 114.53961181640625
3700 114.5350570678711
3800 114.53048706054688
3900 114.52586364746094
4000 114.52123260498047
4100 114.51655578613281
4200 114.511825