In [102]:
import sys
sys.path.insert(0, "../../python")
import mxnet as mx
import numpy as np
from collections import namedtuple
import time
import math
LSTMState = namedtuple("LSTMState", ["c", "h"])
LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
                                     "h2h_weight", "h2h_bias"])
LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol",
                                     "init_states", "last_states",
                                     "seq_data", "seq_labels", "seq_outputs",
                                     "param_blocks"])

def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.):
    """LSTM Cell symbol"""
    if dropout > 0.:
        indata = mx.sym.Dropout(data=indata, p=dropout)
    i2h = mx.sym.FullyConnected(data=indata,
                                weight=param.i2h_weight,
                                bias=param.i2h_bias,
                                num_hidden=num_hidden * 4,
                                name="t%d_l%d_i2h" % (seqidx, layeridx))
    h2h = mx.sym.FullyConnected(data=prev_state.h,
                                weight=param.h2h_weight,
                                bias=param.h2h_bias,
                                num_hidden=num_hidden * 4,
                                name="t%d_l%d_h2h" % (seqidx, layeridx))
    gates = i2h + h2h
    slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
                                      name="t%d_l%d_slice" % (seqidx, layeridx))
    in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
    in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
    forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
    out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
    next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
    next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
    return LSTMState(c=next_c, h=next_h)


# we define a new unrolling function here because the original
# one in lstm.py concats all the labels at the last layer together,
# making the mini-batch size of the label different from the data.
# I think the existing data-parallelization code need some modification
# to allow this situation to work properly
def lstm_unroll(num_lstm_layer, seq_len, input_size,
                num_hidden, num_embed, num_label, dropout=0.):

    embed_weight = mx.sym.Variable("embed_weight")
    cls_weight = mx.sym.Variable("cls_weight")
    cls_bias = mx.sym.Variable("cls_bias")
    param_cells = []
    last_states = []
    for i in range(num_lstm_layer):
        param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
                                     i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
                                     h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
                                     h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
        state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
                          h=mx.sym.Variable("l%d_init_h" % i))
        last_states.append(state)
    assert(len(last_states) == num_lstm_layer)

    # embeding layer
    data = mx.sym.Variable('data')
    label = mx.sym.Variable('softmax_label')
    embed = mx.sym.Embedding(data=data, input_dim=input_size,
                             weight=embed_weight, output_dim=num_embed, name='embed')
    wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1)
    #wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1)

    hidden_all = []
    for seqidx in range(seq_len):
        hidden = wordvec[seqidx]

        # stack LSTM
        for i in range(num_lstm_layer):
            if i == 0:
                dp_ratio = 0.
            else:
                dp_ratio = dropout
            next_state = lstm(num_hidden, indata=hidden,
                              prev_state=last_states[i],
                              param=param_cells[i],
                              seqidx=seqidx, layeridx=i, dropout=dp_ratio)
            hidden = next_state.h
            last_states[i] = next_state
        # decoder
        if dropout > 0.:
            hidden = mx.sym.Dropout(data=hidden, p=dropout)
        hidden_all.append(hidden)

    hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
    pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label,
                                 weight=cls_weight, bias=cls_bias, name='pred')

    ################################################################################
    # Make label the same shape as our produced data path
    # I did not observe big speed difference between the following two ways

    label = mx.sym.transpose(data=label)
    label = mx.sym.Reshape(data=label, target_shape=(0,))

    #label_slice = mx.sym.SliceChannel(data=label, num_outputs=seq_len)
    #label = [label_slice[t] for t in range(seq_len)]
    #label = mx.sym.Concat(*label, dim=0)
    #label = mx.sym.Reshape(data=label, target_shape=(0,))
    ################################################################################

    sm = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')

    return sm

