In [1]:
import torch
import numpy as np
import pandas as pd
from torch import nn

In [2]:
# Dataset from https://www.kaggle.com/datasets/saurabhshahane/electricity-load-forecasting
df = pd.read_csv('kaggle_electricity_load_forecasting_train.csv')
df = df.set_index('datetime')

In [3]:
data_output = np.array(df['nat_demand']).reshape(-1, 1)
data_input = np.array(df[['T2M_toc', 'QV2M_toc','TQL_toc','W2M_toc']])

In [4]:
data_output = torch.Tensor(data_output).cuda()
data_input = torch.Tensor(data_input).cuda()

In [5]:
print(data_output.shape, data_input.shape)

torch.Size([43775, 1]) torch.Size([43775, 4])


In [39]:
class AttentionNet(nn.Module):
    def __init__(self, inp_size, out_size, linear_size=36):
        super(AttentionNet, self).__init__()

        self.inp_size = inp_size
        self.linear_size = linear_size
        self.out_size = out_size
        
        self.inp_lin = nn.Linear(self.inp_size, self.linear_size)
        self.lin1 = nn.Linear(self.linear_size, self.linear_size)
        self.lin2 = nn.Linear(self.linear_size, self.linear_size)
        self.out_lin = nn.Linear(self.linear_size, self.out_size)
    def forward(self, x):
        x = self.inp_lin(x)
        x1 = self.lin1(x)
        x2 = self.lin2(x)
        temp = torch.matmul(x1, torch.matmul(x2.transpose(-2, -1), x2))
        x = self.out_lin(temp)
        return x


class DLFTNet(nn.Module):
    def __init__(self, inp_size, out_size, linear_size=36):
        super(DLFTNet, self).__init__()

        self.inp_size = inp_size
        self.linear_size = linear_size
        self.out_size = out_size
        
        self.inp_lin = nn.Linear(self.inp_size, self.linear_size)
        self.lin1 = nn.Linear(self.linear_size, self.linear_size)
        self.out_lin = nn.Linear(self.linear_size ** 2, self.out_size)
        
    def forward(self, x):
        x = self.inp_lin(x)
        x1 = torch.fft.fft(x).real.view(-1, x.shape[1], 1)
        x2 = self.lin1(x).view(-1, 1, x.shape[1])
        temp = torch.matmul(x1, x2).reshape(-1, x.shape[1] ** 2)
        x = self.out_lin(temp)
        return x

class MLPNet(nn.Module):
    def __init__(self, inp_size, out_size, freq_domain=None, linear_size=36):
        super(MLPNet, self).__init__()

        self.inp_size = inp_size
        self.linear_size = linear_size
        self.out_size = out_size
        self.freq_domain = freq_domain
        
        self.inp_lin = nn.Linear(self.inp_size, self.linear_size)
        self.lin1 = nn.Linear(self.linear_size, self.linear_size)
        self.out_lin = nn.Linear(self.linear_size, self.out_size)
        
    def forward(self, x):
        x = self.inp_lin(x)
        x1 = self.lin1(x)
        x = self.out_lin(x1)
        return x

In [40]:
net = MLPNet(4, 1).cuda()
opti = torch.optim.Adam(net.parameters(), lr=0.0001)
loss = torch.nn.L1Loss().cuda()

for i in range(0, 10000) :
    opti.zero_grad()
    out = net(data_input)
    l = loss(out, data_output)
    l.backward()
    opti.step()
    if i % 100 == 0 :
        print(i, l.item())

