In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import ops

In [2]:
def selu(x):
    with ops.name_scope('elu') as scope:
        alpha = 1.6732632423543772848170429916717
        scale = 1.0507009873554804934193349852946
        return scale*tf.where(x>=0.0, x, alpha*tf.nn.elu(x))
    
def init(*shape):
    return np.zeros(shape, dtype='float32')
    
def BatchNorm(input, channel=8):
    with tf.variable_scope('BatchNorm'):
        weight = tf.Variable(init(channel), name='weight')
        bias = tf.Variable(init(channel), name='bias')
        mean = tf.Variable(init(channel), name='running_mean')
        var = tf.Variable(init(channel), name='running_var')
        return tf.nn.batch_normalization(input, mean, var, bias, weight, 1e-05)

def Conv1d(input, in_channel, out_channel, kernel_size, dilation=1, bias=False):
    with tf.variable_scope('Conv1d'):
        w = tf.Variable(init(kernel_size, in_channel, out_channel), name='weight')
        if dilation > 1:
            w = tf.expand_dims(w, 0)
            x = tf.expand_dims(input, 1)
            x = tf.nn.atrous_conv2d(x, w, dilation, 'SAME')
            x = tf.squeeze(x, 1)
        else:
            x = tf.nn.conv1d(input, w, 1, 'SAME')
        if bias:
            b = tf.Variable(init(out_channel), name='bias')
            x = x + b
    return x
    
def MaxPool1d(input):
    with tf.variable_scope('MaxPool1d'):
        x = tf.expand_dims(input, 1)
        x = tf.nn.max_pool(x, [1, 1, 2, 1], [1, 1, 2, 1], 'SAME')
        x = tf.squeeze(x, 1)
    return x
    

def Encoder(input, init_channel):
    def DownSampleBlock(input, in_channel, out_channel):
        with tf.variable_scope('DownSampleBlock'):
            x = Conv1d(input, in_channel, out_channel, 7, bias=True)
            x = BatchNorm(x, out_channel)
            x = selu(x)
            x = MaxPool1d(x)
        return x    
    with tf.variable_scope('Encoder'):
        x = DownSampleBlock(input, 1, init_channel)
        x = DownSampleBlock(x, init_channel, init_channel*2)
        x = DownSampleBlock(x, init_channel*2, init_channel*4)
        x = DownSampleBlock(x, init_channel*4, init_channel*8)
    return x
    
def ResNet(input, channel):
    def DilatedBlock(input, channel=8, kernel_size=9, dilation=2):
        # No change in # of channels -> identity mapping
        with tf.variable_scope('DilatedBlock'):
            x = BatchNorm(input, channel)
            x = Conv1d(x, channel, channel, kernel_size)
            x = selu(x)
            x = BatchNorm(x, channel)
            x = Conv1d(x, channel, channel, kernel_size, dilation)
            x = selu(x)
        return x + input

    with tf.variable_scope('ResNet'):
        x = DilatedBlock(input, channel)
        for _ in range(8):
            x = DilatedBlock(x, channel)
        
    return x

def NET(input, init_channel=32):
    x = Encoder(input, init_channel)
    x = ResNet(x, init_channel*8)
    with tf.variable_scope('Logit'):
        logit = Conv1d(x, init_channel*8, 4, 1)
    return logit

In [6]:
tf.reset_default_graph()
x = tf.placeholder(1, name='input')
y1 = NET(x, 32)
summary_writer = tf.summary.FileWriter('/tmp/model/', graph=tf.get_default_graph())
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    print(y1.eval({x:np.ones([10, 256, 1])}).shape)

(10, 16, 4)


In [4]:
for i, a in enumerate(tf.global_variables()):
    print(a.name)

