Skip to content

Commit

Permalink
Merge pull request #28 from bodgit/multi
Browse files Browse the repository at this point in the history
Add MultiProvider
  • Loading branch information
bodgit committed Jan 8, 2021
2 parents 5f29113 + 5f37296 commit 1e96a83
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 0 deletions.
51 changes: 51 additions & 0 deletions multi.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package tsig

import "github.com/miekg/dns"

type multiProvider struct {
providers []dns.TsigProvider
}

func (mp *multiProvider) Generate(msg []byte, t *dns.TSIG) (b []byte, err error) {
for _, p := range mp.providers {
b, err = p.Generate(msg, t)
switch err {
case dns.ErrKeyAlg:
break
default:
return
}
}
return nil, dns.ErrKeyAlg
}

func (mp *multiProvider) Verify(msg []byte, t *dns.TSIG) (err error) {
for _, p := range mp.providers {
err = p.Verify(msg, t)
switch err {
case dns.ErrKeyAlg:
break
default:
return
}
}
return dns.ErrKeyAlg
}

// MultiProvider creates a dns.TsigProvider that chains the provided input
// providers. This allows multiple TSIG algorithms.
//
// Each provider is called in turn and if it returns dns.ErrKeyAlg the next
// provider in the list is tried. On success or any other error, the result is
// returned; it does not continue down the list.
func MultiProvider(providers ...dns.TsigProvider) dns.TsigProvider {
allProviders := make([]dns.TsigProvider, 0, len(providers))
for _, p := range providers {
if mp, ok := p.(*multiProvider); ok {
allProviders = append(allProviders, mp.providers...)
} else {
allProviders = append(allProviders, p)
}
}
return &multiProvider{allProviders}
}
117 changes: 117 additions & 0 deletions multi_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package tsig

import (
"errors"
"testing"

"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)

var (
errProvider = errors.New("provider error")
testSignature = []byte("a good signature")
)

type unsupportedProvider struct{}

func (unsupportedProvider) Generate(_ []byte, _ *dns.TSIG) ([]byte, error) {
return nil, dns.ErrKeyAlg
}

func (unsupportedProvider) Verify(_ []byte, _ *dns.TSIG) error {
return dns.ErrKeyAlg
}

type errorProvider struct{}

func (errorProvider) Generate(_ []byte, _ *dns.TSIG) ([]byte, error) {
return nil, errProvider
}

func (errorProvider) Verify(_ []byte, _ *dns.TSIG) error {
return errProvider
}

type testProvider struct{}

func (testProvider) Generate(_ []byte, _ *dns.TSIG) ([]byte, error) {
return testSignature, nil
}

func (testProvider) Verify(_ []byte, _ *dns.TSIG) error {
return nil
}

func TestMultiProviderGenerate(t *testing.T) {
tables := map[string]struct {
provider dns.TsigProvider
signature []byte
err error
}{
"good": {
MultiProvider(new(testProvider)),
testSignature,
nil,
},
"unsupported good": {
MultiProvider(new(unsupportedProvider), new(testProvider)),
testSignature,
nil,
},
"error good": {
MultiProvider(new(errorProvider), new(testProvider)),
nil,
errProvider,
},
"all unsupported": {
MultiProvider(new(unsupportedProvider)),
nil,
dns.ErrKeyAlg,
},
"nested": {
MultiProvider(MultiProvider(new(testProvider))),
testSignature,
nil,
},
}

for name, table := range tables {
t.Run(name, func(t *testing.T) {
b, err := table.provider.Generate(nil, nil)
assert.Equal(t, table.signature, b)
assert.Equal(t, table.err, err)
})
}
}

func TestMultiProviderVerify(t *testing.T) {
tables := map[string]struct {
provider dns.TsigProvider
err error
}{
"good": {
MultiProvider(new(testProvider)),
nil,
},
"unsupported good": {
MultiProvider(new(unsupportedProvider), new(testProvider)),
nil,
},
"error good": {
MultiProvider(new(errorProvider), new(testProvider)),
errProvider,
},
"all unsuppored": {
MultiProvider(new(unsupportedProvider)),
dns.ErrKeyAlg,
},
}

for name, table := range tables {
t.Run(name, func(t *testing.T) {
err := table.provider.Verify(nil, nil)
assert.Equal(t, table.err, err)
})
}
}

0 comments on commit 1e96a83

Please sign in to comment.