In [1]:
import torch
import numpy as np
import pandas as pd
from torch import nn

In [2]:
# Dataset from https://www.kaggle.com/datasets/saurabhshahane/electricity-load-forecasting
df = pd.read_csv('kaggle_electricity_load_forecasting_train.csv')
df = df.set_index('datetime')

In [3]:
data_output = np.array(df['nat_demand']).reshape(-1, 1)
data_input = np.array(df[['T2M_toc', 'QV2M_toc','TQL_toc','W2M_toc']])

In [4]:
data_output = torch.Tensor(data_output).cuda()
data_input = torch.Tensor(data_input).cuda()

In [5]:
print(data_output.shape, data_input.shape)

torch.Size([43775, 1]) torch.Size([43775, 4])


In [6]:
class AttentionNet(nn.Module):
    def __init__(self, inp_size, out_size, linear_size=36):
        super(AttentionNet, self).__init__()

        self.inp_size = inp_size
        self.linear_size = linear_size
        self.out_size = out_size
        
        self.inp_lin = nn.Linear(self.inp_size, self.linear_size)
        self.lin1 = nn.Linear(self.linear_size, self.linear_size)
        self.lin2 = nn.Linear(self.linear_size, self.linear_size)
        self.out_lin = nn.Linear(self.linear_size, self.out_size)
    def forward(self, x):
        x = self.inp_lin(x)
        x1 = self.lin2(x)
        x2 = self.lin2(x)
        temp = torch.matmul(x1, torch.matmul(x2.transpose(-2, -1), x2))
        x = self.out_lin(x)
        return x


class DLFTNet(nn.Module):
    def __init__(self, inp_size, out_size, linear_size=36):
        super(DLFTNet, self).__init__()

        self.inp_size = inp_size
        self.linear_size = linear_size
        self.out_size = out_size
        
        self.inp_lin = nn.Linear(self.inp_size, self.linear_size)
        self.out_lin = nn.Linear(self.linear_size, self.out_size)
    def forward(self, x):
        x = self.inp_lin(x)
        x = torch.fft.fft(x).real
        x = self.out_lin(x)
        return x

class MLPNet(nn.Module):
    def __init__(self, inp_size, out_size, freq_domain=None, linear_size=36):
        super(MLPNet, self).__init__()

        self.inp_size = inp_size
        self.linear_size = linear_size
        self.out_size = out_size
        self.freq_domain = freq_domain
        
        self.inp_lin = nn.Linear(self.inp_size, self.linear_size)
        self.lin1 = nn.Linear(self.linear_size, self.linear_size)
        self.out_lin = nn.Linear(self.linear_size, self.out_size)
        
    def forward(self, x):
        x = self.inp_lin(x)
        x1 = self.lin1(x)
        x = self.out_lin(x)
        return x

In [7]:
net = MLPNet(4, 1).cuda()
opti = torch.optim.Adam(net.parameters(), lr=0.0001)
loss = torch.nn.L1Loss().cuda()

for i in range(0, 10000) :
    opti.zero_grad()
    out = net(data_input)
    l = loss(out, data_output)
    l.backward()
    opti.step()
    if i % 1000 == 0 :
        print(i, l.item())

  from .autonotebook import tqdm as notebook_tqdm


0 1185.302001953125
1000 1134.449951171875
2000 1046.0831298828125
3000 909.3234252929688
4000 725.9102172851562
5000 504.82080078125
6000 294.20172119140625
7000 207.03170776367188
8000 189.46913146972656
9000 178.40626525878906


In [8]:
net = AttentionNet(4, 1).cuda()
opti = torch.optim.Adam(net.parameters(), lr=0.0001)
loss = torch.nn.L1Loss().cuda()

for i in range(0, 10000) :
    opti.zero_grad()
    out = net(data_input)
    l = loss(out, data_output)
    l.backward()
    opti.step()
    if i % 1000 == 0 :
        print(i, l.item())

0 1182.427978515625
1000 1130.1849365234375
2000 1031.293212890625
3000 883.2610473632812
4000 691.5526123046875
5000 463.650634765625
6000 259.7709655761719
7000 190.357666015625
8000 177.751220703125
9000 167.99497985839844


In [9]:
net = DLFTNet(4, 1).cuda()
opti = torch.optim.Adam(net.parameters(), lr=0.0001)
loss = torch.nn.L1Loss().cuda()

for i in range(0, 10000) :
    opti.zero_grad()
    out = net(data_input)
    l = loss(out, data_output)
    l.backward()
    opti.step()
    if i % 1000 == 0 :
        print(i, l.item())

0 1159.499267578125
1000 888.193359375
2000 372.638671875
3000 171.71966552734375
4000 158.98695373535156
5000 146.78591918945312
6000 137.71897888183594
7000 132.97914123535156
8000 131.42691040039062
9000 131.1643829345703