def lstm_inference_symbol(num_lstm_layer, input_size,
                          num_hidden, num_embed, num_label, dropout=0.):
    seqidx = 0
    embed_weight=mx.sym.Variable("embed_weight")
    cls_weight = mx.sym.Variable("cls_weight")
    cls_bias = mx.sym.Variable("cls_bias")
    param_cells = []
    last_states = []
    for i in range(num_lstm_layer):
        param_cells.append(LSTMParam(i2h_weight = mx.sym.Variable("l%d_i2h_weight" % i),
                                      i2h_bias = mx.sym.Variable("l%d_i2h_bias" % i),
                                      h2h_weight = mx.sym.Variable("l%d_h2h_weight" % i),
                                      h2h_bias = mx.sym.Variable("l%d_h2h_bias" % i)))
        state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
                          h=mx.sym.Variable("l%d_init_h" % i))
        last_states.append(state)
    assert(len(last_states) == num_lstm_layer)
    data = mx.sym.Variable("data")

    hidden = mx.sym.Embedding(data=data,
                              input_dim=input_size,
                              output_dim=num_embed,
                              weight=embed_weight,
                              name="embed")
    # stack LSTM
    for i in range(num_lstm_layer):
        if i==0:
            dp=0.
        else:
            dp = dropout
        next_state = lstm(num_hidden, indata=hidden,
                          prev_state=last_states[i],
                          param=param_cells[i],
                          seqidx=seqidx, layeridx=i, dropout=dp)
        hidden = next_state.h
        last_states[i] = next_state
    # decoder
    if dropout > 0.:
        hidden = mx.sym.Dropout(data=hidden, p=dropout)
    fc = mx.sym.FullyConnected(data=hidden, num_hidden=num_label,
                               weight=cls_weight, bias=cls_bias, name='pred')
    sm = mx.sym.SoftmaxOutput(data=fc, name='softmax')
    output = [sm]
    for state in last_states:
        output.append(state.c)
        output.append(state.h)
    return mx.sym.Group(output)

In [103]:
### Build a 2-layer LSTM network for a sequence length of 10
num_hidden = 200
num_embed = 200
num_lstm_layer = 2
input_size = 100
output_size = 1
def sym_gen(seq_len):
    return lstm_unroll(num_lstm_layer, seq_len, input_size,
                       num_hidden=num_hidden, num_embed=num_embed,
                       num_label=output_size)

In [104]:
seq_len = 10
lstm = sym_gen(seq_len).save("/tmp/lstm.json")

In [105]:
import sys, os
import h2o
from h2o.estimators.deepwater import H2ODeepWaterEstimator
import os.path

In [106]:
h2o.init(nthreads=-1)

Checking whether there is an H2O instance running at http://localhost:54321. connected.


0,1
H2O cluster uptime:,04 secs
H2O cluster version:,3.11.0.99999
H2O cluster version age:,1 day
H2O cluster name:,arno
H2O cluster total nodes:,1
H2O cluster free memory:,13.96 Gb
H2O cluster total cores:,12
H2O cluster allowed cores:,12
H2O cluster status:,"accepting new members, healthy"
H2O connection url:,http://localhost:54321


In [107]:
frame = h2o.create_frame(cols=100,binary_fraction=0,missing_fraction=0,categorical_fraction=0,
                         has_response=True, response_factors=1)

Create Frame progress: |██████████████████████████████████████████████████████████████████████| 100%


In [108]:
frame

