Skip to content

Commit

Permalink
add end index marker on loss
Browse files Browse the repository at this point in the history
  • Loading branch information
gogamza committed Jan 30, 2018
1 parent 9c81f42 commit d48802f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
4 changes: 2 additions & 2 deletions main.py
Expand Up @@ -147,7 +147,7 @@ def model_init(n_hidden,vocab_size, embed_dim, max_seq_length, embed_weights, ct
model.embedding.weight.set_data(embed_weights)

trainer = gluon.Trainer(model.collect_params(), 'rmsprop')
loss = SoftmaxCrossEntropyLossMask(axis = 2)
loss = SoftmaxCrossEntropyLossMask(end_idx, axis = 2)
return(model, loss, trainer)


Expand Down Expand Up @@ -290,7 +290,7 @@ def train(epochs, tr_data_iterator, model, loss, trainer, ctx, start_epoch=1, md
else:
print("train start from '{}'".format(opt.init_model))
model.load_params(opt.init_model, ctx=ctx)
trainer_sgd = gluon.Trainer(model.collect_params(), 'sgd', optimizer_params={'learning_rate':0.01,}, kvstore='device')
trainer_sgd = gluon.Trainer(model.collect_params(), 'sgd', optimizer_params={'learning_rate':0.1,}, kvstore='local')
tr_loss, te_loss = train(5, tr_data_iterator, model, loss, trainer_sgd, ctx=ctx, mdl_desc=opt.model_prefix, decay=True)

if opt.test:
Expand Down
9 changes: 7 additions & 2 deletions mask_loss.py
Expand Up @@ -2,16 +2,18 @@


class SoftmaxCrossEntropyLossMask(Loss):
def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None,
def __init__(self, end_idx, axis=-1, sparse_label=True, from_logits=False, weight=None,
batch_axis=0, **kwargs):
super(SoftmaxCrossEntropyLossMask, self).__init__(weight, batch_axis, **kwargs)
self._axis = axis
self._sparse_label = sparse_label
self._from_logits = from_logits
self.end_idx = end_idx

def hybrid_forward(self, F, pred, label, sample_weight=None):
#각 label 문장의 마지막 문자('END') 인덱스 정보 추출
label = F.cast(label, dtype='float32')
label_sent_length = F.argmax(F.where(label == 2.0, F.ones_like(label), F.zeros_like(label)), axis=1)
label_sent_length = F.argmax(F.where(label == self.end_idx, F.ones_like(label), F.zeros_like(label)), axis=1)


if not self._from_logits:
Expand All @@ -28,3 +30,6 @@ def hybrid_forward(self, F, pred, label, sample_weight=None):
loss = F.SequenceMask(loss, sequence_length=label_sent_length + 1, use_sequence_length=True)
loss = F.transpose(loss, (1,0, 2))
return F.sum(loss, axis=self._batch_axis, exclude=True)/(label_sent_length + 1)



0 comments on commit d48802f

Please sign in to comment.