Encoder/DownSampleBlock/Conv1d/weight:0
Encoder/DownSampleBlock/Conv1d/bias:0
Encoder/DownSampleBlock/BatchNorm/weight:0
Encoder/DownSampleBlock/BatchNorm/bias:0
Encoder/DownSampleBlock/BatchNorm/running_mean:0
Encoder/DownSampleBlock/BatchNorm/running_var:0
Encoder/DownSampleBlock_1/Conv1d/weight:0
Encoder/DownSampleBlock_1/Conv1d/bias:0
Encoder/DownSampleBlock_1/BatchNorm/weight:0
Encoder/DownSampleBlock_1/BatchNorm/bias:0
Encoder/DownSampleBlock_1/BatchNorm/running_mean:0
Encoder/DownSampleBlock_1/BatchNorm/running_var:0
Encoder/DownSampleBlock_2/Conv1d/weight:0
Encoder/DownSampleBlock_2/Conv1d/bias:0
Encoder/DownSampleBlock_2/BatchNorm/weight:0
Encoder/DownSampleBlock_2/BatchNorm/bias:0
Encoder/DownSampleBlock_2/BatchNorm/running_mean:0
Encoder/DownSampleBlock_2/BatchNorm/running_var:0
Encoder/DownSampleBlock_3/Conv1d/weight:0
Encoder/DownSampleBlock_3/Conv1d/bias:0
Encoder/DownSampleBlock_3/BatchNorm/weight:0
Encoder/DownSampleBlock_3/BatchNorm/bias:0
Encoder/DownSampleBlock_3/Bat

In [5]:
len(tf.global_variables())

115

In [8]:
import dilated_model as dm
import torch as th
sd = dm.EncodeWideResNet(in_channel=1, init_channel=32, 
    num_enc_layer=4, N_res_in_block=1, use_selu=True)


In [11]:
res = sd(th.autograd.Variable(th.FloatTensor(10, 1, 1000)))

In [16]:
res.max(1)[1][:, None]

Variable containing:
    0
    2
    0
    0
    0
    0
    2
    0
    0
    2
[torch.LongTensor of size 10x1]

In [154]:
for k, v in sd.items():
    print(k)

encoder.0.weight
encoder.0.bias
encoder.1.weight
encoder.1.bias
encoder.1.running_mean
encoder.1.running_var
encoder.4.weight
encoder.4.bias
encoder.5.weight
encoder.5.bias
encoder.5.running_mean
encoder.5.running_var
encoder.8.weight
encoder.8.bias
encoder.9.weight
encoder.9.bias
encoder.9.running_mean
encoder.9.running_var
encoder.12.weight
encoder.12.bias
encoder.13.weight
encoder.13.bias
encoder.13.running_mean
encoder.13.running_var
resnet.0.residuals.0.block.0.weight
resnet.0.residuals.0.block.0.bias
resnet.0.residuals.0.block.0.running_mean
resnet.0.residuals.0.block.0.running_var
resnet.0.residuals.0.block.1.weight
resnet.0.residuals.0.block.3.weight
resnet.0.residuals.0.block.3.bias
resnet.0.residuals.0.block.3.running_mean
resnet.0.residuals.0.block.3.running_var
resnet.0.residuals.0.block.5.weight
resnet.0.residuals.1.block.0.weight
resnet.0.residuals.1.block.0.bias
resnet.0.residuals.1.block.0.running_mean
resnet.0.residuals.1.block.0.running_var
resnet.0.residuals.1.block.

In [62]:
a = nn.Conv1d(3, 4, 5)

In [63]:
a.state_dict()

OrderedDict([('weight', 
              (0 ,.,.) = 
               -0.1830  0.1597  0.1484 -0.1916  0.2417
               -0.1660  0.0768 -0.1457  0.1034 -0.2297
               -0.1860  0.0322  0.0990  0.2096 -0.1983
              
              (1 ,.,.) = 
               -0.1521  0.1666 -0.0709  0.0931  0.1345
                0.1034 -0.0586  0.1383 -0.1212 -0.1902
               -0.2440 -0.0988 -0.0854  0.1089  0.1279
              
              (2 ,.,.) = 
                0.0536  0.2212  0.0824  0.0108  0.2315
               -0.2318  0.0847 -0.2136 -0.0735  0.2240
                0.2383  0.0270  0.2266 -0.1901 -0.0234
              
              (3 ,.,.) = 
                0.2279  0.2184  0.0534 -0.0578  0.1607
                0.1085  0.1057  0.0550 -0.0299  0.0537
                0.0227  0.1641 -0.1961  0.0515 -0.1527
              [torch.FloatTensor of size 4x3x5]), ('bias', 
               0.1855
               0.2336
              -0.0204
              -0.0937
              [tor