response,C1,C2,C3,C4,C5,C6,C7,C8,C9,C10,C11,C12,C13,C14,C15,C16,C17,C18,C19,C20,C21,C22,C23,C24,C25,C26,C27,C28,C29,C30,C31,C32,C33,C34,C35,C36,C37,C38,C39,C40,C41,C42,C43,C44,C45,C46,C47,C48,C49,C50,C51,C52,C53,C54,C55,C56,C57,C58,C59,C60,C61,C62,C63,C64,C65,C66,C67,C68,C69,C70,C71,C72,C73,C74,C75,C76,C77,C78,C79,C80,C81,C82,C83,C84,C85,C86,C87,C88,C89,C90,C91,C92,C93,C94,C95,C96,C97,C98,C99,C100
96.7365,-63.1,-47.6417,46.5744,-6,-68.885,96.9477,-89,86,9.99778,-20,38.4544,-10,-3.54834,-72.9411,90.1386,-41,-17.1353,-21.4727,88,47.3278,-18.2263,-88.9352,-37,-56.3288,-1.9214,5.57803,-69.0953,27.1838,99.9547,85.4753,73,-12.7777,72.7898,-9.24588,45.6384,-93.0865,-15.9331,-13.1023,-64.2107,23,-51.6539,27.7314,31.7985,-79.5775,-99.8303,-57.7458,95.7458,-57,-6.41109,-71,-38,73,-27.6499,1.98144,58,80,-56.1474,-18.7159,98,-8,4.83277,7,-22.3071,95.9057,95.7646,70,-6.61345,-40.1997,66.4516,38.9605,93.8146,84.1604,-16,40,25,-91.3696,-93.6916,47.1687,22.2209,84.7104,-11,-74.8743,-30.5307,81,-7.42803,61.0709,-17.5682,-38.659,-34.4877,-17.4602,-39.3993,96.5913,-88.3433,-56,83.2917,55.2055,16.9876,35,-71,80.2196
80.2196,64.9665,2.00453,32.6502,50,-27.6605,-72.0948,-42,64,-76.2626,-66,-16.7308,-83,74.1814,-61.4134,13.2972,22,80.9712,-41.2175,74,-88.3022,-28.3357,79.4659,34,45.777,34.6672,27.7169,46.8899,-20.4999,58.5507,84.6901,52,-35.9207,56.2729,63.7637,76.6114,-73.3252,-57.3884,-34.5031,91.6094,-78,-76.3004,24.2685,51.5267,5.4758,1.32537,33.2929,80.6263,-85,74.2339,81,-62,-16,-46.6101,-5.06858,-39,-32,72.7696,12.5647,11,-56,74.7774,12,-13.4609,-63.7515,80.4971,-26,37.8467,-0.0160082,86.2291,-69.0556,-57.0797,-67.996,20,55,-77,90.9576,-55.1556,-55.5587,-33.4265,58.9762,-25,-61.0377,30.597,-60,33.2487,57.4991,36.5986,68.6564,60.5324,-62.5551,89.5762,-56.6127,63.8185,-93,7.00835,8.6771,-70.0697,84,9,87.02
87.02,8.3077,-40.4146,-38.7572,-78,39.2995,62.879,-90,-18,68.5036,35,-63.888,13,68.746,29.5101,38.0018,30,96.3695,5.22922,-91,-16.4422,37.1118,45.9995,3,-97.3549,63.1142,-19.5494,37.1899,98.8963,35.7732,6.51401,-25,-18.0164,-48.1853,88.6309,39.2518,-85.0542,-29.536,-30.4109,71.0317,-76,28.2696,8.58083,-18.7272,-67.491,-99.9378,85.392,68.9017,-20,45.8495,34,34,-77,-27.1306,-95.545,-83,-95,-91.0532,-15.1906,-8,-92,16.4308,-23,-95.4656,46.0312,18.1831,-6,-46.28,-89.5301,90.1433,79.8685,-42.7771,-29.5158,-66,24,48,-91.4428,37.0413,-40.4754,53.7552,76.3168,82,-23.3425,43.3854,-82,63.3403,-79.5318,58.126,-13.5294,-29.1723,59.8394,-29.0674,95.3866,59.8875,18,-91.161,-50.4076,-43.8731,48,68,18.9306
18.9306,51.5524,-27.8915,2.39198,-59,70.7798,6.98306,-73,-97,-65.9225,27,-6.77124,26,-1.41661,-8.41517,58.7498,-46,-25.1891,-78.4035,16,-44.8241,10.7573,-95.7129,-59,40.4209,42.5928,-42.8793,43.6818,-24.1504,-79.6651,66.2368,81,-85.0853,-31.0196,28.7591,-14.4936,65.5248,-14.5567,-69.0826,-31.509,24,7.54629,95.4029,24.644,12.0591,-87.501,87.7862,-14.3206,-26,37.554,-12,54,-95,-48.8596,-23.3127,-84,-78,59.3215,-73.8663,53,-91,76.5399,-46,-22.0482,80.9321,64.61,-59,-35.6819,72.4878,51.6868,5.35061,-13.6894,55.4389,-89,-49,9,31.3214,0.486513,-91.9354,89.8814,-5.90396,48,63.2368,-17.5378,27,-44.9251,85.8133,-80.8117,-99.3299,-80.935,-55.4165,-59.2884,-9.92572,83.9287,-54,90.8215,-5.84039,-83.9271,77,-3,68.2784
68.2784,-92.1579,57.9068,62.8852,-65,-55.2837,-59.6201,38,-46,36.1635,97,25.7822,81,-55.1398,-84.1493,-21.3195,12,-40.7044,-55.8637,72,-10.4243,-82.1317,-22.7336,75,24.2086,9.27167,25.2063,13.4561,29.6597,84.3318,-20.0071,-3,-76.573,47.6092,73.5785,-64.106,-5.66817,73.0423,-18.8753,8.74555,83,-29.0357,37.0468,22.8334,-72.4816,30.8898,-27.5817,70.5949,47,72.2814,-71,68,-41,-15.8581,-84.9483,97,-28,98.9827,-2.64886,72,-37,-54.0482,89,-4.85298,1.8985,-58.5264,71,-62.5459,-76.9662,-10.246,10.5572,54.0291,-71.2033,-52,9,-14,84.2678,-73.7258,2.17149,-66.3275,95.6223,65,8.04374,5.79764,-12,10.6401,-21.3786,-46.2408,60.2095,7.75586,6.45017,86.7454,68.8757,-87.9177,35,0.352519,-68.0268,22.1385,-86,-11,80.0716
80.0716,62.0426,72.7915,75.7595,-58,86.5048,95.6307,60,-47,60.5242,-12,11.5251,17,41.0371,98.3555,-87.582,-27,48.2181,-64.7353,5,68.5704,48.573,48.1119,-79,7.32828,-68.9459,45.5623,59.2843,41.8177,-80.9009,81.9697,39,-27.7881,-96.6256,-77.1793,-35.0056,-53.6352,35.2003,-74.7817,-83.8736,-78,-54.249,31.0844,-13.2023,40.6421,-75.3586,10.1055,-80.9009,2,37.5853,-85,24,-55,-12.2767,89.8525,-9,64,55.5448,-71.2419,-4,17,19.6694,98,7.34388,-39.2241,-2.24953,59,-67.9681,74.955,29.1898,-49.477,-52.5628,-47.934,33,48,-100,13.984,-31.176,-79.7664,60.06,28.8325,57,3.96136,-9.48581,96,-4.80689,-98.32,35.3278,99.9754,-1.32057,34.8858,75.4466,-81.4732,78.2463,35,33.7557,-32.6836,-18.1566,-65,-1,23.8786
23.8786,-15.8316,48.4462,25.4972,-12,8.55949,-85.2158,-85,-33,53.8978,13,-85.1984,25,-75.3867,70.2064,26.1099,66,77.5901,-36.471,-48,-42.4893,85.8797,-80.3715,-10,10.2968,-63.4214,59.4654,22.2456,-92.6651,7.19019,22.7035,76,-20.7429,-80.1269,-51.1768,29.5199,3.61183,-52.0439,59.1579,24.5468,-64,30.118,-39.758,-86.1529,-79.6123,20.3741,89.6231,-44.3597,23,-29.7203,34,18,27,72.4801,61.632,90,55,-5.2798,-6.38671,-79,-83,-88.4431,6,18.0785,1.29855,52.3517,-21,72.5362,54.8591,-38.2625,84.5683,-35.7829,31.1437,28,-86,-84,-29.8228,-16.7207,8.16542,56.6053,30.4332,53,-22.2883,-93.1966,23,21.4129,-39.8963,-96.4321,3.95247,14.0111,64.5449,-84.2632,28.9514,55.8052,15,10.4123,-96.8416,-24.916,88,5,15.2426
15.2426,-82.0204,34.5112,-6.49161,-54,-79.6877,42.0253,-91,-40,78.3511,64,-5.7329,-23,43.2067,2.15528,-56.4426,59,-50.7397,-12.4803,-81,-26.2191,-30.5429,49.9745,66,58.5258,9.00132,-49.8233,-20.6065,64.0135,-25.0485,38.2829,-12,-57.8023,-34.1596,31.0128,-23.0026,14.6701,-24.2864,86.2343,-99.8598,-26,51.8444,27.2101,39.7756,-74.9601,-20.1254,-75.7224,9.63159,-56,-56.6146,-79,8,-92,-59.7287,-74.5252,-55,-45,33.1544,88.4236,96,-7,-46.4467,38,68.2664,-19.5659,46.3823,30,42.8921,-56.4746,28.7968,88.7903,6.72548,-77.7896,55,5,71,33.6897,-25.386,66.6122,-47.6384,-63.6182,97,-78.2067,3.09561,100,-55.0149,46.0723,52.5333,21.654,86.1827,8.753,-50.9718,-86.9612,-7.39168,65,6.58803,95.1427,-56.5418,-31,-10,-63.5082
-63.5082,-36.7488,45.3556,-35.5629,-7,22.2262,-12.0369,-41,45,-29.5911,72,-94.7021,-57,-22.2863,-73.1568,96.5711,-62,-77.819,-56.0461,-27,-96.0299,-32.1678,-88.7937,1,-43.4573,17.1982,99.2725,-18.7636,-41.3647,74.5974,17.1087,93,65.3622,-82.3422,-54.5502,3.93187,68.9357,67.0794,93.759,-29.78,-67,-52.8796,-26.4282,-34.8039,29.0259,5.34485,21.6008,21.037,-73,36.1429,-24,-76,51,-27.6194,57.4163,20,16,77.4522,-53.9499,98,66,-45.3066,77,10.9849,33.7346,-58.6457,48,1.80895,-64.5703,-4.72468,76.2207,39.6972,26.3991,98,31,23,-69.807,-6.57768,-16.1716,-28.9209,83.1149,3,66.589,-25.1312,75,58.059,-1.4891,-91.9537,64.8425,97.5772,-15.5862,-34.6945,17.6513,67.7186,-28,52.3945,-48.3118,-25.1681,17,-89,95.7437
95.7437,-1.77429,4.73747,-29.4937,-5,-24.5009,15.0896,22,-68,-48.6214,71,-81.0229,79,87.7459,-83.3655,-12.8973,30,-52.3515,-45.1389,-9,-1.18014,-37.3663,3.84566,-54,-53.0538,-46.7009,-42.1595,-30.0414,53.832,-42.1825,-74.5234,35,-86.4942,57.0463,-63.4034,-55.7669,27.7981,3.00889,-91.7048,-75.4427,-64,82.0869,-27.8453,-77.6114,67.4127,18.6908,-62.9398,-59.7133,-85,89.9131,-80,67,61,-85.5882,-35.3335,-52,38,43.1189,-59.2282,72,58,80.5106,12,48.7146,23.5186,64.487,-43,99.2664,2.2822,3.96507,1.35055,0.117817,-21.8003,86,63,-20,80.5345,-4.24406,-47.3694,-49.3457,-78.7091,-25,-28.8267,-0.128148,93,-3.71066,-14.4566,-24.4394,-23.2254,-75.7661,-24.1967,-47.3814,43.0861,11.5987,77,42.8061,76.2937,-66.0331,97,-14,78.2598




In [109]:
model = H2ODeepWaterEstimator(network_definition_file="/tmp/lstm.json")
model.train(x=list(range(1,101)), y="response", training_frame=frame)
model.show()

deepwater Model Build progress: |

H2OConnectionError: Unexpected HTTP error: HTTPConnectionPool(host='localhost', port=54321): Max retries exceeded with url: /3/Jobs/$0301c0a80b0232d4ffffffff$_ac916bab89fe048a6feba849192a406f (Caused by NewConnectionError('<requests.packages.urllib3.connection.HTTPConnection object at 0x7f5c5d921e90>: Failed to establish a new connection: [Errno 111] Connection refused',))