In [1]:
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn as nn
from torchvision import datasets,transforms
import torch.nn.functional as F
import sys
sys.path.append("../code/")
from dataloader import APPLIANCE_ORDER, get_train_test
from sklearn.metrics import mean_absolute_error
import os

In [2]:
cuda_av = False
if torch.cuda.is_available():
    cuda_av = True

torch.manual_seed(0)
np.random.seed(0)

In [3]:

class CustomRNN(nn.Module):
    def __init__(self, cell_type, hidden_size, num_layers, bidirectional):
        super(CustomRNN, self).__init__()
        torch.manual_seed(0)

        if bidirectional:
            self.num_directions = 2
        else:
            self.num_directions = 1
        if cell_type == "RNN":
            self.rnn = nn.RNN(input_size=1, hidden_size=hidden_size,
                              num_layers=num_layers, batch_first=True,
                              bidirectional=bidirectional)
        elif cell_type == "GRU":
            self.rnn = nn.GRU(input_size=1, hidden_size=hidden_size,
                              num_layers=num_layers, batch_first=True,
                              bidirectional=bidirectional)
        else:
            self.rnn = nn.LSTM(input_size=1, hidden_size=hidden_size,
                               num_layers=num_layers, batch_first=True,
                               bidirectional=bidirectional)

        self.linear = nn.Linear(hidden_size * self.num_directions, 1)
        self.act = nn.ReLU()

    def forward(self, x):
        pred, hidden = self.rnn(x, None)
        pred = self.linear(pred).view(pred.data.shape[0], -1, 1)
        pred = torch.min(pred, x)
        return pred

