In [2]:
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn as nn
from torchvision import datasets,transforms
import torch.nn.functional as F
import sys
sys.path.append("../code/")
from dataloader import APPLIANCE_ORDER, get_train_test
from sklearn.metrics import mean_absolute_error
import os

In [3]:
cuda_av = False
if torch.cuda.is_available():
    cuda_av = True

torch.manual_seed(0)
np.random.seed(0)

In [4]:
class CustomRNN(nn.Module):
    def __init__(self, cell_type, hidden_size, num_layers, bidirectional):
        super(CustomRNN, self).__init__()
        torch.manual_seed(0)

        if bidirectional:
            self.num_directions = 2
        else:
            self.num_directions = 1
        if cell_type == "RNN":
            self.rnn = nn.RNN(input_size=1, hidden_size=hidden_size,
                              num_layers=num_layers, batch_first=True,
                              bidirectional=bidirectional)
        elif cell_type == "GRU":
            self.rnn = nn.GRU(input_size=1, hidden_size=hidden_size,
                              num_layers=num_layers, batch_first=True,
                              bidirectional=bidirectional)
        else:
            self.rnn = nn.LSTM(input_size=1, hidden_size=hidden_size,
                               num_layers=num_layers, batch_first=True,
                               bidirectional=bidirectional)

        self.linear = nn.Linear(hidden_size * self.num_directions, 1)
        self.act = nn.ReLU()

    def forward(self, x):
        pred, hidden = self.rnn(x, None)
        pred = self.linear(pred).view(pred.data.shape[0], -1, 1)
        pred = torch.min(pred, x)
        return pred

