Skip to content

Commit

Permalink
Update Capsule_Keras.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bojone committed Mar 19, 2018
1 parent 8c14be4 commit af9959f
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions Capsule_Keras.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#! -*- coding: utf-8 -*-
# refer: https://kexue.fm/archives/5112

from keras import activations
from keras import backend as K
from keras.engine.topology import Layer

Expand All @@ -17,17 +19,16 @@ def softmax(x, axis=-1):

#A Capsule Implement with Pure Keras
class Capsule(Layer):
def __init__(self, num_capsule, dim_capsule, routings=3, share_weights=True, activation='default', **kwargs):
def __init__(self, num_capsule, dim_capsule, routings=3, share_weights=True, activation='squash', **kwargs):
super(Capsule, self).__init__(**kwargs)
self.num_capsule = num_capsule
self.dim_capsule = dim_capsule
self.routings = routings
self.share_weights = share_weights
if activation == 'default':
if activation == 'squash':
self.activation = squash
else:
self.activation = activation
print self.activation
self.activation = activations.get(activation)

def build(self, input_shape):
super(Capsule, self).build(input_shape)
Expand Down Expand Up @@ -66,13 +67,13 @@ def call(self, u_vecs):
o = K.batch_dot(c, u_hat_vecs, [2, 2])
if K.backend() == 'theano':
o = K.sum(o, axis=1)
o = self.activation(o)
if i < self.routings - 1:
o = K.l2_normalize(o, -1)
b = K.batch_dot(o, u_hat_vecs, [2, 3])
if K.backend() == 'theano':
b = K.sum(b, axis=1)

return o
return self.activation(o)

def compute_output_shape(self, input_shape):
return (None, self.num_capsule, self.dim_capsule)

0 comments on commit af9959f

Please sign in to comment.