In [38]:
# Make the cells wider in the browser window
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

from importlib import reload
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import waveglow_model as model
import pandas as pd
import numpy as np

In [12]:
# Import the data
trainset = pd.read_pickle("wind_power_data/wind_power_train.pickle").values
testset = pd.read_pickle("wind_power_data/wind_power_test.pickle").values

In [191]:
# Try instantiating a network
reload(model)
net = model.WaveGlow(
    n_context_channels=96, 
    n_flows=6, 
    n_group=24, 
    n_early_every=3,
    n_early_size=8,
    n_layers=2,
    dilation_list=[1,2],
    n_channels=96,
    kernel_size=3);

Channels:  24
Channels:  24
Channels:  24
Channels:  16
Channels:  16
Channels:  16


In [31]:
trainset.shape

(70080,)

In [81]:
samp = trainset[:96]
samp = np.reshape(samp, (1,96))
print(samp.shape)

samp_torch = torch.autograd.Variable(torch.FloatTensor(samp))

context = samp[:, :, None]
print(context.shape)
context_torch = torch.autograd.Variable(torch.FloatTensor(context))

samp_torch.unfold(1, 24, 24).shape

(1, 96)
(1, 96, 1)


torch.Size([1, 4, 24])

In [107]:
context.shape

(1, 96, 1)

In [192]:
z, log_s_list, log_det_w_list, early_out_shapes = net(samp_torch, context_torch)

In [189]:
z.shape

torch.Size([1, 24, 4])

In [193]:
forecast = net.generate(context_torch, latent_z=z, early_assignment_shapes=early_out_shapes)

In [202]:
forecast

tensor([[ 0.0760,  0.0600,  0.0600,  0.0640,  0.0620,  0.0530,  0.0670,  0.1120,
          0.3530,  0.5690,  0.7650,  0.8170,  0.7950,  1.0850,  1.1670,  1.4590,
          1.5330,  1.4970,  1.6760,  1.8240,  2.0530,  2.7030,  3.1530,  4.1340,
          4.6270,  5.7560,  6.8790,  8.1850,  8.8390,  9.0900, 11.5000, 13.3160,
         14.6710, 15.1870, 15.5030, 15.4410, 15.5680, 15.8660, 15.9200, 15.9150,
         15.9010, 15.8920, 15.9310, 15.9300, 15.9590, 15.9880, 15.9780, 15.9920,
         15.9940, 15.9940, 15.9950, 15.9920, 15.9940, 15.9970, 15.9930, 15.9940,
         15.9900, 15.9530, 15.9420, 15.4970, 13.5840, 10.8160,  8.6550,  4.9210,
          2.5130,  1.7040,  1.6020,  1.5680,  1.2830,  1.0180,  0.9310,  1.4980,
          1.5950,  1.4840,  1.4300,  1.4200,  1.0690,  0.8060,  0.7370,  0.5880,
          0.3710,  0.1410,  0.0590,  0.0880,  0.2560,  0.3530,  0.2980,  0.3040,
          0.4070,  0.5120,  0.4870,  0.6190,  0.6440,  0.8720,  1.4810,  1.6690]])

In [203]:
samp_torch

tensor([[ 0.0760,  0.0600,  0.0600,  0.0640,  0.0620,  0.0530,  0.0670,  0.1120,
          0.3530,  0.5690,  0.7650,  0.8170,  0.7950,  1.0850,  1.1670,  1.4590,
          1.5330,  1.4970,  1.6760,  1.8240,  2.0530,  2.7030,  3.1530,  4.1340,
          4.6270,  5.7560,  6.8790,  8.1850,  8.8390,  9.0900, 11.5000, 13.3160,
         14.6710, 15.1870, 15.5030, 15.4410, 15.5680, 15.8660, 15.9200, 15.9150,
         15.9010, 15.8920, 15.9310, 15.9300, 15.9590, 15.9880, 15.9780, 15.9920,
         15.9940, 15.9940, 15.9950, 15.9920, 15.9940, 15.9970, 15.9930, 15.9940,
         15.9900, 15.9530, 15.9420, 15.4970, 13.5840, 10.8160,  8.6550,  4.9210,
          2.5130,  1.7040,  1.6020,  1.5680,  1.2830,  1.0180,  0.9310,  1.4980,
          1.5950,  1.4840,  1.4300,  1.4200,  1.0690,  0.8060,  0.7370,  0.5880,
          0.3710,  0.1410,  0.0590,  0.0880,  0.2560,  0.3530,  0.2980,  0.3040,
          0.4070,  0.5120,  0.4870,  0.6190,  0.6440,  0.8720,  1.4810,  1.6690]])

In [204]:
forecast - samp_torch

tensor([[-1.1548e-06, -7.8604e-07,  1.0245e-06,  4.0978e-07,  4.0978e-07,
          1.6689e-06, -2.3842e-07,  5.4389e-07,  3.8743e-07,  5.3644e-07,
         -5.3644e-07,  7.1526e-07, -3.5763e-07, -4.7684e-07,  0.0000e+00,
         -2.9802e-06, -3.5763e-06,  4.7684e-07, -4.7684e-07,  0.0000e+00,
          7.1526e-07,  1.4305e-06, -4.7684e-07,  0.0000e+00,  1.4305e-06,
         -1.9073e-06, -8.5831e-06, -9.5367e-07,  1.4305e-05,  1.6212e-05,
         -1.2398e-05,  1.4305e-05,  1.5259e-05,  5.7220e-06,  4.7684e-06,
          1.9073e-06, -2.8610e-06, -4.7684e-06,  7.6294e-06, -1.1444e-05,
         -8.5831e-06,  9.5367e-06, -7.6294e-06, -9.5367e-07,  2.2888e-05,
          5.7220e-06,  1.9073e-06,  8.5831e-06, -9.5367e-07, -9.5367e-07,
         -7.6294e-06, -4.7684e-06,  3.8147e-06,  8.5831e-06, -1.2398e-05,
          6.6757e-06,  1.0490e-05,  2.8610e-06, -3.8147e-06, -1.9073e-06,
          2.8610e-06,  9.5367e-07,  3.8147e-06, -6.6757e-06, -9.5367e-07,
         -5.9605e-07, -4.1723e-06,  4.

In [205]:
torch.allclose(forecast, samp_torch, atol=1e-5)

True