-
-
Notifications
You must be signed in to change notification settings - Fork 86
/
fixnorm.go
45 lines (38 loc) · 1.06 KB
/
fixnorm.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
// Copyright 2019 spaGO Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package fixnorm implements the fixnorm normalization method.
//
// Reference: "Improving Lexical Choice in Neural Machine Translation" by Toan Q. Nguyen and David Chiang (2018)
// (https://arxiv.org/pdf/1710.01329.pdf)
package fixnorm
import (
"encoding/gob"
"github.com/nlpodyssey/spago/pkg/ml/ag"
"github.com/nlpodyssey/spago/pkg/ml/nn"
)
var (
_ nn.Model = &Model{}
)
// Model is an empty model used to instantiate a new Processor.
type Model struct {
nn.BaseModel
}
func init() {
gob.Register(&Model{})
}
// New returns a new model.
func New() *Model {
return &Model{}
}
// Forward performs the forward step for each input node and returns the result.
func (m *Model) Forward(xs ...ag.Node) []ag.Node {
g := m.Graph()
ys := make([]ag.Node, len(xs))
eps := g.NewScalar(1e-10)
for i, x := range xs {
norm := g.Sqrt(g.ReduceSum(g.Square(x)))
ys[i] = g.DivScalar(x, g.AddScalar(norm, eps))
}
return ys
}