/
classification.go
52 lines (45 loc) · 1.29 KB
/
classification.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
// Copyright 2020 spaGO Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sequenceclassification
import (
"encoding/gob"
mat "github.com/nlpodyssey/spago/pkg/mat32"
"github.com/nlpodyssey/spago/pkg/ml/ag"
"github.com/nlpodyssey/spago/pkg/ml/nn"
"github.com/nlpodyssey/spago/pkg/ml/nn/activation"
"github.com/nlpodyssey/spago/pkg/ml/nn/linear"
"github.com/nlpodyssey/spago/pkg/ml/nn/stack"
)
var (
_ nn.Model = &Classifier{}
)
// ClassifierConfig provides configuration settings for a BART head for sentence-level
// Classifier model.
type ClassifierConfig struct {
InputSize int
HiddenSize int
OutputSize int
PoolerDropout mat.Float
}
// Classifier is a model for BART head for sentence-level classification tasks.
type Classifier struct {
Config ClassifierConfig
*stack.Model
}
func init() {
gob.Register(&Classifier{})
}
// NewClassifier returns a new Classifier.
func NewClassifier(config ClassifierConfig) *Classifier {
return &Classifier{
Config: config,
Model: stack.New(
// dropout.New(pooler_dropout),
linear.New(config.InputSize, config.HiddenSize),
activation.New(ag.OpTanh),
// dropout.New(pooler_dropout),
linear.New(config.HiddenSize, config.OutputSize),
),
}
}