In [1]:
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn as nn
from torchvision import datasets,transforms
import torch.nn.functional as F
import sys
sys.path.append("../code/")
from dataloader import APPLIANCE_ORDER, get_train_test
from sklearn.metrics import mean_absolute_error
import os

In [2]:
cuda_av = False
if torch.cuda.is_available():
    cuda_av = True

torch.manual_seed(0)
np.random.seed(0)

In [3]:

class CustomRNN(nn.Module):
    def __init__(self, cell_type, hidden_size, num_layers, bidirectional):
        super(CustomRNN, self).__init__()
        torch.manual_seed(0)

        if bidirectional:
            self.num_directions = 2
        else:
            self.num_directions = 1
        if cell_type == "RNN":
            self.rnn = nn.RNN(input_size=1, hidden_size=hidden_size,
                              num_layers=num_layers, batch_first=True,
                              bidirectional=bidirectional)
        elif cell_type == "GRU":
            self.rnn = nn.GRU(input_size=1, hidden_size=hidden_size,
                              num_layers=num_layers, batch_first=True,
                              bidirectional=bidirectional)
        else:
            self.rnn = nn.LSTM(input_size=1, hidden_size=hidden_size,
                               num_layers=num_layers, batch_first=True,
                               bidirectional=bidirectional)

        self.linear = nn.Linear(hidden_size * self.num_directions, 1)
        self.act = nn.ReLU()

    def forward(self, x):
        pred, hidden = self.rnn(x, None)
        pred = self.linear(pred).view(pred.data.shape[0], -1, 1)
        pred = torch.min(pred, x)
        return pred