class CustomCNN(nn.Module):
    def __init__(self):
        super(CustomCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, kernel_size=7, stride=1, padding=2)
        self.bn1 = nn.BatchNorm2d(20)

        self.conv2 = nn.Conv2d(20, 16, kernel_size=2, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(16)

        self.conv3 = nn.Conv2d(16, 64, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(64)

        self.conv4 = nn.ConvTranspose2d(64, 16, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(16)

        self.conv5 = nn.ConvTranspose2d(16, 6, kernel_size=4, stride=2, padding=1)
        self.bn5 = nn.BatchNorm2d(6)

        self.conv6 = nn.ConvTranspose2d(6, 1, kernel_size=5, stride=1, padding=2) 
        
        self.act = nn.ReLU()
        
    # forward method
    def forward(self, input):
        
        e1 = self.conv1(input)
        bn1 = self.bn1(self.act(e1))
        e2 = self.bn2(self.conv2(bn1))        
        e5 = self.bn5(self.conv5(e2))
        e6 = self.conv6(e5)
        return e6

In [15]:
class JointlyNN(nn.Module):
    def __init__(self, cell_type, hidden_size, num_layers, bidirectional, num_appliance):
        super(JointlyNN, self).__init__()
        self.num_appliance = num_appliance
        self.preds = {}
        self.order = ORDER
        for appliance in range(self.num_appliance):
            if ORDER[appliance] in ['hvac', 'fridge']:
                print("use RNN")
                if cuda_av:
                    setattr(self, "Appliance_" + str(appliance), CustomRNN(cell_type, hidden_size,
                                                                           num_layers, bidirectional).cuda())
                else:
                    setattr(self, "Appliance_" + str(appliance), CustomRNN(cell_type, hidden_size,
                                                                           num_layers, bidirectional))
            else:
                print("use CNN")
                if cuda_av:
                    setattr(self, "Appliance_" + str(appliance), CustomCNN().cuda())
                else:
                    setattr(self, "Appliance_" + str(appliance), CustomCNN())


    def forward(self, *args):
        agg_current = args[0]
        flag = False
        if np.random.random() > args[1]:
            flag = True
        else:
            pass
        for appliance in range(self.num_appliance):
            agg_current = agg_current.contiguous()
            if ORDER[appliance] in ['hvac', 'fridge']:
                agg_current = agg_current.view(agg_current.shape[0], -1, 1)
            else:
                agg_current = agg_current.view(agg_current.shape[0], 1, -1, 24)
            
            self.preds[appliance] = getattr(self, "Appliance_" + str(appliance))(agg_current)
            
            agg_current = agg_current.view(agg_current.shape[0], 1, -1, 24)
            self.preds[appliance] = self.preds[appliance].view(self.preds[appliance].shape[0], 1, -1, 24)
            

        return torch.cat([self.preds[a] for a in range(self.num_appliance)])

In [6]:
def preprocess(train, valid, test):
    out_train = [None for temp in range(len(ORDER))]
    for a_num, appliance in enumerate(ORDER):
        out_train[a_num] = Variable(
            torch.Tensor(train[:, APPLIANCE_ORDER.index(appliance), :, :].reshape((train.shape[0], 1, -1, 24))))
        if cuda_av:
            out_train[a_num] = out_train[a_num].cuda()

    out_valid = [None for temp in range(len(ORDER))]
    for a_num, appliance in enumerate(ORDER):
        out_valid[a_num] = Variable(
            torch.Tensor(valid[:, APPLIANCE_ORDER.index(appliance), :, :].reshape((valid.shape[0], 1, -1, 24))))
        if cuda_av:
            out_valid[a_num] = out_valid[a_num].cuda()
            
    out_test = [None for temp in range(len(ORDER))]
    for a_num, appliance in enumerate(ORDER):
        out_test[a_num] = Variable(
            torch.Tensor(test[:, APPLIANCE_ORDER.index(appliance), :, :].reshape((test.shape[0], 1, -1, 24))))
        if cuda_av:
            out_test[a_num] = out_test[a_num].cuda()

    return out_train, out_valid, out_test

In [8]:
dataset = 1
fold_num = 0
num_folds = 5
train, test = get_train_test(dataset, num_folds=num_folds, fold_num=fold_num)
valid = train[int(0.8*len(train)):].copy()
train = train[:int(0.8 * len(train))].copy()
train_aggregate = train[:, 0, :, :].reshape(train.shape[0], 1, -1, 24)
valid_aggregate = valid[:, 0, :, :].reshape(valid.shape[0], 1, -1, 24)
test_aggregate = test[:, 0, :, :].reshape(test.shape[0], 1, -1, 24)

In [9]:
ORDER = ['fridge', 'dr', 'hvac', 'dw', 'mw']
out_train, out_valid, out_test = preprocess(train, valid, test)

In [10]:
inp = Variable(torch.Tensor(train_aggregate), requires_grad=False)
valid_inp = Variable(torch.Tensor(valid_aggregate), requires_grad=False)
test_inp = Variable(torch.Tensor(test_aggregate), requires_grad=False)
if cuda_av:
    inp = inp.cuda()
    valid_inp = valid_inp.cuda()
    test_inp = test_inp.cuda()
valid_out = torch.cat([out_valid[appliance_num] for appliance_num, appliance in enumerate(ORDER)])
test_out = torch.cat([out_test[appliance_num] for appliance_num, appliance in enumerate(ORDER)])
train_out = torch.cat([out_train[appliance_num] for appliance_num, appliance in enumerate(ORDER)])

In [16]:
loss_func = nn.L1Loss()
model = JointlyNN('GRU', 50, 2, True, len(ORDER))
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
p=0
params = [inp, p]

if cuda_av:
    model = model.cuda()
    loss_func = loss_func.cuda()
    
for a_num, appliance in enumerate(ORDER):
    params.append(out_train[a_num])

if cuda_av:
    train_out = train_out.cuda()

use RNN
use CNN
use RNN
use CNN
use CNN


In [None]:
for k in range(4000):
    pred = model(*params)
    optimizer.zero_grad()
    loss = loss_func(pred, train_out)
    
    
    
    if cuda_av:
        test_inp = test_inp.cuda()
    test_params = [test_inp, -2]
    for i in range(len(ORDER)):
        test_params.append(None)
    test_pr = model(*test_params)
    test_loss = loss_func(test_pr, test_out)
    
    print(k, loss.data[0], test_loss.data[0])

    loss.backward()
    optimizer.step()

0 221.59133911132812 188.591796875
1 220.53897094726562 187.48281860351562
2 219.91329956054688 186.83465576171875
3 219.52989196777344 186.43878173828125
4 219.22613525390625 186.1263427734375
5 218.92466735839844 185.80905151367188
6 218.62432861328125 185.4873046875
7 218.30767822265625 185.1513671875
8 217.98666381835938 184.80613708496094
9 217.66726684570312 184.4580078125
10 217.33322143554688 184.09536743164062
11 217.0020751953125 183.73162841796875
12 216.6675567626953 183.35255432128906
13 216.334716796875 182.98556518554688
14 216.0009002685547 182.5986785888672
15 215.6753692626953 182.2126007080078
16 215.3465576171875 181.83694458007812
17 215.01939392089844 181.45811462402344
18 214.6969757080078 181.0842742919922
19 214.36981201171875 180.72665405273438
20 214.04638671875 180.38150024414062
21 213.7194366455078 180.03146362304688
22 213.3905487060547 179.68411254882812
23 213.05931091308594 179.34034729003906
24 212.7296600341797 179.00279235839844
25 212.4017333984375

204 175.59304809570312 141.62095642089844
205 175.43492126464844 141.46566772460938
206 175.27964782714844 141.3124542236328
207 175.1210479736328 141.14593505859375
208 174.97303771972656 141.0147247314453
209 174.84034729003906 140.87338256835938
210 174.6712646484375 140.7008056640625
211 174.5189208984375 140.5298309326172
212 174.3612060546875 140.3745880126953
213 174.21871948242188 140.25047302246094
214 174.0660400390625 140.073974609375
215 173.9269256591797 139.96897888183594
216 173.7713165283203 139.8174285888672
217 173.64129638671875 139.6475067138672
218 173.47262573242188 139.49017333984375
219 173.31532287597656 139.3277587890625
220 173.16188049316406 139.1806182861328
221 173.02796936035156 139.0553741455078
222 172.88894653320312 138.9160919189453
223 172.7077178955078 138.73069763183594
224 172.5764617919922 138.5587158203125
225 172.42872619628906 138.43997192382812
226 172.274658203125 138.2880096435547
227 172.11793518066406 138.11422729492188
228 171.9496154785

404 147.04470825195312 114.14800262451172
405 147.30091857910156 114.68865966796875
406 146.98777770996094 114.27729797363281
407 146.9875946044922 113.7348403930664
408 146.63563537597656 113.6217041015625
409 146.7084503173828 114.05089569091797
410 146.5989227294922 114.02813720703125
411 146.20516967773438 113.2940902709961
412 146.2502899169922 113.23683166503906
413 146.05392456054688 113.41949462890625
414 145.96847534179688 113.4647445678711
415 145.68846130371094 112.98270416259766
416 145.65374755859375 112.77925872802734
417 145.4578094482422 112.82273864746094
418 145.36170959472656 112.78465270996094
419 145.1898193359375 112.40510559082031
420 145.08250427246094 112.30917358398438
421 145.05526733398438 112.70584869384766
422 144.91383361816406 112.50019836425781
423 144.81719970703125 112.08880615234375
424 144.5971221923828 112.0317153930664
425 144.5751953125 112.1870346069336
426 144.32931518554688 111.79803466796875
427 144.28890991210938 111.36991882324219
428 144.0

604 126.66309356689453 95.6648178100586
605 126.42514038085938 95.99076080322266
606 126.60081481933594 96.51979064941406
607 126.32902526855469 95.81584930419922
608 126.29386901855469 95.72312927246094
609 126.17681121826172 96.03165435791016
610 125.97064971923828 95.77095794677734
611 126.00428009033203 95.44572448730469
612 125.80223083496094 95.55712890625
613 125.80477905273438 95.70117950439453
614 125.73650360107422 95.38253784179688
615 125.63905334472656 95.33106994628906
616 125.57166290283203 95.3681411743164
617 125.35004425048828 95.12989807128906
618 125.23419189453125 94.88323974609375
619 125.24005889892578 95.07588195800781
620 125.13829040527344 94.97760772705078
621 125.20755767822266 94.60985565185547
622 124.99508666992188 94.94953918457031
623 125.08230590820312 94.96253204345703
624 125.03953552246094 94.7693862915039
625 124.8437728881836 94.80038452148438
626 124.58856964111328 94.4322738647461
627 124.545654296875 94.09615325927734
628 124.56702423095703 94.

807 111.5435562133789 83.37512969970703
808 111.73375701904297 83.54891204833984
809 111.46770477294922 83.29366302490234
810 111.18451690673828 82.86377716064453
811 110.96394348144531 82.6608657836914
812 110.9190673828125 82.57435607910156
813 110.88264465332031 82.60594177246094
814 110.93831634521484 82.45369720458984
815 110.99229431152344 82.7507095336914
816 111.04061889648438 83.09256744384766
817 110.99272918701172 82.56714630126953
818 110.73026275634766 82.58192443847656
819 110.50286865234375 82.25682067871094
820 110.50115203857422 82.0599365234375
821 110.3631362915039 82.2059097290039
822 110.3045654296875 82.05579376220703
823 110.24720001220703 81.92823028564453
824 110.27256774902344 82.21653747558594
825 110.17581176757812 81.92342376708984
826 110.1761474609375 82.07398223876953
827 110.47473907470703 82.44507598876953
828 110.39584350585938 82.63488006591797
829 110.16395568847656 82.01507568359375
830 109.87548065185547 81.92259979248047
831 109.72781372070312 81

1009 99.71269989013672 73.5921401977539
1010 99.75257873535156 73.1859130859375
1011 99.78924560546875 73.4818344116211
1012 99.66008758544922 73.52843475341797
1013 100.06207275390625 73.42796325683594
1014 99.758056640625 73.82237243652344
1015 99.41883850097656 73.10543823242188
1016 99.51667022705078 72.87403869628906
1017 99.28021240234375 73.10018920898438
1018 99.22260284423828 72.96623992919922
1019 99.26535034179688 72.81141662597656
1020 99.28019714355469 72.90619659423828
1021 99.29817199707031 73.19056701660156
1022 99.1584701538086 72.48664093017578
1023 99.19253540039062 73.07766723632812
1024 99.03630828857422 72.61307525634766
1025 99.0149154663086 72.65447235107422
1026 98.8198471069336 72.55033111572266
1027 98.8517074584961 72.57752990722656
1028 98.66573333740234 72.23361206054688
1029 98.86213684082031 72.79574584960938
1030 98.654052734375 72.2885513305664
1031 98.7942886352539 72.27678680419922
1032 98.6830062866211 72.27810668945312
1033 98.85437774658203 73.174

1212 90.49796295166016 65.94792938232422
1213 90.42053985595703 65.86297607421875
1214 90.25248718261719 65.83869934082031
1215 90.2596435546875 65.40313720703125
1216 90.29509735107422 65.94236755371094
1217 90.29420471191406 65.64134216308594
1218 90.42413330078125 65.78102111816406
1219 90.26986694335938 65.71533966064453
1220 89.97535705566406 65.59355163574219
1221 90.02620697021484 65.15214538574219
1222 89.82508087158203 65.38895416259766
1223 89.788818359375 65.4250717163086
1224 90.28321075439453 65.25761413574219
1225 89.85340881347656 65.5826187133789
1226 90.07706451416016 65.6961898803711
1227 89.73111724853516 65.21856689453125
1228 89.74693298339844 65.0900650024414
1229 89.5976791381836 65.22377014160156
1230 89.52632904052734 65.10616302490234
1231 89.64666748046875 64.76739501953125
1232 89.4472427368164 65.16111755371094
1233 89.38105010986328 65.06278991699219
1234 89.5246810913086 64.70793914794922
1235 89.7221450805664 65.57382202148438
1236 90.03382873535156 65.5

1413 82.90056610107422 59.49517059326172
1414 83.26374053955078 59.79403305053711
1415 82.99333190917969 59.61423110961914
1416 82.7513198852539 59.41602325439453
1417 83.25530242919922 59.588096618652344
1418 82.97843170166016 59.608917236328125
1419 83.0190658569336 59.740379333496094
1420 83.02620697021484 59.65638732910156
1421 82.74738311767578 59.32038879394531
1422 83.01820373535156 59.68183517456055
1423 82.75177001953125 59.1791877746582
1424 82.6156997680664 59.35606002807617
1425 83.00984954833984 59.68919372558594
1426 82.86512756347656 59.41814041137695
1427 82.71917724609375 59.388851165771484
1428 83.10256958007812 59.78245544433594
1429 82.33201599121094 58.9982795715332
1430 82.5126953125 59.1524772644043
1431 82.595947265625 59.15542221069336
1432 82.45320892333984 59.02534103393555
1433 82.08956146240234 58.8493766784668
1434 82.3886947631836 58.90665817260742
1435 82.30347442626953 59.03120040893555
1436 81.986083984375 58.62898254394531
1437 82.1668472290039 58.628

1614 76.52981567382812 54.46030044555664
1615 76.72937774658203 54.83625030517578
1616 77.03895568847656 54.70888137817383
1617 76.42628479003906 54.478702545166016
1618 76.21052551269531 54.0473518371582
1619 76.18253326416016 53.91887283325195
1620 76.5958023071289 54.6834602355957
1621 76.62031555175781 54.596778869628906
1622 76.95694732666016 54.74241256713867
1623 76.72415924072266 54.58369445800781
1624 76.22639465332031 54.188045501708984
1625 76.69363403320312 54.14623260498047
1626 76.40914154052734 54.23587417602539
1627 76.18423461914062 54.195003509521484
1628 76.74149322509766 54.324180603027344
1629 76.37481689453125 54.32381057739258
1630 76.25405883789062 54.319095611572266
1631 76.36695861816406 54.40336990356445
1632 76.29535675048828 54.216373443603516
1633 76.1444091796875 54.25569152832031
1634 76.05653381347656 54.046546936035156
1635 76.09176635742188 54.15304946899414
1636 75.67010498046875 53.821327209472656
1637 75.96058654785156 54.03062438964844
1638 75.586

1814 72.23290252685547 51.13351821899414
1815 72.33272552490234 51.05590057373047
1816 72.39204406738281 51.247859954833984
1817 71.83263397216797 50.69423294067383
1818 72.3489990234375 51.20930099487305
1819 72.2067642211914 51.299198150634766
1820 71.62918090820312 50.69375991821289
1821 72.41702270507812 51.42266082763672
1822 71.75265502929688 51.22365951538086
1823 71.92282104492188 51.18339538574219
1824 72.048828125 51.37051010131836
1825 71.87661743164062 51.14263916015625
1826 71.84736633300781 50.934078216552734
1827 71.79132080078125 51.06810760498047
1828 71.5684585571289 50.68655776977539
1829 72.12721252441406 51.47427749633789
1830 71.6717529296875 50.816043853759766
1831 71.41600799560547 50.63340377807617
1832 71.34772491455078 50.49947738647461
1833 71.5270004272461 50.72123336791992
1834 71.51114654541016 51.014404296875
1835 71.26661682128906 50.28000259399414
1836 71.48151397705078 50.78369903564453
1837 71.21664428710938 50.531612396240234
1838 71.06105041503906 

2015 67.45623779296875 47.644840240478516
2016 67.2793197631836 47.68109130859375
2017 67.381591796875 47.4635009765625
2018 67.39745330810547 47.94294357299805
2019 67.39608001708984 47.57496643066406
2020 67.87926483154297 48.09291076660156
2021 67.5971450805664 47.97181701660156
2022 67.64927673339844 48.181861877441406
2023 67.37601470947266 47.674278259277344
2024 67.42232513427734 48.008079528808594
2025 67.44075012207031 47.53406524658203
2026 67.38040924072266 47.76486587524414
2027 67.32099914550781 47.546722412109375
2028 67.1484146118164 47.358768463134766
2029 67.33283233642578 47.631622314453125
2030 67.01998901367188 47.39924621582031
2031 66.88249969482422 47.23411560058594
2032 67.40546417236328 47.59699630737305
2033 67.33634948730469 47.92496109008789
2034 67.3622817993164 47.49482727050781
2035 66.69747924804688 47.15458679199219
2036 67.18478393554688 47.614051818847656
2037 67.1434326171875 47.35868453979492
2038 67.76825714111328 48.404178619384766
2039 67.7959442

2216 63.939456939697266 45.646427154541016
2217 63.53832244873047 45.02824783325195
2218 63.698333740234375 45.213279724121094
2219 63.737117767333984 45.33865737915039
2220 63.67805862426758 45.00619125366211
2221 63.6036491394043 45.33757400512695
2222 63.544593811035156 45.12611770629883
2223 63.80990982055664 45.11144256591797
2224 63.61422348022461 45.2682991027832
2225 63.260353088378906 44.79944610595703
2226 63.853023529052734 45.2604866027832
2227 63.422645568847656 45.16371536254883
2228 63.265995025634766 44.89615249633789
2229 63.532745361328125 44.80424118041992
2230 63.457515716552734 45.33868408203125
2231 63.25870132446289 44.97720718383789
2232 63.731292724609375 45.059837341308594
2233 64.00029754638672 45.621063232421875
2234 63.41191864013672 45.18386459350586
2235 63.45307159423828 44.82365417480469
2236 63.44258117675781 45.39889907836914
2237 63.070045471191406 44.662654876708984
2238 63.403419494628906 44.8907585144043
2239 63.666847229003906 45.069801330566406