class CustomCNN(nn.Module):
    def __init__(self):
        super(CustomCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, kernel_size=7, stride=1, padding=2)
        self.bn1 = nn.BatchNorm2d(20)

        self.conv2 = nn.Conv2d(20, 16, kernel_size=2, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(16)

        self.conv3 = nn.Conv2d(16, 64, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(64)

        self.conv4 = nn.ConvTranspose2d(64, 16, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(16)

        self.conv5 = nn.ConvTranspose2d(16, 6, kernel_size=4, stride=2, padding=1)
        self.bn5 = nn.BatchNorm2d(6)

        self.conv6 = nn.ConvTranspose2d(6, 1, kernel_size=5, stride=1, padding=2) 
        
        self.act = nn.ReLU()
        
    # forward method
    def forward(self, input):
        
        e1 = self.conv1(input)
        bn1 = self.bn1(self.act(e1))
        e2 = self.bn2(self.conv2(bn1))        
        e5 = self.bn5(self.conv5(e2))
        e6 = self.conv6(e5)
        return e6

In [5]:
def preprocess(train, valid, test):
    out_train = [None for temp in range(len(ORDER))]
    for a_num, appliance in enumerate(ORDER):
        out_train[a_num] = Variable(
            torch.Tensor(train[:, APPLIANCE_ORDER.index(appliance), :, :].reshape((train.shape[0], 1, -1, 24))))
        if cuda_av:
            out_train[a_num] = out_train[a_num].cuda()

    out_valid = [None for temp in range(len(ORDER))]
    for a_num, appliance in enumerate(ORDER):
        out_valid[a_num] = Variable(
            torch.Tensor(valid[:, APPLIANCE_ORDER.index(appliance), :, :].reshape((valid.shape[0], 1, -1, 24))))
        if cuda_av:
            out_valid[a_num] = out_valid[a_num].cuda()
            
    out_test = [None for temp in range(len(ORDER))]
    for a_num, appliance in enumerate(ORDER):
        out_test[a_num] = Variable(
            torch.Tensor(test[:, APPLIANCE_ORDER.index(appliance), :, :].reshape((test.shape[0], 1, -1, 24))))
        if cuda_av:
            out_test[a_num] = out_test[a_num].cuda()

    return out_train, out_valid, out_test

In [7]:
dataset = 1
fold_num = 0
num_folds = 5
train, test = get_train_test(dataset, num_folds=num_folds, fold_num=fold_num)
valid = train[int(0.8*len(train)):].copy()
train = train[:int(0.8 * len(train))].copy()
train_aggregate = train[:, 0, :, :].reshape(train.shape[0], 1, -1, 24)
valid_aggregate = valid[:, 0, :, :].reshape(valid.shape[0], 1, -1, 24)
test_aggregate = test[:, 0, :, :].reshape(test.shape[0], 1, -1, 24)

In [12]:
ORDER = ['hvac', 'dr']
out_train, out_valid, out_test = preprocess(train, valid, test)

    


if cuda_av:
    model = model.cuda()
    loss_func = loss_func.cuda()

In [14]:
inp = Variable(torch.Tensor(train_aggregate), requires_grad=False)
valid_inp = Variable(torch.Tensor(valid_aggregate), requires_grad=False)
test_inp = Variable(torch.Tensor(test_aggregate), requires_grad=False)
if cuda_av:
    inp = inp.cuda()
    valid_inp = valid_inp.cuda()
    test_inp = test_inp.cuda()
valid_out = torch.cat([out_valid[appliance_num] for appliance_num, appliance in enumerate(ORDER)])
test_out = torch.cat([out_test[appliance_num] for appliance_num, appliance in enumerate(ORDER)])
train_out = torch.cat([out_train[appliance_num] for appliance_num, appliance in enumerate(ORDER)])

In [54]:
class AppliancesRNNCNN(nn.Module):
    def __init__(self, cell_type, hidden_size, num_layers, bidirectional, num_appliance):
        super(AppliancesRNNCNN, self).__init__()
        self.num_appliance = num_appliance
        self.preds = {}
        self.order = ORDER
        for appliance in range(self.num_appliance):
            if ORDER[appliance] in ['hvac', 'fridge']:
                print("use RNN")
                if cuda_av:
                    setattr(self, "Appliance_" + str(appliance), CustomRNN(cell_type, hidden_size,
                                                                           num_layers, bidirectional).cuda())
                else:
                   
                    setattr(self, "Appliance_" + str(appliance), CustomRNN(cell_type, hidden_size,
                                                                           num_layers, bidirectional))
            else:
                print("use CNN")
                if cuda_av:
                    setattr(self, "Appliance_" + str(appliance), CustomCNN().cuda())
                else:
                    setattr(self, "Appliance_" + str(appliance), CustomCNN())


    def forward(self, *args):
        agg_current = args[0]
        flag = False
        if np.random.random() > args[1]:
            flag = True
        # print("Subtracting prediction")
        else:
            pass
        # print("Subtracting true")
        for appliance in range(self.num_appliance):
#             print(agg_current.shape)
            # print(agg_current.mean().data[0])
#             print(ORDER[appliance])
            # print self.order[appliance]
            # print args[2+appliance]
            # print(getattr(self, "Appliance_" + str(appliance)))
            if ORDER[appliance] in ['hvac', 'fridge']:
                agg_current = agg_current.view(agg_current.shape[0], -1, 1)
            else:
                agg_current = agg_current.view(agg_current.shape[0], 1, -1, 24)
            
            self.preds[appliance] = getattr(self, "Appliance_" + str(appliance))(agg_current)
            
            agg_current = agg_current.view(agg_current.shape[0], 1, -1, 24)
            self.preds[appliance] = self.preds[appliance].view(self.preds[appliance].shape[0], 1, -1, 24)
            
            
            if flag:
                agg_current = agg_current - self.preds[appliance]
            else:
                agg_current = agg_current - args[2 + appliance]

        return torch.cat([self.preds[a] for a in range(self.num_appliance)])

In [55]:
loss_func = nn.L1Loss()
model = AppliancesRNNCNN('GRU', 50, 2, True, len(ORDER))
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
p=0
params = [inp, p]
for a_num, appliance in enumerate(ORDER):
    params.append(out_train[a_num])

if cuda_av:
    train_out = train_out.cuda()

use RNN
use CNN


In [61]:
for k in range(2000):
    pred = model(*params)
    optimizer.zero_grad()
    loss = loss_func(pred, train_out)
    
    
    
    if cuda_av:
        test_inp = test_inp.cuda()
    test_params = [test_inp, -2]
    for i in range(len(ORDER)):
        test_params.append(None)
    test_pr = model(*test_params)
    test_loss = loss_func(test_pr, test_out)
    
    print(k, loss.data[0], test_loss.data[0])

    loss.backward()
    optimizer.step()

0 273.1723327636719 201.75656127929688
1 272.9757995605469 201.95533752441406
2 272.6414489746094 201.6727294921875
3 272.8826599121094 201.45474243164062
4 272.5438232421875 201.78038024902344
5 272.3073425292969 201.37948608398438
6 272.17718505859375 201.28274536132812
7 271.6128234863281 200.67788696289062
8 272.0034484863281 200.94090270996094
9 272.0603332519531 201.56915283203125
10 271.2385559082031 200.28973388671875
11 271.56103515625 200.23284912109375
12 271.3169860839844 200.85321044921875
13 271.0174560546875 200.341796875
14 270.42578125 199.78273010253906
15 270.12249755859375 199.34695434570312
16 270.4627380371094 199.83856201171875
17 270.2785339355469 199.85655212402344
18 269.6317443847656 198.58682250976562
19 269.86083984375 199.27371215820312
20 269.41619873046875 199.0166473388672
21 269.682373046875 198.70680236816406
22 269.4401550292969 199.04991149902344
23 268.8383483886719 198.4703369140625
24 269.70831298828125 199.0587615966797
25 268.46966552734375 197

204 239.45858764648438 173.69772338867188
205 239.90162658691406 173.31170654296875
206 239.01512145996094 173.36692810058594
207 238.93096923828125 173.39190673828125
208 239.58755493164062 173.92198181152344
209 237.71556091308594 171.91567993164062
210 238.8876190185547 173.6354217529297
211 238.06248474121094 171.89210510253906
212 238.23985290527344 172.75665283203125
213 237.73245239257812 172.3212127685547
214 238.12937927246094 172.20640563964844
215 237.26800537109375 171.6285400390625
216 237.06976318359375 171.55555725097656
217 237.446044921875 172.01812744140625
218 236.9294891357422 171.8428192138672
219 236.49539184570312 170.55337524414062
220 236.3893280029297 170.74560546875
221 236.26641845703125 170.92840576171875
222 236.06829833984375 170.16537475585938
223 235.77337646484375 170.416259765625
224 235.6127471923828 170.2527313232422
225 235.33668518066406 169.62725830078125
226 235.5931396484375 170.37109375
227 235.12896728515625 169.72911071777344
228 234.9379730

405 212.02035522460938 150.70555114746094
406 211.53094482421875 150.3167266845703
407 211.41188049316406 150.3098907470703
408 211.32118225097656 149.91305541992188
409 211.55023193359375 150.4292755126953
410 211.08604431152344 149.91880798339844
411 210.96141052246094 149.70205688476562
412 210.8001251220703 149.38604736328125
413 210.59706115722656 149.553466796875
414 210.35357666015625 149.4113006591797
415 210.286376953125 149.16213989257812
416 210.00143432617188 148.99081420898438
417 209.9775390625 149.0963897705078
418 210.27816772460938 149.43783569335938
419 210.2119140625 148.95228576660156
420 210.7103729248047 149.75172424316406
421 209.58592224121094 148.6278076171875
422 209.46334838867188 148.54588317871094
423 209.5900115966797 148.64096069335938
424 209.95046997070312 148.82798767089844
425 210.58827209472656 149.81382751464844
426 209.19590759277344 148.34715270996094
427 209.13816833496094 148.2649383544922
428 209.7763671875 149.2461395263672
429 209.67114257812

605 190.9490203857422 133.55606079101562
606 191.1710968017578 133.76876831054688
607 190.6215057373047 132.7982940673828
608 190.00686645507812 132.66513061523438
609 190.11669921875 133.0387420654297
610 190.20204162597656 132.95278930664062
611 190.6278839111328 133.2840118408203
612 189.74057006835938 132.05377197265625
613 189.6992645263672 132.37925720214844
614 189.7307586669922 132.89341735839844
615 190.21925354003906 132.9923095703125
616 190.6023712158203 133.36004638671875
617 189.44223022460938 131.7746124267578
618 189.41464233398438 132.0253448486328
619 190.5427703857422 133.47891235351562
620 189.62014770507812 132.6376953125
621 189.2143096923828 132.3661346435547
622 188.90078735351562 131.5287628173828
623 189.34823608398438 132.0608673095703
624 189.04115295410156 132.13243103027344
625 188.844482421875 132.0927734375
626 188.75192260742188 131.72706604003906
627 188.53619384765625 130.78211975097656
628 188.24061584472656 131.0926513671875
629 188.24363708496094 1

805 174.59674072265625 120.3808364868164
806 173.74278259277344 119.81752014160156
807 174.22015380859375 119.88391876220703
808 173.3979949951172 118.97427368164062
809 174.0126495361328 119.86548614501953
810 173.8107452392578 119.76155853271484
811 173.59205627441406 119.69459533691406
812 173.58602905273438 119.5434341430664
813 172.9713134765625 118.82258605957031
814 173.98561096191406 119.79635620117188
815 173.5985565185547 119.57686614990234
816 173.3058624267578 119.31814575195312
817 173.17408752441406 118.8399658203125
818 174.24403381347656 119.65442657470703
819 174.98231506347656 120.0309829711914
820 177.8375701904297 122.220703125
821 177.93289184570312 123.48493194580078
822 174.5870361328125 119.77616119384766
823 177.31455993652344 121.42859649658203
824 175.6305694580078 121.62413787841797
825 173.58599853515625 118.65707397460938
826 174.0253448486328 119.29814910888672
827 175.8561553955078 121.67542266845703
828 172.86477661132812 118.37825775146484
829 173.6846

1004 160.3384246826172 108.88152313232422
1005 160.40232849121094 109.41190338134766
1006 160.137939453125 108.21102905273438
1007 160.37765502929688 109.15350341796875
1008 159.31849670410156 107.6144027709961
1009 159.81271362304688 108.35252380371094
1010 159.54742431640625 108.21207427978516
1011 160.9040069580078 109.06147766113281
1012 160.28392028808594 109.5420150756836
1013 159.60400390625 107.99201965332031
1014 158.62103271484375 107.30663299560547
1015 159.36526489257812 108.53913116455078
1016 159.05426025390625 107.79631805419922
1017 159.55142211914062 108.48489379882812
1018 158.43272399902344 107.08830261230469
1019 158.84942626953125 107.97462463378906
1020 158.48541259765625 107.41993713378906
1021 158.57009887695312 107.36164855957031
1022 159.22740173339844 108.44905853271484
1023 158.82879638671875 107.99757385253906
1024 158.857421875 108.18611907958984
1025 158.0026092529297 106.73713684082031
1026 158.13336181640625 106.9557876586914
1027 158.4843292236328 107.

1199 148.46629333496094 99.73311614990234
1200 149.17022705078125 100.58901977539062
1201 148.08811950683594 99.93183898925781
1202 148.7703857421875 100.4800033569336
1203 147.91639709472656 99.42465209960938
1204 148.27789306640625 99.90087890625
1205 147.83236694335938 99.80708312988281
1206 147.76817321777344 99.5570297241211
1207 147.79396057128906 99.46680450439453
1208 147.43153381347656 99.16230773925781
1209 147.54019165039062 99.50897979736328
1210 147.30931091308594 99.02588653564453
1211 147.09654235839844 98.88038635253906
1212 147.25648498535156 99.28862762451172
1213 146.89462280273438 98.71074676513672
1214 147.0923309326172 98.86119079589844
1215 146.70582580566406 98.39602661132812
1216 146.76771545410156 98.65730285644531
1217 146.84693908691406 98.92440032958984
1218 146.54098510742188 98.37097930908203
1219 146.67311096191406 98.49398803710938
1220 146.63731384277344 98.45095825195312
1221 147.10716247558594 99.34565734863281
1222 147.2189178466797 98.9953384399414

1398 139.1291961669922 93.94102478027344
1399 138.41323852539062 92.82249450683594
1400 138.73487854003906 93.91596221923828
1401 137.994873046875 92.64462280273438
1402 138.83157348632812 93.34588623046875
1403 138.27796936035156 93.22704315185547
1404 139.72386169433594 94.48747253417969
1405 139.3846435546875 94.62467193603516
1406 138.60572814941406 93.52590942382812
1407 137.8629608154297 92.78009033203125
1408 138.09268188476562 92.9134750366211
1409 138.3255157470703 93.1946792602539
1410 138.3124237060547 93.36114501953125
1411 139.52394104003906 94.16786193847656
1412 139.20681762695312 94.40058135986328
1413 138.8223876953125 93.99844360351562
1414 138.24029541015625 93.3433609008789
1415 137.82723999023438 92.66337585449219
1416 137.17198181152344 91.79549407958984
1417 137.54049682617188 92.52616119384766
1418 137.7568359375 92.19866180419922
1419 138.11416625976562 93.62140655517578
1420 137.78909301757812 92.81314849853516
1421 138.18992614746094 93.45779418945312
1422 13

1597 130.92037963867188 88.99141693115234
1598 132.4929962158203 89.36418914794922
1599 131.73594665527344 89.57096862792969
1600 130.26173400878906 87.90912628173828
1601 132.5745391845703 90.26834869384766
1602 131.09503173828125 88.51387023925781
1603 133.103515625 91.2822036743164
1604 132.2488555908203 90.0224838256836
1605 131.62582397460938 89.05835723876953
1606 132.6194305419922 90.63753509521484
1607 130.34654235839844 88.80036163330078
1608 133.09622192382812 91.40516662597656
1609 130.3072052001953 88.57563781738281
1610 132.41131591796875 90.53435516357422
1611 130.427978515625 88.25836181640625
1612 130.5752410888672 88.96041870117188
1613 131.39450073242188 89.28691101074219
1614 130.48130798339844 88.56818389892578
1615 129.93251037597656 88.00931549072266
1616 130.77676391601562 89.36349487304688
1617 129.06614685058594 86.79096221923828
1618 130.2446746826172 88.19248962402344
1619 129.56785583496094 87.80134582519531
1620 129.06893920898438 87.89669036865234
1621 129

1795 125.57569885253906 86.2852783203125
1796 127.35929870605469 88.14125061035156
1797 125.22235107421875 86.19430541992188
1798 126.01432800292969 86.77001953125
1799 124.84109497070312 86.07984924316406
1800 124.13787078857422 84.55838775634766
1801 125.6111831665039 85.52264404296875
1802 124.65687561035156 85.58899688720703
1803 124.38614654541016 85.24848937988281
1804 124.9825210571289 85.70516204833984
1805 125.11815643310547 85.97432708740234
1806 124.61065673828125 84.97535705566406
1807 124.34251403808594 84.19889068603516
1808 126.15084838867188 86.66463470458984
1809 124.29841613769531 85.2948989868164
1810 126.42147827148438 86.92237854003906
1811 126.33612060546875 86.9515609741211
1812 126.7012710571289 88.90938568115234
1813 126.63148498535156 88.65696716308594
1814 125.63407897949219 86.40997314453125
1815 126.20293426513672 87.16549682617188
1816 124.86512756347656 85.82388305664062
1817 125.04318237304688 85.44197082519531
1818 124.7743148803711 86.47364044189453
18

1993 118.24585723876953 82.0374984741211
1994 119.91090393066406 83.5634765625
1995 118.16361236572266 81.93636322021484
1996 119.22080993652344 82.60604858398438
1997 120.14643859863281 83.79107666015625
1998 119.56328582763672 83.48152923583984
1999 118.8980712890625 83.64216613769531


In [65]:
test_pred = torch.split(test_pr, test_aggregate.shape[0])
test_fold = [None for x in range(len(ORDER))]
if cuda_av:
    for appliance_num, appliance in enumerate(ORDER):
        test_fold[appliance_num] = test_pred[appliance_num].cpu().data.numpy().reshape(-1, 24)
else:
    for appliance_num, appliance in enumerate(ORDER):
        test_fold[appliance_num] = test_pred[appliance_num].data.numpy().reshape(-1, 24)

In [66]:
test_gt_fold = [None for x in range(len(ORDER))]
for appliance_num, appliance in enumerate(ORDER):
    test_gt_fold[appliance_num] = test[:, APPLIANCE_ORDER.index(appliance), :, :].reshape(
        test_aggregate.shape[0],
        -1, 1).reshape(-1, 24)

In [68]:
test_error = {}

for appliance_num, appliance in enumerate(ORDER):
    test_error[appliance] = mean_absolute_error(test_fold[appliance_num], test_gt_fold[appliance_num])

In [69]:
test_error

{'dr': 39.242717150844463, 'hvac': 128.04161381433951}