Skip to content
This repository has been archived by the owner on Apr 4, 2018. It is now read-only.

Commit

Permalink
fixed tests after refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
lukovnikov committed Aug 25, 2016
1 parent 8aac4ee commit c8f4a04
Show file tree
Hide file tree
Showing 14 changed files with 75 additions and 166 deletions.
6 changes: 3 additions & 3 deletions teafacto/blocks/kgraph/fbencdec.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, wordembdim=50, entembdim=200, innerdim=200, attdim=100, outdi
self.dec = SeqDecoder(
[VectorEmbed(indim=self.outdim, dim=self.entembdim), GRU(dim=self.entembdim, innerdim=self.decinnerdim)],
attention=Attention(attgen, attcon),
outconcat=True,
outconcat=True, inconcat=False,
innerdim=self.encinnerdim + self.decinnerdim
)

Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(self, wordembdim=50, wordencdim=50, entembdim=200, innerdim=200, at
self.dec = SeqDecoder(
[VectorEmbed(indim=self.outdim, dim=self.entembdim), GRU(dim=self.entembdim, innerdim=self.decinnerdim)],
attention=Attention(attgen, attcon),
outconcat=True,
outconcat=True, inconcat=False,
innerdim=self.encinnerdim + self.decinnerdim)

def apply(self, inpseq, outseq):
Expand Down Expand Up @@ -283,7 +283,7 @@ def init(self):

self.dec = SeqDecoder(
[self.memblock, GRU(dim=self.entembdim + self.encinnerdim, innerdim=self.decinnerdim)],
outconcat=True,
outconcat=True, inconcat=False,
attention=Attention(attgen, attcon),
innerdim=self.decinnerdim + self.encinnerdim,
softmaxoutblock=self.softmaxoutblock
Expand Down
6 changes: 3 additions & 3 deletions teafacto/blocks/seq/enc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from blocks.basic import IdxToOneHot, VectorEmbed
from blocks.pool import Pool
from blocks.seq.rnn import MakeRNU
from teafacto.blocks.basic import IdxToOneHot, VectorEmbed
from teafacto.blocks.pool import Pool
from teafacto.blocks.seq.rnn import MakeRNU
from teafacto.blocks.seq.oldseqproc import Vec2Idx, SimpleVec2Idx
from teafacto.blocks.seq.rnn import SeqEncoder, MaskMode
from teafacto.core.base import Block, tensorops as T
Expand Down
27 changes: 17 additions & 10 deletions teafacto/blocks/seq/encdec.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from teafacto.core.base import Block, asblock, Val, issequence
from teafacto.blocks.seq.rnn import SeqEncoder, MaskMode, MaskSetMode, SeqDecAtt, BiRNU
from teafacto.blocks.seq.rnn import SeqEncoder, MaskMode, MaskSetMode, SeqDecoder, BiRNU
from teafacto.blocks.seq.rnu import GRU
from teafacto.blocks.seq.attention import Attention, LinearGateAttentionGenerator, WeightedSumAttCon
from teafacto.blocks.basic import VectorEmbed, IdxToOneHot, MatDot, Softmax
from teafacto.blocks.basic import VectorEmbed, IdxToOneHot, MatDot


class SeqEncDec(Block):
Expand Down Expand Up @@ -55,17 +55,18 @@ def rec(self, x_t, *states):

class SeqEncDecAtt(SeqEncDec):
def __init__(self, enclayers, declayers, attgen, attcon,
decinnerdim, statetrans=None, vecout=False, **kw):
decinnerdim, statetrans=None, vecout=False,
inconcat=True, outconcat=False, **kw):
enc = SeqEncoder(*enclayers)\
.with_outputs\
.with_mask\
.maskoptions(-1, MaskMode.AUTO, MaskSetMode.ZERO)
smo = False if vecout else None
dec = SeqDecAtt(
dec = SeqDecoder(
declayers,
attention=Attention(attgen, attcon),
innerdim=decinnerdim,
softmaxoutblock=smo,
innerdim=decinnerdim, inconcat=inconcat,
softmaxoutblock=smo, outconcat=outconcat
)
super(SeqEncDecAtt, self).__init__(enc, dec, statetrans=statetrans, **kw)

Expand All @@ -83,6 +84,8 @@ def __init__(self,
rnu=GRU,
statetrans=None,
vecout=False,
inconcat=True,
outconcat=False,
**kw):
encinnerdim = [encdim] if not issequence(encdim) else encdim
decinnerdim = [decdim] if not issequence(decdim) else decdim
Expand All @@ -91,10 +94,12 @@ def __init__(self,
self.getenclayers(inpembdim, inpvocsize, encinnerdim, bidir, rnu)

self.declayers = \
self.getdeclayers(outembdim, outvocsize, lastencinnerdim, decinnerdim, rnu)
self.getdeclayers(outembdim, outvocsize, lastencinnerdim,
decinnerdim, rnu, inconcat)

# attention
lastdecinnerdim = decinnerdim[-1]
argdecinnerdim = lastdecinnerdim if outconcat is False else lastencinnerdim + lastdecinnerdim
attgen = LinearGateAttentionGenerator(indim=lastencinnerdim + lastdecinnerdim,
attdim=attdim)
attcon = WeightedSumAttCon()
Expand All @@ -106,7 +111,8 @@ def __init__(self,
statetrans = MatDot(lastencinnerdim, lastdecinnerdim)

super(SimpleSeqEncDecAtt, self).__init__(self.enclayers, self.declayers,
attgen, attcon, lastdecinnerdim, statetrans=statetrans, vecout=vecout, **kw)
attgen, attcon, argdecinnerdim, statetrans=statetrans, vecout=vecout,
inconcat=inconcat, outconcat=outconcat, **kw)

def getenclayers(self, inpembdim, inpvocsize, encinnerdim, bidir, rnu):
if inpembdim is None:
Expand All @@ -132,7 +138,8 @@ def getenclayers(self, inpembdim, inpvocsize, encinnerdim, bidir, rnu):
enclayers = [inpemb] + encrnus
return enclayers, lastencinnerdim

def getdeclayers(self, outembdim, outvocsize, lastencinnerdim, decinnerdim, rnu):
def getdeclayers(self, outembdim, outvocsize, lastencinnerdim,
decinnerdim, rnu, inconcat):
if outembdim is None:
outemb = IdxToOneHot(outvocsize)
outembdim = outvocsize
Expand All @@ -142,7 +149,7 @@ def getdeclayers(self, outembdim, outvocsize, lastencinnerdim, decinnerdim, rnu)
else:
outemb = VectorEmbed(indim=outvocsize, dim=outembdim)
decrnus = []
firstdecdim = outembdim + lastencinnerdim
firstdecdim = outembdim + lastencinnerdim if inconcat else outembdim
dims = [firstdecdim] + decinnerdim
i = 1
while i < len(dims):
Expand Down
8 changes: 2 additions & 6 deletions teafacto/blocks/seq/oldseqproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from teafacto.core.stack import stack
from teafacto.util import issequence


"""
class SeqEncDec(Block):
def __init__(self, enc, dec, statetrans=None, **kw):
super(SeqEncDec, self).__init__(**kw)
Expand Down Expand Up @@ -37,10 +37,6 @@ def apply(self, inpseq, outseq, maskseq=None):
return deco
def get_init_info(self, inpseq, batsize, maskseq=None): # TODO: must evaluate enc here, in place, without any side effects
"""
VERY DIFFERENT FROM THE PURELY SYMBOLIC GET_INIT_INFO IN REAL REC BLOCKS !!!
This one is used in decoder/prediction
"""
enco, allenco, encmask = self.enc.predict(inpseq, mask=maskseq)
if self.statetrans is not None:
Expand Down Expand Up @@ -159,7 +155,7 @@ def getdeclayers(self, outembdim, outvocsize, encinnerdim,
i += 1
declayers = [outemb] + decrnus
return declayers

"""

class SeqTransducer(Block):
def __init__(self, embedder, *layers, **kw):
Expand Down
170 changes: 38 additions & 132 deletions teafacto/blocks/seq/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,14 +309,17 @@ def setreturn(self, *args):
return self


class SeqDecAtt(Block):
class SeqDecoder(Block):
""" seq decoder with attention with new inconcat implementation """
def __init__(self, layers, softmaxoutblock=None, innerdim=None, attention=None, **kw):
super(SeqDecAtt, self).__init__(**kw)
def __init__(self, layers, softmaxoutblock=None, innerdim=None,
attention=None, inconcat=True, outconcat=False, **kw):
super(SeqDecoder, self).__init__(**kw)
self.embedder = layers[0]
self.block = RecStack(*layers[1:])
self.outdim = innerdim
self.attention = attention
self.inconcat = inconcat
self.outconcat = outconcat
self._mask = False
self._attention = None
assert(isinstance(self.block, ReccableBlock))
Expand All @@ -333,13 +336,14 @@ def __init__(self, layers, softmaxoutblock=None, innerdim=None, attention=None,
def numstates(self):
return self.block.numstates

def apply(self, context, seq, initstates=None, mask=None, encmask=None, **kw): # context: (batsize, enc.innerdim), seq: idxs-(batsize, seqlen)
def apply(self, context, seq, context_0=None, initstates=None, mask=None, encmask=None, **kw): # context: (batsize, enc.innerdim), seq: idxs-(batsize, seqlen)
if initstates is None:
initstates = seq.shape[0]
elif issequence(initstates):
if len(initstates) < self.numstates: # fill up with batsizes for lower layers
initstates = [seq.shape[0]] * (self.numstates - len(initstates)) + initstates
init_info, nonseq = self.get_init_info(context, initstates, encmask=encmask) # sets init states to provided ones
init_info, nonseq = self.get_init_info(context, initstates,
ctx_0=context_0, encmask=encmask) # sets init states to provided ones
outputs, _ = T.scan(fn=self.rec,
sequences=seq.dimswap(1, 0),
outputs_info=[None] + init_info,
Expand All @@ -363,136 +367,38 @@ def applymask(cls, xseq, maskseq):
ret = xseq * mask + masker * (1.0 - mask)
return ret

def get_init_info(self, context, initstates, encmask=None):
ret = self.block.get_init_info(initstates)
return [0] + ret, [encmask, context]
def get_init_info(self, context, initstates, ctx_0=None, encmask=None):
initstates = self.block.get_init_info(initstates)
ctx_0 = self._get_ctx_t(context, initstates, encmask) if ctx_0 is None else ctx_0
if encmask is None:
encmask = T.ones(context.shape[:2], dtype="float32")
return [ctx_0, 0] + initstates, [encmask, context]

def _get_ctx_t(self, ctx, states_tm1, encmask):
if ctx.d.ndim == 2: # static context
ctx_t = ctx
elif ctx.d.ndim > 2:
# ctx is 3D, always dynamic context
assert(self.attention is not None)
h_tm1 = states_tm1[0] # ??? --> will it also work with multi-state RNUs?
ctx_t = self.attention(h_tm1, ctx, mask=encmask)
return ctx_t

def rec(self, x_t, t, *args): # x_t: (batsize), context: (batsize, enc.innerdim)
states_tm1 = args[:-1]
def rec(self, x_t, ctx_tm1, t, *args): # x_t: (batsize), context: (batsize, enc.innerdim)
states_tm1 = args[:-2]
ctx = args[-1]
encmask = args[-2]
h_tm1 = states_tm1[0] # ???
x_t_emb = self.embedder(x_t) # i_t: (batsize, embdim)
# get context with attention
ctx_t = self.attention(h_tm1, ctx, mask=encmask)
# do inconcat
i_t = T.concatenate([x_t_emb, ctx_t], axis=1)
i_t = T.concatenate([x_t_emb, ctx_tm1], axis=1) if self.inconcat else x_t_emb
rnuret = self.block.rec(i_t, *states_tm1)
t += 1
pre_y_t = rnuret[0]
h_t = rnuret[0]
states_t = rnuret[1:]
y_t = self.softmaxoutblock(pre_y_t)
return [y_t, t] + states_t


class SeqDecoder(Block):
'''
Decodes a sequence of symbols given context
output: probabilities over symbol space: float: (batsize, seqlen, vocabsize)
! must pass in a recurrent block that takes two arguments: context_t and x_t
! first input is TERMINUS ==> suggest to set TERMINUS(0) embedding to all zeroes (in s2vf)
! first layer must be an embedder or IdxToOneHot, otherwise, an IdxToOneHot is created automatically based on given dim
'''
def __init__(self, layers, softmaxoutblock=None, innerdim=None, attention=None, inconcat=False, outconcat=False, **kw): # limit says at most how many is produced
super(SeqDecoder, self).__init__(**kw)
self.embedder = layers[0]
self.block = RecStack(*layers[1:])
self.outdim = innerdim
self.inconcat = inconcat
self.outconcat = outconcat
self.attention = attention
self._mask = False
self._attention = None
assert(isinstance(self.block, ReccableBlock))
if softmaxoutblock is None: # default softmax out block
sm = Softmax()
self.lin = MatDot(indim=self.outdim, dim=self.embedder.indim)
self.softmaxoutblock = asblock(lambda x: sm(self.lin(x)))
elif softmaxoutblock is False:
self.softmaxoutblock = asblock(lambda x: x)
else:
self.softmaxoutblock = softmaxoutblock

@property
def numstates(self):
return self.block.numstates

def apply(self, context, seq, context_0=None, initstates=None, mask=None, encmask=None, **kw): # context: (batsize, enc.innerdim), seq: idxs-(batsize, seqlen)
if initstates is None:
initstates = seq.shape[0]
elif issequence(initstates):
if len(initstates) < self.numstates: # fill up with batsizes for lower layers
initstates = [seq.shape[0]]*(self.numstates - len(initstates)) + initstates
init_info, ctx = self.get_init_info(context, context_0, initstates, encmask=encmask) # sets init states to provided ones
outputs, _ = T.scan(fn=self.rec,
sequences=seq.dimswap(1, 0),
outputs_info=[None] + init_info,
non_sequences=ctx)
ret = outputs[0].dimswap(1, 0) # returns probabilities of symbols --> (batsize, seqlen, vocabsize)
if mask == "auto":
mask = (seq > 0).astype("int32")
ret = self.applymask(ret, mask)
return ret

@classmethod
def applymask(cls, xseq, maskseq):
if maskseq is None:
return xseq
else:
mask = T.tensordot(maskseq, T.ones((xseq.shape[2],)), 0) # f32^(batsize, seqlen, outdim) -- maskseq stacked
masker = T.concatenate(
[T.ones((xseq.shape[0], xseq.shape[1], 1)),
T.zeros((xseq.shape[0], xseq.shape[1], xseq.shape[2] - 1))],
axis=2) # f32^(batsize, seqlen, outdim) -- gives 100% prob to output 0
ret = xseq * mask + masker * (1.0 - mask)
return ret

def get_init_info(self, context, context_0, initstates, encmask=None):
# TODO: get ctx_0 with new inconcat idea
ret = self.block.get_init_info(initstates)
context_0 = self._get_ctx_t0(context, context_0)
return [context_0, 0] + ret, [encmask, context]

def _get_ctx_t0(self, ctx, ctx_0=None):
if ctx_0 is None:
if ctx.d.ndim == 2: # static context
ctx_0 = ctx
elif ctx.d.ndim > 2: # dynamic context (batsize, inseqlen, inencdim)
assert(self.attention is not None) # 3D context only processable with attention (dynamic context)
w_0 = T.ones((ctx.shape[0], ctx.shape[1]), dtype=T.config.floatX) / ctx.shape[1].astype(T.config.floatX) # ==> make uniform weights (??)
ctx_0 = self.attention.attentionconsumer(ctx, w_0)
'''else:
ctx_0 = ctx[:, -1, :] # take the last context'''
else:
print "sum ting wong in SeqDecoder _get_ctx_t0()"
return ctx_0

def rec(self, x_t, ctx_tm1, t, *args): # x_t: (batsize), context: (batsize, enc.innerdim)
# TODO: implement new inconcat
states_tm1 = args[:-1]
ctx = args[-1]
encmask = args[-2]
i_t = self.embedder(x_t) # i_t: (batsize, embdim)
j_t = self._get_j_t(i_t, ctx_tm1)
rnuret = self.block.rec(j_t, *states_tm1) # list of matrices (batsize, **somedims**)
ret = rnuret
t = t + 1
h_t = ret[0]
states_t = ret[1:]
ctx_t = self._gen_context(ctx, h_t, encmask)
g_t = self._get_g_t(h_t, ctx_t)
y_t = self.softmaxoutblock(g_t)
return [y_t, ctx_t, t] + states_t #, {}, T.until( (i > 1) * T.eq(mask.norm(1), 0) )

def _get_j_t(self, i_t, ctx_tm1):
return T.concatenate([i_t, ctx_tm1], axis=1) if self.inconcat else i_t

def _get_g_t(self, h_t, ctx_t):
return T.concatenate([h_t, ctx_t], axis=1) if self.outconcat else h_t

def _gen_context(self, multicontext, criterion, encmask):
return self.attention(criterion, multicontext, mask=encmask) if self.attention is not None else multicontext
ctx_t = self._get_ctx_t(ctx, states_t, encmask) # get context with attention
_y_t = T.concatenate([h_t, ctx_t], axis=1) if self.outconcat else h_t
y_t = self.softmaxoutblock(_y_t)
return [y_t, ctx_t, t] + states_t

# ----------------------------------------------------------------------------------------------------------------------

Expand All @@ -505,7 +411,7 @@ def __init__(self, innerdim=50, input_vocsize=100, output_vocsize=100, **kw):
encrec = GRU(dim=input_vocsize, innerdim=innerdim)
decrecrnu = GRU(dim=output_vocsize, innerdim=innerdim)
self.enc = SeqEncoder(input_embedder, encrec)
self.dec = SeqDecoder([output_embedder, decrecrnu], outconcat=True, innerdim=innerdim+innerdim)
self.dec = SeqDecoder([output_embedder, decrecrnu], outconcat=True, inconcat=False, innerdim=innerdim+innerdim)

def apply(self, inpseq, outseq):
enco = self.enc(inpseq)
Expand Down Expand Up @@ -564,7 +470,7 @@ def __init__(self, vocsize=25, outvocsize=20, encdim=200, innerdim=200, attdim=5
attcon = SeqEncoder(None,
GRU(dim=vocsize, innerdim=encdim))
self.dec = SeqDecoder([IdxToOneHot(outvocsize), GRU(dim=outvocsize, innerdim=innerdim)],
outconcat=True,
outconcat=True, inconcat=False,
attention=Attention(attgen, attcon),
innerdim=innerdim+encdim)

Expand Down Expand Up @@ -596,7 +502,7 @@ def __init__(self, vocsize=25, outvocsize=25, encdim=300, innerdim=200, attdim=5
attgen = LinearGateAttentionGenerator(indim=innerdim+encdim, innerdim=attdim)
attcon = WeightedSumAttCon()
self.dec = SeqDecoder([IdxToOneHot(outvocsize), GRU(dim=outvocsize, innerdim=innerdim)],
outconcat=True,
outconcat=True, inconcat=False,
attention=Attention(attgen, attcon),
innerdim=innerdim+encdim
)
Expand All @@ -614,7 +520,7 @@ def __init__(self, vocsize=25, outvocsize=25, encdim=300, innerdim=200, attdim=5
attgen = LinearGateAttentionGenerator(indim=innerdim+encdim*2, innerdim=attdim)
attcon = WeightedSumAttCon()
self.dec = SeqDecoder([IdxToOneHot(outvocsize), GRU(dim=outvocsize, innerdim=innerdim)],
outconcat=True,
outconcat=True, inconcat=False,
attention=Attention(attgen, attcon),
innerdim=innerdim+encdim*2
)
Expand All @@ -632,7 +538,7 @@ def __init__(self, vocsize=25, outvocsize=25, encdim=300, innerdim=200, attdim=5
attgen = LinearGateAttentionGenerator(indim=innerdim+encdim*2, innerdim=attdim)
attcon = WeightedSumAttCon()
self.dec = SeqDecoder([IdxToOneHot(outvocsize), GRU(dim=outvocsize+encdim*2, innerdim=innerdim)],
inconcat=True,
inconcat=True, outconcat=False,
attention=Attention(attgen, attcon),
innerdim=innerdim
)
Expand Down

0 comments on commit c8f4a04

Please sign in to comment.