-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #28 from bodgit/multi
Add MultiProvider
- Loading branch information
Showing
2 changed files
with
168 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package tsig | ||
|
||
import "github.com/miekg/dns" | ||
|
||
type multiProvider struct { | ||
providers []dns.TsigProvider | ||
} | ||
|
||
func (mp *multiProvider) Generate(msg []byte, t *dns.TSIG) (b []byte, err error) { | ||
for _, p := range mp.providers { | ||
b, err = p.Generate(msg, t) | ||
switch err { | ||
case dns.ErrKeyAlg: | ||
break | ||
default: | ||
return | ||
} | ||
} | ||
return nil, dns.ErrKeyAlg | ||
} | ||
|
||
func (mp *multiProvider) Verify(msg []byte, t *dns.TSIG) (err error) { | ||
for _, p := range mp.providers { | ||
err = p.Verify(msg, t) | ||
switch err { | ||
case dns.ErrKeyAlg: | ||
break | ||
default: | ||
return | ||
} | ||
} | ||
return dns.ErrKeyAlg | ||
} | ||
|
||
// MultiProvider creates a dns.TsigProvider that chains the provided input | ||
// providers. This allows multiple TSIG algorithms. | ||
// | ||
// Each provider is called in turn and if it returns dns.ErrKeyAlg the next | ||
// provider in the list is tried. On success or any other error, the result is | ||
// returned; it does not continue down the list. | ||
func MultiProvider(providers ...dns.TsigProvider) dns.TsigProvider { | ||
allProviders := make([]dns.TsigProvider, 0, len(providers)) | ||
for _, p := range providers { | ||
if mp, ok := p.(*multiProvider); ok { | ||
allProviders = append(allProviders, mp.providers...) | ||
} else { | ||
allProviders = append(allProviders, p) | ||
} | ||
} | ||
return &multiProvider{allProviders} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
package tsig | ||
|
||
import ( | ||
"errors" | ||
"testing" | ||
|
||
"github.com/miekg/dns" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
var ( | ||
errProvider = errors.New("provider error") | ||
testSignature = []byte("a good signature") | ||
) | ||
|
||
type unsupportedProvider struct{} | ||
|
||
func (unsupportedProvider) Generate(_ []byte, _ *dns.TSIG) ([]byte, error) { | ||
return nil, dns.ErrKeyAlg | ||
} | ||
|
||
func (unsupportedProvider) Verify(_ []byte, _ *dns.TSIG) error { | ||
return dns.ErrKeyAlg | ||
} | ||
|
||
type errorProvider struct{} | ||
|
||
func (errorProvider) Generate(_ []byte, _ *dns.TSIG) ([]byte, error) { | ||
return nil, errProvider | ||
} | ||
|
||
func (errorProvider) Verify(_ []byte, _ *dns.TSIG) error { | ||
return errProvider | ||
} | ||
|
||
type testProvider struct{} | ||
|
||
func (testProvider) Generate(_ []byte, _ *dns.TSIG) ([]byte, error) { | ||
return testSignature, nil | ||
} | ||
|
||
func (testProvider) Verify(_ []byte, _ *dns.TSIG) error { | ||
return nil | ||
} | ||
|
||
func TestMultiProviderGenerate(t *testing.T) { | ||
tables := map[string]struct { | ||
provider dns.TsigProvider | ||
signature []byte | ||
err error | ||
}{ | ||
"good": { | ||
MultiProvider(new(testProvider)), | ||
testSignature, | ||
nil, | ||
}, | ||
"unsupported good": { | ||
MultiProvider(new(unsupportedProvider), new(testProvider)), | ||
testSignature, | ||
nil, | ||
}, | ||
"error good": { | ||
MultiProvider(new(errorProvider), new(testProvider)), | ||
nil, | ||
errProvider, | ||
}, | ||
"all unsupported": { | ||
MultiProvider(new(unsupportedProvider)), | ||
nil, | ||
dns.ErrKeyAlg, | ||
}, | ||
"nested": { | ||
MultiProvider(MultiProvider(new(testProvider))), | ||
testSignature, | ||
nil, | ||
}, | ||
} | ||
|
||
for name, table := range tables { | ||
t.Run(name, func(t *testing.T) { | ||
b, err := table.provider.Generate(nil, nil) | ||
assert.Equal(t, table.signature, b) | ||
assert.Equal(t, table.err, err) | ||
}) | ||
} | ||
} | ||
|
||
func TestMultiProviderVerify(t *testing.T) { | ||
tables := map[string]struct { | ||
provider dns.TsigProvider | ||
err error | ||
}{ | ||
"good": { | ||
MultiProvider(new(testProvider)), | ||
nil, | ||
}, | ||
"unsupported good": { | ||
MultiProvider(new(unsupportedProvider), new(testProvider)), | ||
nil, | ||
}, | ||
"error good": { | ||
MultiProvider(new(errorProvider), new(testProvider)), | ||
errProvider, | ||
}, | ||
"all unsuppored": { | ||
MultiProvider(new(unsupportedProvider)), | ||
dns.ErrKeyAlg, | ||
}, | ||
} | ||
|
||
for name, table := range tables { | ||
t.Run(name, func(t *testing.T) { | ||
err := table.provider.Verify(nil, nil) | ||
assert.Equal(t, table.err, err) | ||
}) | ||
} | ||
} |