Skip to content

Commit

Permalink
[Relay] DQN Port (apache#2009)
Browse files Browse the repository at this point in the history
  • Loading branch information
joshpoll authored and eqy committed Oct 29, 2018
1 parent 581d248 commit b00adb6
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/tvm/relay/testing/__init__.py
Expand Up @@ -6,3 +6,4 @@
from . import dqn
from . import dcgan
from . import mobilenet
from . import dqn
1 change: 0 additions & 1 deletion python/tvm/relay/testing/dqn.py
Expand Up @@ -30,7 +30,6 @@ def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"
"""get symbol of nature dqn"""
data_shape = (batch_size,) + image_shape
data = relay.var("data", shape=data_shape, dtype=dtype)

conv1_bias = relay.var("conv1_bias")
conv1 = layers.conv2d(data, kernel_size=(8, 8), strides=(4, 4), padding=(0, 0),
channels=32, name="conv1")
Expand Down
4 changes: 4 additions & 0 deletions tests/python/relay/test_ir_text_printer.py
Expand Up @@ -115,6 +115,10 @@ def test_mobilenet():
net, params = tvm.relay.testing.mobilenet.get_workload(batch_size=1)
net.astext()

def test_dqn():
net, params = tvm.relay.testing.dqn.get_workload(batch_size=1)
show(net.astext())

if __name__ == "__main__":
do_print[0] = True
test_resnet()
Expand Down

0 comments on commit b00adb6

Please sign in to comment.