Skip to content
This repository has been archived by the owner on Apr 4, 2018. It is now read-only.

Commit

Permalink
fixed tests 🍌🍌🍌
Browse files Browse the repository at this point in the history
  • Loading branch information
lukovnikov committed May 23, 2016
1 parent 9adac69 commit e33d4b8
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 9 deletions.
22 changes: 15 additions & 7 deletions teafacto/blocks/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,7 @@ def set_init_states(self, *states):

def apply(self, context, seq, context_0=None, **kw): # context: (batsize, enc.innerdim), seq: idxs-(batsize, seqlen)
sequences = [seq.dimswap(1, 0)] # sequences: (seqlen, batsize)
if context_0 is None:
if context.d.ndim == 2: # static context
context_0 = context
elif context.d.ndim == 3: # (batsize, inseqlen, inencdim)
context_0 = context[:, -1, :] # take the last context as initial input
else:
print "sum ting wong in SeqDecoder apply()"
context_0 = self._get_ctx_t0(context, context_0)
if self.init_states is not None:
init_info = self.block.get_init_info(self.init_states) # sets init states to provided ones
else:
Expand All @@ -290,6 +284,20 @@ def apply(self, context, seq, context_0=None, **kw): # context: (batsize, enc
outputs_info=[None, context, context_0, 0] + init_info)
return outputs[0].dimswap(1, 0) # returns probabilities of symbols --> (batsize, seqlen, vocabsize)

def _get_ctx_t0(self, ctx, ctx_0=None):
if ctx_0 is None:
if ctx.d.ndim == 2: # static context
ctx_0 = ctx
elif ctx.d.ndim > 2: # dynamic context (batsize, inseqlen, inencdim)
assert(self.attention is not None) # 3D context only processable with attention (dynamic context)
w_0 = T.ones((ctx.shape[0], ctx.shape[1]), dtype=T.config.floatX) / ctx.shape[1].astype(T.config.floatX) # ==> make uniform weights (??)
ctx_0 = self.attention.attentionconsumer(ctx, w_0)
'''else:
ctx_0 = ctx[:, -1, :] # take the last context'''
else:
print "sum ting wong in SeqDecoder _get_ctx_t0()"
return ctx_0

def recwrap(self, x_t, ctx, ctx_tm1, t, *states_tm1): # x_t: (batsize), context: (batsize, enc.innerdim)
i_t = self.embedder(x_t) # i_t: (batsize, embdim)
j_t = self._get_j_t(i_t, ctx_tm1)
Expand Down
5 changes: 5 additions & 0 deletions teafacto/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ def __getattr__(cls, item):
top = getattr(tensor, item)
return wrapf(top)

@property
def config(cls):
return theano.config

def scan(cls, fn, sequences=None, outputs_info=None, non_sequences=None, n_steps=None, truncate_gradient=-1, go_backwards=False,mode=None, name=None, profile=False, allow_gc=None, strict=False):
return scan()(fn, sequences=sequences, outputs_info=outputs_info, non_sequences=non_sequences, n_steps=n_steps,
truncate_gradient=truncate_gradient, go_backwards=go_backwards,mode=mode, name=name, profile=profile,
Expand Down Expand Up @@ -79,6 +83,7 @@ def __getattr__(self, item):
class tensorops:
__metaclass__ = TWrapper


class TensorWrapper(type):
"""Wrapper class that provides proxy access to an instance of some
internal instance."""
Expand Down
4 changes: 2 additions & 2 deletions test/test_rnnautoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def shiftdata(x):

class TestRNNAutoEncoder(TestCase):
def setUp(self):
vocsize = 25
vocsize = 24
innerdim = 200
encdim = 200
encdim = 190
batsize = 500
seqlen = 5
self.exppredshape = (batsize, seqlen, vocsize)
Expand Down

0 comments on commit e33d4b8

Please sign in to comment.