0 1183.7738037109375
100 1174.4654541015625
200 1161.39892578125
300 1141.6104736328125
400 1112.4664306640625
500 1070.949951171875
600 1014.36181640625
700 940.7183227539062
800 848.4071655273438
900 736.049072265625
1000 602.4129638671875
1100 447.8420104980469
1200 296.6191711425781
1300 210.6390380859375
1400 185.51734924316406
1500 179.68624877929688
1600 176.95565795898438
1700 174.54257202148438
1800 172.1291046142578
1900 169.6953887939453
2000 167.2498779296875
2100 164.8249053955078
2200 162.42539978027344
2300 160.08456420898438
2400 157.79054260253906
2500 155.5536346435547
2600 153.39151000976562
2700 151.32717895507812
2800 149.3759307861328
2900 147.5207977294922
3000 145.77261352539062
3100 144.14300537109375
3200 142.62582397460938
3300 141.2288360595703
3400 139.94454956054688
3500 138.76234436035156
3600 137.6855926513672
3700 136.7210235595703
3800 135.86961364746094
3900 135.12228393554688
4000 134.4762725830078
4100 133.9157257080078
4200 133.43516540527344
4300 

In [41]:
net = AttentionNet(4, 1).cuda()
opti = torch.optim.Adam(net.parameters(), lr=0.0001)
loss = torch.nn.L1Loss().cuda()

for i in range(0, 10000) :
    opti.zero_grad()
    out = net(data_input)
    l = loss(out, data_output)
    l.backward()
    opti.step()
    if i % 100 == 0 :
        print(i, l.item())

0 3775175.75
100 2986.615234375
200 22216.71484375
300 13953.3369140625
400 2511.17626953125
500 2357.7119140625
600 2202.16650390625
700 2006.95556640625
800 1869.9871826171875
900 1863.1217041015625
1000 2111.072998046875
1100 2038.2357177734375
1200 1965.3062744140625
1300 1949.006591796875
1400 1833.4266357421875
1500 1802.700927734375
1600 1763.21826171875
1700 2031.7774658203125
1800 1807.0172119140625
1900 1391.2127685546875
2000 1647.3631591796875
2100 1371.1881103515625
2200 1188.261962890625
2300 1069.1737060546875
2400 985.0745849609375
2500 919.1196899414062
2600 4851.12353515625
2700 1755.2647705078125
2800 1396.2376708984375
2900 1663.8245849609375
3000 479.5391845703125
3100 761.8577880859375
3200 720.2938232421875
3300 588.7936401367188
3400 521.6331787109375
3500 445.6144104003906
3600 381.92474365234375
3700 370.73504638671875
3800 284.79986572265625
3900 229.81053161621094
4000 191.0625457763672
4100 114.45570373535156
4200 114.08018493652344
4300 114.29058074951172


In [42]:
net = DLFTNet(4, 1).cuda()
opti = torch.optim.Adam(net.parameters(), lr=0.0001)
loss = torch.nn.L1Loss().cuda()

for i in range(0, 10000) :
    opti.zero_grad()
    out = net(data_input)
    l = loss(out, data_output)
    l.backward()
    opti.step()
    if i % 100 == 0 :
        print(i, l.item())

0 1306.522705078125
100 115.82056427001953
200 115.06350708007812
300 114.9590072631836
400 114.87504577636719
500 114.81510162353516
600 114.77745819091797
700 114.75737762451172
800 114.74542999267578
900 114.73768615722656
1000 114.7322769165039
1100 114.72813415527344
1200 114.7242202758789
1300 114.72026824951172
1400 114.7161865234375
1500 114.71195220947266
1600 114.70758819580078
1700 114.70308685302734
1800 114.6984634399414
1900 114.69368743896484
2000 114.68870544433594
2100 114.68358612060547
2200 114.67831420898438
2300 114.6729507446289
2400 114.66749572753906
2500 114.66195678710938
2600 114.6563491821289
2700 114.65055847167969
2800 114.64461517333984
2900 114.63856506347656
3000 114.63227081298828
3100 114.62579345703125
3200 114.61908721923828
3300 114.61228942871094
3400 114.60541534423828
3500 114.59842681884766
3600 114.59123229980469
3700 114.5838623046875
3800 114.57637023925781
3900 114.56886291503906
4000 114.56123352050781
4100 114.5534896850586
4200 114.54564