class CustomCNN(nn.Module):
    def __init__(self):
        super(CustomCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, kernel_size=7, stride=1, padding=2)
        self.bn1 = nn.BatchNorm2d(20)

        self.conv2 = nn.Conv2d(20, 16, kernel_size=2, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(16)

        self.conv3 = nn.Conv2d(16, 64, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(64)

        self.conv4 = nn.ConvTranspose2d(64, 16, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(16)

        self.conv5 = nn.ConvTranspose2d(16, 6, kernel_size=4, stride=2, padding=1)
        self.bn5 = nn.BatchNorm2d(6)

        self.conv6 = nn.ConvTranspose2d(6, 1, kernel_size=5, stride=1, padding=2) 
        
        self.act = nn.ReLU()
        
    # forward method
    def forward(self, input):
        
        e1 = self.conv1(input)
        bn1 = self.bn1(self.act(e1))
        e2 = self.bn2(self.conv2(bn1))        
        e5 = self.bn5(self.conv5(e2))
        e6 = self.conv6(e5)
        return e6

In [4]:
def preprocess(train, valid, test):
    out_train = [None for temp in range(len(ORDER))]
    for a_num, appliance in enumerate(ORDER):
        out_train[a_num] = Variable(
            torch.Tensor(train[:, APPLIANCE_ORDER.index(appliance), :, :].reshape((train.shape[0], 1, -1, 24))))
        if cuda_av:
            out_train[a_num] = out_train[a_num].cuda()

    out_valid = [None for temp in range(len(ORDER))]
    for a_num, appliance in enumerate(ORDER):
        out_valid[a_num] = Variable(
            torch.Tensor(valid[:, APPLIANCE_ORDER.index(appliance), :, :].reshape((valid.shape[0], 1, -1, 24))))
        if cuda_av:
            out_valid[a_num] = out_valid[a_num].cuda()
            
    out_test = [None for temp in range(len(ORDER))]
    for a_num, appliance in enumerate(ORDER):
        out_test[a_num] = Variable(
            torch.Tensor(test[:, APPLIANCE_ORDER.index(appliance), :, :].reshape((test.shape[0], 1, -1, 24))))
        if cuda_av:
            out_test[a_num] = out_test[a_num].cuda()

    return out_train, out_valid, out_test

In [5]:
dataset = 1
fold_num = 1
num_folds = 5
train, test = get_train_test(dataset, num_folds=num_folds, fold_num=fold_num)
valid = train[int(0.8*len(train)):].copy()
train = train[:int(0.8 * len(train))].copy()
train_aggregate = train[:, 0, :, :].reshape(train.shape[0], 1, -1, 24)
valid_aggregate = valid[:, 0, :, :].reshape(valid.shape[0], 1, -1, 24)
test_aggregate = test[:, 0, :, :].reshape(test.shape[0], 1, -1, 24)

In [6]:
ORDER = ['fridge', 'dr', 'hvac', 'dw', 'mw']
out_train, out_valid, out_test = preprocess(train, valid, test)

In [7]:
inp = Variable(torch.Tensor(train_aggregate), requires_grad=False)
valid_inp = Variable(torch.Tensor(valid_aggregate), requires_grad=False)
test_inp = Variable(torch.Tensor(test_aggregate), requires_grad=False)
if cuda_av:
    inp = inp.cuda()
    valid_inp = valid_inp.cuda()
    test_inp = test_inp.cuda()
valid_out = torch.cat([out_valid[appliance_num] for appliance_num, appliance in enumerate(ORDER)])
test_out = torch.cat([out_test[appliance_num] for appliance_num, appliance in enumerate(ORDER)])
train_out = torch.cat([out_train[appliance_num] for appliance_num, appliance in enumerate(ORDER)])

In [8]:
class AppliancesRNNCNN(nn.Module):
    def __init__(self, cell_type, hidden_size, num_layers, bidirectional, num_appliance):
        super(AppliancesRNNCNN, self).__init__()
        self.num_appliance = num_appliance
        self.preds = {}
        self.order = ORDER
        for appliance in range(self.num_appliance):
            if ORDER[appliance] in ['hvac', 'fridge']:
                print("use RNN")
                if cuda_av:
                    setattr(self, "Appliance_" + str(appliance), CustomRNN(cell_type, hidden_size,
                                                                           num_layers, bidirectional).cuda())
                else:
                    setattr(self, "Appliance_" + str(appliance), CustomRNN(cell_type, hidden_size,
                                                                           num_layers, bidirectional))
            else:
                print("use CNN")
                if cuda_av:
                    setattr(self, "Appliance_" + str(appliance), CustomCNN().cuda())
                else:
                    setattr(self, "Appliance_" + str(appliance), CustomCNN())


    def forward(self, *args):
        agg_current = args[0]
        flag = False
        if np.random.random() > args[1]:
            flag = True
        else:
            pass
        for appliance in range(self.num_appliance):
            agg_current = agg_current.contiguous()
            if ORDER[appliance] in ['hvac', 'fridge']:
                agg_current = agg_current.view(agg_current.shape[0], -1, 1)
            else:
                agg_current = agg_current.view(agg_current.shape[0], 1, -1, 24)
            
            self.preds[appliance] = getattr(self, "Appliance_" + str(appliance))(agg_current)
            
            agg_current = agg_current.view(agg_current.shape[0], 1, -1, 24)
            self.preds[appliance] = self.preds[appliance].view(self.preds[appliance].shape[0], 1, -1, 24)
            
            
            if flag:
                agg_current = agg_current - self.preds[appliance]
            else:
                agg_current = agg_current - args[2 + appliance]

        return torch.cat([self.preds[a] for a in range(self.num_appliance)])

In [9]:
loss_func = nn.L1Loss()
model = AppliancesRNNCNN('GRU', 50, 2, True, len(ORDER))
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
p=0
params = [inp, p]

if cuda_av:
    model = model.cuda()
    loss_func = loss_func.cuda()
    
for a_num, appliance in enumerate(ORDER):
    params.append(out_train[a_num])

if cuda_av:
    train_out = train_out.cuda()

use CNN
use RNN
use RNN
use CNN
use CNN


In [None]:
for k in range(4000):
    pred = model(*params)
    optimizer.zero_grad()
    loss = loss_func(pred, train_out)
    
    
    
    if cuda_av:
        test_inp = test_inp.cuda()
    test_params = [test_inp, -2]
    for i in range(len(ORDER)):
        test_params.append(None)
    test_pr = model(*test_params)
    test_loss = loss_func(test_pr, test_out)
    
    print(k, loss.data[0], test_loss.data[0])

    loss.backward()
    optimizer.step()

0 39.79768371582031 50.715614318847656
1 39.78328323364258 50.57164001464844
2 39.761749267578125 50.69451141357422
3 39.74970245361328 50.75785446166992
4 39.735313415527344 50.810813903808594
5 39.69281768798828 50.751094818115234
6 39.679351806640625 50.8043212890625
7 39.74676513671875 50.60640335083008
8 39.74199676513672 50.79082107543945
9 39.71002197265625 50.89620590209961
10 39.73906707763672 50.675662994384766
11 39.70842361450195 50.635986328125
12 39.685211181640625 50.686038970947266
13 39.65639877319336 50.85262680053711
14 40.056365966796875 50.320438385009766
15 39.836082458496094 50.538856506347656
16 40.03162384033203 51.63629913330078
17 39.901519775390625 51.45805740356445
18 39.86384582519531 50.570762634277344
19 39.971168518066406 50.479331970214844
20 39.6970100402832 51.01506042480469
21 39.910465240478516 51.314849853515625
22 39.71820068359375 50.54423522949219
23 39.861507415771484 50.236942291259766
24 39.79844284057617 50.20582580566406
25 39.718826293945

206 38.781707763671875 48.8636360168457
207 38.77463912963867 48.86792755126953
208 38.75998306274414 48.83001708984375
209 38.760684967041016 48.94241714477539
210 38.76028060913086 48.8320426940918
211 38.7426872253418 48.88446044921875
212 38.751136779785156 48.93061828613281
213 38.760250091552734 48.822750091552734
214 38.752750396728516 48.86294174194336
215 38.80889129638672 49.06700134277344
216 38.77944564819336 48.7044677734375
217 38.75968933105469 48.73820495605469
218 38.84196472167969 49.08711242675781
219 38.76100540161133 48.65956115722656
220 38.77619552612305 48.6448860168457
221 38.80004119873047 49.04872512817383
222 38.72758483886719 48.87642288208008
223 38.781494140625 48.56551742553711
224 38.72853469848633 48.7703971862793
225 38.77311325073242 48.97548294067383
226 40.16291809082031 49.731117248535156
227 39.51485061645508 48.99440002441406
228 39.63370132446289 48.26299285888672
229 40.18051528930664 48.35436248779297
230 39.74372863769531 48.56155776977539
2

410 38.2194709777832 48.148258209228516
411 38.24100112915039 47.9892578125
412 38.223201751708984 48.08863830566406
413 38.20832824707031 48.24310302734375
414 38.207645416259766 48.2973518371582
415 38.19523620605469 48.16563034057617
416 38.17512512207031 48.16184997558594
417 38.18088150024414 48.15742111206055
418 38.1628303527832 48.129539489746094
419 38.158817291259766 48.17330551147461
420 38.17283248901367 48.38985824584961
421 38.146080017089844 48.1450309753418
422 38.15633773803711 48.08646774291992
423 38.136192321777344 48.223167419433594
424 38.11427688598633 48.19265365600586
425 38.124351501464844 47.96407699584961
426 38.11216735839844 48.02462387084961
427 38.1118278503418 48.154624938964844
428 38.09535217285156 48.025630950927734
429 38.09026336669922 47.97005844116211
430 38.078060150146484 48.12788772583008
431 38.0869026184082 47.97892379760742
432 38.065067291259766 48.04133224487305
433 38.0677604675293 48.173702239990234
434 38.05124282836914 48.058986663818

613 38.23997116088867 48.38224411010742
614 38.18611145019531 48.08335494995117
615 38.18595886230469 47.894813537597656
616 38.151302337646484 47.92594528198242
617 38.25449752807617 48.32221984863281
618 38.37269592285156 48.302833557128906
619 38.39710998535156 48.12592315673828
620 38.45766067504883 48.049617767333984
621 38.53377914428711 48.03269577026367
622 38.5617561340332 48.057003021240234
623 38.562374114990234 47.17237854003906
624 38.646087646484375 47.2802619934082
625 38.56296157836914 47.083919525146484
626 38.61040496826172 46.87226867675781
627 38.57467269897461 46.92098617553711
628 38.531166076660156 47.23373794555664
629 38.512176513671875 47.21135330200195
630 38.45014953613281 47.17363357543945
631 38.39898681640625 47.034523010253906
632 38.32847595214844 46.85841369628906
633 38.28773880004883 47.2380256652832
634 38.24151611328125 47.807430267333984
635 38.21395492553711 47.74641418457031
636 38.210445404052734 47.36509704589844
637 38.1746826171875 47.192741

816 37.756282806396484 45.789588928222656
817 37.7294807434082 45.880638122558594
818 37.719276428222656 46.13801956176758
819 37.72201156616211 46.2852897644043
820 37.77241516113281 46.203392028808594
821 37.76557922363281 46.047332763671875
822 37.75270080566406 45.950462341308594
823 37.7220458984375 45.91666793823242
824 37.70851516723633 45.980648040771484
825 37.7053108215332 46.03730773925781
826 37.70082473754883 45.99527359008789
827 37.674686431884766 45.834171295166016
828 37.670108795166016 45.74917984008789
829 37.6365852355957 45.72897720336914
830 37.597312927246094 45.8675537109375
831 37.58539962768555 46.063232421875
832 37.557411193847656 46.14111328125
833 37.5510368347168 46.083770751953125
834 37.56678771972656 46.044761657714844
835 37.55548095703125 46.09617233276367
836 37.54751205444336 46.177494049072266
837 37.56547927856445 46.09987258911133
838 37.53812026977539 46.07271957397461
839 37.49268341064453 46.06395721435547
840 37.47574234008789 46.12620162963

1018 36.69049072265625 45.26267623901367
1019 36.67869567871094 45.13559341430664
1020 36.69881057739258 45.03692626953125
1021 36.68067932128906 45.172183990478516
1022 36.70823287963867 45.32518005371094
1023 36.66942596435547 45.17879104614258
1024 36.69403076171875 45.07484436035156
1025 36.66827392578125 45.14262390136719
1026 36.68401336669922 45.222171783447266
1027 36.6655158996582 45.16740036010742
1028 36.6811637878418 45.195003509521484
1029 36.65415573120117 45.252620697021484
1030 36.65814208984375 45.27988815307617
1031 36.649024963378906 45.18157958984375
1032 36.687171936035156 45.347740173339844
1033 36.665252685546875 45.36952590942383
1034 36.66494369506836 45.25621795654297
1035 36.6882438659668 45.09735870361328
1036 36.6827278137207 45.08200454711914
1037 36.65964889526367 45.202327728271484
1038 36.65766906738281 45.1754150390625
1039 36.669960021972656 45.11517333984375
1040 36.66734313964844 45.11696243286133
1041 36.66130065917969 45.256011962890625
1042 36.64

1216 36.37154006958008 45.12362289428711
1217 36.347965240478516 45.073490142822266
1218 36.353946685791016 45.26455307006836
1219 36.3441276550293 45.46818161010742
1220 36.354515075683594 45.528072357177734
1221 36.329673767089844 45.5932731628418
1222 36.3363151550293 45.758235931396484
1223 36.3087158203125 45.67266845703125
1224 36.30790328979492 45.43418502807617
1225 36.30223846435547 45.427146911621094
1226 36.302066802978516 45.437950134277344
1227 36.29483413696289 45.312015533447266
1228 36.287391662597656 44.99674606323242
1229 36.27864074707031 45.050716400146484
1230 36.24937438964844 44.9139289855957
1231 36.253597259521484 44.796173095703125
1232 36.252479553222656 44.92899703979492
1233 36.25486755371094 45.07213592529297
1234 36.235252380371094 45.1187858581543
1235 36.23667907714844 45.16196823120117
1236 36.34889221191406 45.570526123046875
1237 36.346961975097656 45.528873443603516
1238 37.09493637084961 46.32267761230469
1239 37.2408447265625 46.10700607299805
124

1415 37.512115478515625 45.902488708496094
1416 37.5329704284668 46.34036636352539
1417 37.6547737121582 46.49771499633789
1418 37.48843002319336 46.30755615234375
1419 37.51424026489258 46.1246452331543
1420 37.550865173339844 46.11441421508789
1421 37.520633697509766 46.2547607421875
1422 37.51912307739258 46.49947738647461
1423 37.559425354003906 46.603939056396484
1424 37.596649169921875 46.52249526977539
1425 37.59428024291992 46.48722457885742
1426 37.59212875366211 46.148468017578125
1427 37.59925842285156 46.128631591796875
1428 37.57410430908203 46.27518081665039
1429 37.570098876953125 46.35195541381836
1430 37.50486755371094 46.4326057434082
1431 37.50594711303711 46.54008865356445
1432 37.44740676879883 46.334102630615234
1433 37.40164566040039 46.220664978027344
1434 37.357723236083984 46.2208137512207
1435 37.31621170043945 46.2397575378418
1436 37.27190399169922 46.24217224121094
1437 37.21885299682617 46.1953239440918
1438 37.16508483886719 46.193031311035156
1439 37.14

1614 36.493499755859375 44.77189636230469
1615 36.47965621948242 44.63148498535156
1616 36.48565673828125 44.630672454833984
1617 36.487789154052734 44.7642707824707
1618 36.45333480834961 44.672698974609375
1619 36.40994644165039 45.240623474121094
1620 36.391239166259766 45.397315979003906
1621 36.3521728515625 45.47786331176758
1622 36.32344055175781 45.47367477416992
1623 36.30118942260742 45.46116638183594
1624 36.27460479736328 45.66249465942383
1625 36.25749588012695 46.072296142578125
1626 36.267723083496094 46.14305877685547
1627 36.23397445678711 46.04378128051758
1628 36.20520782470703 46.07711410522461
1629 36.197242736816406 46.17570114135742
1630 36.18858337402344 46.32304000854492
1631 36.17343521118164 46.27810287475586
1632 36.173336029052734 46.22992706298828
1633 36.21532440185547 46.2030029296875
1634 36.21632385253906 45.586585998535156
1635 36.3399658203125 45.39582443237305
1636 36.45380401611328 45.30989074707031
1637 36.507102966308594 45.44822311401367
1638 36

1812 37.95018005371094 46.257080078125
1813 38.018184661865234 46.33340835571289
1814 38.01642608642578 46.03352355957031
1815 38.04738235473633 45.75645065307617
1816 37.998497009277344 45.812744140625
1817 37.90888595581055 45.98572540283203
1818 37.94707489013672 46.18328094482422
1819 37.870574951171875 46.181156158447266
1820 37.83321762084961 46.02615737915039
1821 37.700443267822266 46.09458923339844
1822 37.677982330322266 46.182373046875
1823 37.626976013183594 46.255619049072266
1824 37.57058334350586 46.12904357910156
1825 37.546958923339844 45.77189636230469
1826 37.718135833740234 45.39588165283203
1827 37.68800354003906 44.46315383911133
1828 37.50510787963867 45.153316497802734
1829 37.59309387207031 45.44932556152344
1830 37.45585632324219 45.709041595458984
1831 37.365203857421875 45.25679397583008
1832 37.376399993896484 44.275081634521484
1833 37.253570556640625 44.06208038330078
1834 37.05362319946289 44.1207275390625
1835 37.010231018066406 44.23941421508789
1836 3

2010 35.76500701904297 44.27279281616211
2011 35.78132247924805 44.22113800048828
2012 35.78865051269531 44.435447692871094
2013 35.80792999267578 44.55386734008789
2014 35.78178024291992 44.14190673828125
2015 35.842533111572266 44.03604507446289
2016 36.379791259765625 44.43563461303711
2017 36.775516510009766 44.77035903930664
2018 36.63943862915039 44.96575164794922
2019 36.41533660888672 44.77592086791992
2020 36.461830139160156 44.84571075439453
2021 36.61901092529297 45.04413986206055
2022 36.68180847167969 45.32027053833008
2023 36.727840423583984 45.205177307128906
2024 36.72573471069336 44.85914611816406
2025 36.74285125732422 44.54872512817383
2026 36.64324188232422 44.8195915222168
2027 36.743507385253906 45.43653869628906
2028 36.851951599121094 45.723480224609375
2029 36.67911148071289 45.34945297241211
2030 36.63716506958008 44.82468032836914
2031 36.600921630859375 44.47480773925781
2032 36.60029602050781 44.28969192504883
2033 36.527950286865234 44.51226043701172
2034 

2208 35.33072280883789 43.771209716796875
2209 35.2917594909668 43.61641311645508
2210 35.30812072753906 43.47312545776367
2211 35.30336380004883 43.75418472290039
2212 35.4044303894043 43.402530670166016
2213 35.61328887939453 44.18842697143555
2214 35.46015548706055 44.027587890625
2215 35.614990234375 43.57655715942383
2216 35.6844596862793 43.503543853759766
2217 35.68153381347656 43.82307434082031
2218 35.735687255859375 44.08013916015625
2219 35.64442825317383 44.02785873413086
2220 35.66947555541992 43.78961944580078
2221 35.67410659790039 43.58111572265625
2222 35.58258819580078 43.62321853637695
2223 35.60981750488281 44.138755798339844
2224 35.6028938293457 44.14802932739258
2225 35.548824310302734 43.77556228637695
2226 35.54766082763672 43.6646728515625
2227 35.48644256591797 44.15309143066406
2228 35.49472427368164 44.239952087402344
2229 35.45964431762695 43.84347915649414
2230 35.47300720214844 43.77435302734375
2231 35.51152420043945 43.97759246826172
2232 35.5534515380

2407 34.97987747192383 43.92095947265625
2408 34.98501205444336 43.82750701904297
2409 35.58460235595703 44.68819808959961
2410 35.213134765625 44.209598541259766
2411 35.397178649902344 43.37909698486328
2412 35.62270736694336 43.10086441040039
2413 35.26394271850586 43.29338455200195
2414 35.35457992553711 44.07536315917969
2415 35.36751174926758 44.19469451904297
2416 35.20079040527344 43.5159797668457
2417 35.40995407104492 43.28028106689453
2418 35.218528747558594 43.448272705078125
2419 35.1630859375 44.05827331542969
2420 35.25471115112305 44.247249603271484
2421 35.07474899291992 43.88945388793945
2422 35.17900848388672 43.46150207519531
2423 35.070533752441406 43.634464263916016
2424 35.144866943359375 44.00447463989258
2425 35.090065002441406 43.93342208862305
2426 35.12186050415039 43.52464294433594
2427 35.24101257324219 44.254878997802734
2428 35.32328414916992 44.24678421020508
2429 35.232269287109375 43.58208084106445
2430 35.31258010864258 43.381290435791016
2431 35.281

2605 35.46881103515625 43.540565490722656
2606 35.37432098388672 43.58475875854492
2607 35.38582229614258 43.79775619506836
2608 35.371795654296875 43.19873046875
2609 35.686737060546875 43.39186096191406
2610 35.650428771972656 43.94270706176758
2611 35.68405532836914 44.41960906982422
2612 35.85511779785156 44.89424514770508
2613 35.911617279052734 44.65113067626953
2614 36.07157897949219 44.37152862548828
2615 36.101341247558594 44.586936950683594
2616 36.02072525024414 44.53605270385742
2617 36.01003646850586 44.59003448486328
2618 35.945858001708984 44.628841400146484
2619 35.98671340942383 44.38785171508789
2620 36.03812026977539 44.31329345703125
2621 35.92002487182617 44.002967834472656
2622 35.926902770996094 43.773887634277344
2623 36.071624755859375 43.9338264465332
2624 35.90793991088867 43.73345184326172
2625 35.84428787231445 43.61677932739258
2626 35.81695556640625 43.670650482177734
2627 35.83198165893555 43.92963409423828
2628 35.83043670654297 43.714393615722656
2629 

2803 37.35737228393555 44.21686553955078
2804 37.3356819152832 44.227352142333984
2805 37.20907211303711 44.640899658203125
2806 37.292118072509766 45.12913131713867
2807 37.17943572998047 45.08879852294922
2808 36.98420715332031 44.7480354309082
2809 37.1138801574707 44.247493743896484
2810 36.90340805053711 44.418514251708984
2811 36.82848358154297 44.79487228393555
2812 36.70049285888672 44.95708084106445
2813 36.64427947998047 44.93696212768555
2814 36.68207550048828 44.962955474853516
2815 36.6596565246582 44.501041412353516
2816 36.95293426513672 44.38900375366211
2817 36.762718200683594 44.67399597167969
2818 36.88441467285156 44.78237533569336
2819 36.95428466796875 44.485591888427734
2820 37.18777847290039 44.41788101196289
2821 37.068729400634766 44.9797477722168
2822 37.07575988769531 45.06247329711914
2823 37.069007873535156 44.92588806152344
2824 36.95633316040039 44.68213653564453
2825 36.975345611572266 44.15553665161133
2826 37.01914978027344 44.23136901855469
2827 37.1

3003 35.930301666259766 43.54893112182617
3004 35.95269012451172 43.34597396850586
3005 35.887962341308594 43.39466857910156
3006 35.87776565551758 43.63867950439453
3007 35.87705993652344 43.80796813964844
3008 35.88130187988281 43.55537033081055
3009 35.82400131225586 43.77302932739258
3010 35.75810241699219 43.569332122802734
3011 35.70024871826172 43.47529602050781
3012 35.672908782958984 43.392234802246094
3013 35.66107940673828 43.519596099853516
3014 36.11448669433594 42.7207145690918
3015 36.09840393066406 42.75007629394531
3016 35.856266021728516 43.0079460144043
3017 35.72261047363281 43.4086799621582
3018 35.95764923095703 43.8851203918457
3019 35.93354415893555 43.77692413330078
3020 35.78312683105469 43.37669372558594
3021 35.68559646606445 42.9576416015625
3022 35.74338912963867 42.77688217163086
3023 35.60774612426758 42.886314392089844
3024 35.59056854248047 43.2057991027832
3025 35.61288070678711 43.273521423339844
3026 35.53659439086914 43.09947204589844
3027 35.52643

3202 36.39765548706055 44.508296966552734
3203 36.47353744506836 44.89423751831055
3204 36.23884582519531 44.593929290771484
3205 36.67281723022461 44.34667205810547
3206 36.40840148925781 44.36580276489258
3207 37.26424789428711 46.07711410522461
3208 37.25202560424805 46.5430793762207
3209 36.66984176635742 44.52305603027344
3210 37.762325286865234 43.853607177734375
3211 37.83373260498047 43.73353576660156
3212 37.69580841064453 43.74300003051758
3213 37.081016540527344 43.91747283935547
3214 37.2760009765625 45.3709716796875
3215 37.72994613647461 45.2733039855957
3216 37.739959716796875 45.65843963623047
3217 37.33625793457031 45.17192840576172
3218 37.0584602355957 44.15314483642578
3219 37.391090393066406 43.632179260253906
3220 37.47188186645508 44.302589416503906
3221 37.15006637573242 44.399967193603516
3222 36.87560272216797 45.09640884399414
3223 37.05705261230469 45.68402862548828
3224 37.021820068359375 45.68268585205078
3225 36.80229187011719 45.28184509277344
3226 36.62

3400 35.44096755981445 44.60430908203125
3401 35.44609451293945 44.56135559082031
3402 35.409080505371094 44.736228942871094
3403 35.41630172729492 44.91731643676758
3404 35.40279769897461 44.908538818359375
3405 35.37938690185547 44.73857498168945
3406 36.60301971435547 44.3093147277832
3407 37.17087173461914 44.654571533203125
3408 36.26948547363281 44.38298797607422
3409 36.315582275390625 45.037723541259766
3410 36.700408935546875 45.67925262451172
3411 37.133968353271484 46.087379455566406
3412 36.896305084228516 45.87765121459961
3413 36.6124267578125 45.71685791015625
3414 36.468467712402344 45.4416389465332
3415 36.47966003417969 45.1146240234375
3416 36.48594284057617 44.714881896972656
3417 36.26914978027344 44.668521881103516
3418 36.3233757019043 44.970130920410156
3419 36.37966537475586 45.527400970458984
3420 36.34025573730469 45.275691986083984
3421 36.18030548095703 45.017555236816406
3422 36.158756256103516 44.67641830444336
3423 36.18733215332031 44.76868438720703
342

3599 35.05258560180664 43.045501708984375
3600 35.04361343383789 43.004425048828125
3601 35.03584671020508 42.96205520629883
3602 35.029598236083984 42.980281829833984
3603 35.03425216674805 42.960594177246094
3604 35.03175354003906 43.00668716430664
3605 35.032962799072266 42.95521545410156
3606 35.07239532470703 43.19971466064453
3607 35.12580108642578 43.503971099853516
3608 35.196632385253906 42.858577728271484
3609 35.14979934692383 44.646270751953125
3610 35.162010192871094 44.83854293823242
3611 35.16889953613281 44.47412109375
3612 35.136653900146484 45.0064697265625
3613 35.18131637573242 45.2066535949707
3614 35.1175537109375 44.721683502197266
3615 35.138648986816406 44.301002502441406
3616 35.13623809814453 44.58506774902344
3617 35.08909606933594 44.588478088378906
3618 35.09605026245117 42.91505813598633
3619 35.076961517333984 43.0097541809082
3620 35.09694290161133 43.04542541503906
3621 35.093589782714844 42.80479431152344
3622 35.08768844604492 42.78373336791992
3623 

3798 34.924781799316406 42.69367218017578
3799 35.06382369995117 45.12726593017578
3800 35.460182189941406 45.454463958740234
3801 35.763938903808594 45.319068908691406
3802 36.224273681640625 45.22528076171875
3803 36.271522521972656 45.595611572265625
3804 36.45436477661133 46.31289291381836
3805 36.60586166381836 46.2822380065918
3806 36.75462341308594 45.93239212036133
3807 36.790550231933594 45.70922088623047
3808 36.851375579833984 45.904335021972656
3809 36.89069747924805 45.82808303833008
3810 36.90908432006836 45.960208892822266
3811 36.91318130493164 46.446746826171875
3812 36.791316986083984 45.84463882446289
3813 36.781410217285156 44.98701477050781
3814 36.70475387573242 44.025665283203125
3815 36.732120513916016 44.79462814331055
3816 36.581966400146484 44.256526947021484
3817 36.5181999206543 44.10581588745117
3818 36.44422912597656 44.07832717895508
3819 36.3243408203125 44.204872131347656
3820 36.27534484863281 44.283935546875
3821 36.20750427246094 44.323822021484375


In [None]:
test_pred = torch.split(test_pr, test_aggregate.shape[0])
test_fold = [None for x in range(len(ORDER))]
if cuda_av:
    for appliance_num, appliance in enumerate(ORDER):
        test_fold[appliance_num] = test_pred[appliance_num].cpu().data.numpy().reshape(-1, 24)
else:
    for appliance_num, appliance in enumerate(ORDER):
        test_fold[appliance_num] = test_pred[appliance_num].data.numpy().reshape(-1, 24)

In [None]:
test_gt_fold = [None for x in range(len(ORDER))]
for appliance_num, appliance in enumerate(ORDER):
    test_gt_fold[appliance_num] = test[:, APPLIANCE_ORDER.index(appliance), :, :].reshape(
        test_aggregate.shape[0],
        -1, 1).reshape(-1, 24)

In [None]:
test_error = {}

for appliance_num, appliance in enumerate(ORDER):
    test_error[appliance] = mean_absolute_error(test_fold[appliance_num], test_gt_fold[appliance_num])

In [None]:
test_error

In [32]:
test_error

{'dr': 51.895019602707926,
 'dw': 11.02489129080851,
 'fridge': 28.159192659996876,
 'hvac': 285.86724458947918,
 'mw': 7.4520018024466692}