Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for Distilbert #21

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg/converter/converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/nlpodyssey/cybertron/pkg/converter/bart"
"github.com/nlpodyssey/cybertron/pkg/converter/bert"
"github.com/nlpodyssey/cybertron/pkg/converter/distilbert"
"github.com/nlpodyssey/cybertron/pkg/converter/flair"
"github.com/nlpodyssey/cybertron/pkg/models"
"github.com/nlpodyssey/spago/mat/float"
Expand All @@ -29,6 +30,8 @@ func Convert[T float.DType](modelPath string, overwriteIfExists bool) error {
switch modelType {
case "bert", "electra":
return bert.Convert[T](modelPath, overwriteIfExists)
case "distilbert":
return distilbert.Convert[T](modelPath, overwriteIfExists)
case "bart", "marian", "pegasus":
return bart.Convert[T](modelPath, overwriteIfExists)
case "flair":
Expand Down
214 changes: 214 additions & 0 deletions pkg/converter/distilbert/convert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
package distilbert

import (
"fmt"
"github.com/nlpodyssey/cybertron/pkg/models/distilbert"
"os"
"path/filepath"
"strings"

"github.com/nlpodyssey/cybertron/pkg/converter/pytorch"
"github.com/nlpodyssey/cybertron/pkg/vocabulary"
"github.com/nlpodyssey/spago/embeddings/store/diskstore"
"github.com/nlpodyssey/spago/mat"
"github.com/nlpodyssey/spago/mat/float"
"github.com/nlpodyssey/spago/nn"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)

const (
// defaultConfigFilename is the default DistilBert JSON configuration filename.
defaultConfigFilename = "config.json"
// defaultVocabularyFile is the default DistilBert model's vocabulary filename.
defaultVocabularyFile = "vocab.txt"
// defaultPyModelFilename is the default DistilBert PyTorch model filename.
defaultPyModelFilename = "pytorch_model.bin"
// defaultGoModelFilename is the default DistilBert spaGO model filename.
defaultGoModelFilename = "spago_model.bin"
)

// mappingParam is a mapping between a Hugging Face Transformers parameters and Cybertron parameters.
type mappingParam struct {
value mat.Matrix
matched bool
}

// Convert converts a DistilBert PyTorch model to a Spago (Cybertron) model.
func Convert[T float.DType](modelDir string, overwriteIfExist bool) error {
var (
configFilename = filepath.Join(modelDir, defaultConfigFilename)
pyModelFilename = filepath.Join(modelDir, defaultPyModelFilename)
goModelFilename = filepath.Join(modelDir, defaultGoModelFilename)
vocabFilename = filepath.Join(modelDir, defaultVocabularyFile)
)

if info, err := os.Stat(goModelFilename); !overwriteIfExist && err == nil && !info.IsDir() {
log.Info().Str("model", goModelFilename).Msg("model file already exists, skipping conversion")
return nil
}

config, err := distilbert.ConfigFromFile[distilbert.Config](configFilename)
if err != nil {
return err
}

vocab, err := vocabulary.NewFromFile(vocabFilename)
if err != nil {
return err
}

repo, err := diskstore.NewRepository(filepath.Join(modelDir, "repo"), diskstore.ReadWriteMode)
if err != nil {
panic(err)
}
defer func() {
err = repo.Close()
if err != nil {
panic(err)
}
}()
if err := repo.DropAll(); err != nil {
panic(err)
}

{
// Enable training mode, so that we have writing permissions
// (for example, for embeddings storage files).
config.Cybertron.Training = true
config.Cybertron.TokensStoreName = "tokens"
config.Cybertron.PositionsStoreName = "positions"
config.Cybertron.TokenTypesStoreName = "token_types"

if config.ModelType == "distilbert" || config.EmbeddingsSize == 0 {
config.EmbeddingsSize = config.HiddenSize
}
}

pyParams := pytorch.NewParamsProvider[T]().
WithNameMapping(fixParamsName).
WithPreProcessing(fixAttentionLayers[T](config))

if err = pyParams.Load(pyModelFilename); err != nil {
return err
}

params := make(paramsMap)
baseModel := mapBaseModel[T](config, repo, pyParams, params, vocab)
finalModel := mapSpecificArchitecture[T](baseModel, config.Architectures, params)

mapping := make(map[string]*mappingParam)
for k, v := range params {
mapping[k] = &mappingParam{value: v, matched: false}
}

err = pyParams.Iterate(func(name string, value []T) error {
param, ok := mapping[name]
if !ok {
return nil
}
if param.value.Size() != len(value) {
return fmt.Errorf("error setting %s: dim mismatch", name)
}
mat.SetData[T](param.value, value)
param.matched = true
return nil
})
if err != nil {
return err
}

if zerolog.GlobalLevel() <= zerolog.DebugLevel {
log.Debug().Msg("Reporting possible conversion mapping anomalies")
for key, value := range mapping {
if !value.matched {
log.Debug().Str("parameter", key).Msg("parameter not initialized")
}
}
err = pyParams.Iterate(func(name string, _ []T) error {
if _, ok := mapping[name]; !ok {
log.Debug().Str("parameter", name).Msg("parameter not mapped")
}
return nil
})
if err != nil {
return err
}
}

fmt.Printf("Serializing model to \"%s\"... ", goModelFilename)
err = nn.DumpToFile(finalModel, goModelFilename)
if err != nil {
return err
}

fmt.Println("Done.")

return nil
}

func mapBaseModel[T float.DType](config distilbert.Config, repo *diskstore.Repository, pyParams *pytorch.ParamsProvider[T], params paramsMap, vocab *vocabulary.Vocabulary) *distilbert.Model {
baseModel := distilbert.New[T](config, repo)

{
source := pyParams.Pop("distilbert.embeddings.word_embeddings.weight")
size := baseModel.Embeddings.Tokens.Config.Size
for i := 0; i < config.VocabSize; i++ {
key, _ := vocab.Term(i)
if len(key) == 0 {
continue // skip empty key
}
item, _ := baseModel.Embeddings.Tokens.Embedding(key)
item.ReplaceValue(mat.NewVecDense[T](source[i*size : (i+1)*size]))
}
}

cols := config.HiddenSize

{
source := pyParams.Pop("distilbert.embeddings.position_embeddings.weight")
dest := baseModel.Embeddings.Positions
for i := 0; i < config.MaxPositionEmbeddings; i++ {
item, _ := dest.Embedding(i)
item.ReplaceValue(mat.NewVecDense[T](source[i*cols : (i+1)*cols]))
}
}

mapEmbeddingsLayerNorm(baseModel.Embeddings.Norm, params)
mapTransformerParams(baseModel.Transformer, params)

return baseModel
}

func mapSpecificArchitecture[T float.DType](baseModel *distilbert.Model, architectures []string, params paramsMap) nn.Model {
if architectures == nil {
architectures = append(architectures, "DistilBertBase")
}

switch architectures[0] {
case "DistilBertBase":
return baseModel
case "DistilBertModel":
m := distilbert.NewModelForSequenceEncoding(baseModel)
return m
case "DistilBertForMaskedLM":
m := distilbert.NewModelForMaskedLM[T](baseModel)
mapMaskedLM(m.Layers, params)
return m
default:
panic(fmt.Errorf("distilbert: unsupported architecture %s", architectures[0]))
}
}

func fixParamsName(from string) (to string) {
to = from
to = strings.Replace(to, ".gamma", ".weight", -1)
to = strings.Replace(to, ".beta", ".bias", -1)
if strings.HasPrefix(to, "embeddings.") {
to = fmt.Sprintf("distilbert.%s", to)
}
if strings.HasPrefix(to, "transformer.") {
to = fmt.Sprintf("distilbert.%s", to)
}
return
}
60 changes: 60 additions & 0 deletions pkg/converter/distilbert/mapper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package distilbert

import (
"fmt"
"github.com/nlpodyssey/cybertron/pkg/models/distilbert"

"github.com/nlpodyssey/spago/mat"
"github.com/nlpodyssey/spago/nn"
"github.com/nlpodyssey/spago/nn/linear"
"github.com/nlpodyssey/spago/nn/normalization/layernorm"
)

type paramsMap map[string]mat.Matrix

func mapTransformerParams(transformer *distilbert.Transformer, params paramsMap) {
for i := 0; i < transformer.Config.NumHiddenLayers; i++ {
layer := any(transformer.Layers[i]).(*distilbert.TransformerLayer)
prefixBase := fmt.Sprintf("distilbert.transformer.layer.%d", i)

block1 := layer.SelfAttention
for j := 0; j < transformer.Config.NumAttentionHeads; j++ {
attention := block1.Attention.Heads[j]
prefix := fmt.Sprintf("%s.%d.attention", prefixBase, j)
params[fmt.Sprintf("%s.q_lin.weight", prefix)] = attention.Query.W.Value()
params[fmt.Sprintf("%s.q_lin.bias", prefix)] = attention.Query.B.Value()
params[fmt.Sprintf("%s.k_lin.weight", prefix)] = attention.Key.W.Value()
params[fmt.Sprintf("%s.k_lin.bias", prefix)] = attention.Key.B.Value()
params[fmt.Sprintf("%s.v_lin.weight", prefix)] = attention.Value.W.Value()
params[fmt.Sprintf("%s.v_lin.bias", prefix)] = attention.Value.B.Value()
}
prefix := fmt.Sprintf("distilbert.transformer.layer.%d.attention", i)
params[fmt.Sprintf("%s.out_lin.weight", prefix)] = block1.Attention.OutputMerge.W.Value()
params[fmt.Sprintf("%s.out_lin.bias", prefix)] = block1.Attention.OutputMerge.B.Value()

params[fmt.Sprintf("%s.sa_layer_norm.weight", prefixBase)] = block1.Norm.W.Value()
params[fmt.Sprintf("%s.sa_layer_norm.bias", prefixBase)] = block1.Norm.B.Value()

block2 := layer.FF
params[fmt.Sprintf("%s.ffn.lin1.weight", prefixBase)] = block2.ModuleList[0].(*linear.Model).W.Value()
params[fmt.Sprintf("%s.ffn.lin1.bias", prefixBase)] = block2.ModuleList[0].(*linear.Model).B.Value()
params[fmt.Sprintf("%s.ffn.lin2.weight", prefixBase)] = block2.ModuleList[2].(*linear.Model).W.Value()
params[fmt.Sprintf("%s.ffn.lin2.bias", prefixBase)] = block2.ModuleList[2].(*linear.Model).B.Value()
params[fmt.Sprintf("%s.output_layer_norm.weight", prefixBase)] = block2.Norm.W.Value()
params[fmt.Sprintf("%s.output_layer_norm.bias", prefixBase)] = block2.Norm.B.Value()
}
}

func mapEmbeddingsLayerNorm(embeddingsNorm *layernorm.Model, params paramsMap) {
params["distilbert.embeddings.LayerNorm.weight"] = embeddingsNorm.W.Value()
params["distilbert.embeddings.LayerNorm.bias"] = embeddingsNorm.B.Value()
}

func mapMaskedLM(layers []nn.StandardModel, params paramsMap) {
params["vocab_transform.weight"] = layers[0].(*linear.Model).W.Value()
params["vocab_transform.bias"] = layers[0].(*linear.Model).B.Value()
params["vocab_projector.weight"] = layers[3].(*linear.Model).W.Value()
params["vocab_projector.bias"] = layers[3].(*linear.Model).B.Value()
params["vocab_layer_norm.weight"] = layers[2].(*layernorm.Model).W.Value()
params["vocab_layer_norm.bias"] = layers[2].(*layernorm.Model).B.Value()
}
51 changes: 51 additions & 0 deletions pkg/converter/distilbert/preprocessing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package distilbert

import (
"fmt"
"github.com/nlpodyssey/cybertron/pkg/models/distilbert"

"github.com/nlpodyssey/cybertron/pkg/converter/pytorch"
"github.com/nlpodyssey/spago/mat/float"
)

type paramsPostProcessing[T float.DType] struct {
*pytorch.ParamsProvider[T]
c distilbert.Config
}

func fixAttentionLayers[T float.DType](c distilbert.Config) pytorch.PreProcessingFunc[T] {
return func(params *pytorch.ParamsProvider[T]) error {
p := paramsPostProcessing[T]{
ParamsProvider: params,
c: c,
}
p.fixTransformerSelfAttention()
return nil
}
}

func (p *paramsPostProcessing[T]) fixTransformerSelfAttention() {
for i := 0; i < p.c.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("distilbert.transformer.layer.%d.attention", i)
queryWeight := p.Pop(fmt.Sprintf("%s.q_lin.weight", prefix))
queryBias := p.Pop(fmt.Sprintf("%s.q_lin.bias", prefix))
keyWeight := p.Pop(fmt.Sprintf("%s.k_lin.weight", prefix))
keyBias := p.Pop(fmt.Sprintf("%s.k_lin.bias", prefix))
valueWeight := p.Pop(fmt.Sprintf("%s.v_lin.weight", prefix))
valueBias := p.Pop(fmt.Sprintf("%s.v_lin.bias", prefix))

dim := len(queryBias) / p.c.NumAttentionHeads
dim2 := len(queryBias)
for j := 0; j < p.c.NumAttentionHeads; j++ {
from := j * dim
to := (j + 1) * dim
newPrefix := fmt.Sprintf("distilbert.transformer.layer.%d.%d.attention", i, j)
p.Set(fmt.Sprintf("%s.q_lin.weight", newPrefix), queryWeight[from*dim2:to*dim2])
p.Set(fmt.Sprintf("%s.q_lin.bias", newPrefix), queryBias[from:to])
p.Set(fmt.Sprintf("%s.k_lin.weight", newPrefix), keyWeight[from*dim2:to*dim2])
p.Set(fmt.Sprintf("%s.k_lin.bias", newPrefix), keyBias[from:to])
p.Set(fmt.Sprintf("%s.v_lin.weight", newPrefix), valueWeight[from*dim2:to*dim2])
p.Set(fmt.Sprintf("%s.v_lin.bias", newPrefix), valueBias[from:to])
}
}
}
13 changes: 7 additions & 6 deletions pkg/downloader/downloadmodel.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ const (
// supportedModelsFiles contains the set of all supported model types as keys,
// mapped with the set of all related files to download.
var supportedModelsFiles = map[string][]string{
"bart": {"pytorch_model.bin", "vocab.json", "merges.txt"},
"pegasus": {"pytorch_model.bin", "spiece.model"},
"marian": {"pytorch_model.bin", "vocab.json", "source.spm", "target.spm"},
"bert": {"pytorch_model.bin", "vocab.txt", "tokenizer_config.json"},
"electra": {"pytorch_model.bin", "vocab.txt", "tokenizer_config.json"},
"flair": {"pytorch_model.bin"},
"bart": {"pytorch_model.bin", "vocab.json", "merges.txt"},
"pegasus": {"pytorch_model.bin", "spiece.model"},
"marian": {"pytorch_model.bin", "vocab.json", "source.spm", "target.spm"},
"bert": {"pytorch_model.bin", "vocab.txt", "tokenizer_config.json"},
"distilbert": {"pytorch_model.bin", "vocab.txt", "tokenizer_config.json"},
"electra": {"pytorch_model.bin", "vocab.txt", "tokenizer_config.json"},
"flair": {"pytorch_model.bin"},
}

// Download downloads a supported pre-trained model from huggingface.co
Expand Down
Loading