diff --git a/asymetric.go b/asymetric.go new file mode 100644 index 0000000..896f572 --- /dev/null +++ b/asymetric.go @@ -0,0 +1,82 @@ +package ezcrypt + +import ( + "fmt" + "io" + + "golang.org/x/crypto/nacl/box" +) + +func checkAsym(m []byte, k Key, p Pair) error { + err := checkEncrypt(m, k) + + if err != nil { + return err + } + + if p == nil { + return fmt.Errorf("Pair is missing.") + } + + if p.private() == nil { + return fmt.Errorf("Pair is missing private key.") + } + + if p.private().Bytes() == nil { + return fmt.Errorf("Private key is missing data.") + } + + return nil +} + +func encryptAsym(m []byte, k Key, p Pair, in io.Reader) ([]byte, error) { + err := checkAsym(m, k, p) + + if err != nil { + return nil, err + } + + n, err := generateNonce(in) + + if err != nil { + return nil, err + } + + ret := box.Seal(n.Slice(), m, n.Bytes(), k.Bytes(), p.private().Bytes()) + + return ret, nil +} + +func checkDecryptAsym(m []byte, k Key, p Pair) error { + err := checkAsym(m, k, p) + + if err != nil { + return err + } + + if len(m) <= overhead { + return fmt.Errorf("Message is too short.") + } + + return nil +} + +func decryptAsym(m []byte, k Key, p Pair) ([]byte, error) { + err := checkDecryptAsym(m, k, p) + + if err != nil { + return nil, err + } + + n := newNonce(m) + + b := make([]byte, 0, len(m)-overhead) + + ret, ok := box.Open(b, m[nonceSize:], n.Bytes(), k.Bytes(), p.private().Bytes()) + + if !ok { + return nil, fmt.Errorf("Decryption failed.") + } + + return ret, nil +} diff --git a/asymetric_test.go b/asymetric_test.go new file mode 100644 index 0000000..f6ae84e --- /dev/null +++ b/asymetric_test.go @@ -0,0 +1,225 @@ +package ezcrypt + +import ( + "bytes" + "crypto/rand" + "io" + "testing" +) + +var ( + nilPriv = &pair{ + priv: nil, + } + nilPrivData = &pair{ + priv: &badKey{}, + } +) + +func TestAsymetricEncrypt(t *testing.T) { + p, err := GeneratePair(rand.Reader) + + if err != nil { + t.Fatalf("GeneratePair Failed: %s", err) + } + + tcs := []struct { + t string + m []byte + k Key + p Pair + r io.Reader + ok bool + }{ + { + t: "ok", + m: m, + k: p.Public(), + p: p, + r: rand.Reader, + ok: true, + }, + { + t: "nil message", + k: p.Public(), + p: p, + r: rand.Reader, + }, + { + t: "nil public", + m: m, + p: p, + r: rand.Reader, + }, + { + t: "bad public", + m: m, + k: &badKey{}, + p: p, + r: rand.Reader, + }, + { + t: "nil pair", + m: m, + k: p.Public(), + r: rand.Reader, + }, + { + t: "nil private", + m: m, + k: p.Public(), + p: nilPriv, + r: rand.Reader, + }, + { + t: "nil private data", + m: m, + k: p.Public(), + p: nilPrivData, + r: rand.Reader, + }, + { + t: "nil rand", + m: m, + k: p.Public(), + p: p, + }, + } + + for _, tc := range tcs { + ct, err := encryptAsym(tc.m, tc.k, tc.p, tc.r) + + if tc.ok { + if err != nil { + t.Errorf("Error: %s : %s", tc.t, err) + } else { + if bytes.Equal(m, ct) { + t.Fatalf("WTF?") + } + } + } else { + if err == nil { + t.Errorf("No Error: %s", tc.t) + } + } + } +} + +func TestAsymetricDecrypt(t *testing.T) { + ap, err := GeneratePair(rand.Reader) + + if err != nil { + t.Fatalf("GeneratePair Failed: %s", err) + } + + bp, err := GeneratePair(rand.Reader) + + if err != nil { + t.Fatalf("GeneratePair Failed: %s", err) + } + + cp, err := GeneratePair(rand.Reader) + + if err != nil { + t.Fatalf("GeneratePair Failed: %s", err) + } + + ct, err := encryptAsym(m, ap.Public(), bp, rand.Reader) + + if err != nil { + t.Errorf("Encrypt failed: %s", err) + } + + if bytes.Equal(m, ct) { + t.Fatalf("WTF?") + } + + tcs := []struct { + t string + m []byte + k Key + p Pair + ok bool + }{ + { + t: "ok", + m: ct, + k: bp.Public(), + p: ap, + ok: true, + }, + { + t: "nil message", + k: bp.Public(), + p: ap, + }, + { + t: "short message", + m: []byte("short"), + k: bp.Public(), + p: ap, + }, + { + t: "nil public", + m: ct, + p: ap, + }, + { + t: "nil public data", + m: ct, + k: &badKey{}, + p: ap, + }, + { + t: "nil pair", + m: ct, + k: bp.Public(), + }, + { + t: "nil private", + m: ct, + k: bp.Public(), + p: nilPriv, + }, + { + t: "nil private data", + m: ct, + k: bp.Public(), + p: nilPrivData, + }, + { + t: "wrong private", + m: ct, + k: bp.Public(), + p: cp, + }, + { + t: "wrong public", + m: ct, + k: cp.Public(), + p: ap, + }, + } + + for _, tc := range tcs { + d, err := decryptAsym(tc.m, tc.k, tc.p) + + if tc.ok { + if err != nil { + t.Errorf("Error: %s : %s", tc.t, err) + } else { + if !bytes.Equal(m, d) { + t.Errorf("Decrypt Failed: %s : %s : %s", tc.t, m, d) + } + } + } else { + if err == nil { + t.Errorf("No error: %s", tc.t) + } else { + if bytes.Equal(m, d) { + t.Errorf("WTF?: %s", tc.t) + } + } + } + } +} diff --git a/keys.go b/keys.go index e005616..4f9dedb 100644 --- a/keys.go +++ b/keys.go @@ -30,12 +30,14 @@ type Key interface { // Abstracts a public private key pair. type Pair interface { Public() Key + Encrypt(data []byte, dest Key, in io.Reader) ([]byte, error) + Decrypt(data []byte, source Key) ([]byte, error) Store(public, private string) error private() Key } // Constructs a new key pair. -func NewPair(rand io.Reader) (Pair, error) { +func GeneratePair(rand io.Reader) (Pair, error) { var err error publicKey, privateKey, err := box.GenerateKey(rand) @@ -111,6 +113,14 @@ func (p *pair) private() Key { return p.priv } +func (p *pair) Encrypt(data []byte, dest Key, in io.Reader) ([]byte, error) { + return encryptAsym(data, dest, p, in) +} + +func (p *pair) Decrypt(data []byte, source Key) ([]byte, error) { + return decryptAsym(data, source, p) +} + func (p *pair) Store(public, private string) error { err := writeKey(p.pub, public) diff --git a/keys_test.go b/keys_test.go index 5caf0a1..403c127 100644 --- a/keys_test.go +++ b/keys_test.go @@ -87,15 +87,19 @@ func TestKey(t *testing.T) { t.Fatalf("Loaded invalid key!") } - bytes, err := k.Encrypt([]byte(pubFile), rand.Reader) + ct, err := k.Encrypt(m, rand.Reader) if err != nil { t.Fatalf("Failed to Encrypt: %s", err) } - dec, err := k2.Decrypt(bytes) + if bytes.Equal(m, ct) { + t.Fatalf("WTF!") + } + + dec, err := k2.Decrypt(ct) - if pubFile != string(dec) { + if !bytes.Equal(m, dec) { t.Fatalf("Decrypt failed: expected: %s != %s", pubFile, dec) } @@ -173,18 +177,34 @@ func TestPair(t *testing.T) { t.Fatalf("Failed to write invalid key for test.") } - _, err = NewPair(&errReader{}) + _, err = GeneratePair(&errReader{}) if err == nil { t.Fatalf("Failed to bork with bad rand source.") } - pair, err := NewPair(rand.Reader) + pair, err := GeneratePair(rand.Reader) if err != nil { t.Fatalf("Failed to make pair: %s", err) } + ct, err := pair.Encrypt(m, pair.Public(), rand.Reader) + + if err != nil { + t.Fatalf("Encrypt failed: %s", err) + } + + dt, err := pair.Decrypt(ct, pair.Public()) + + if err != nil { + t.Fatalf("Decrypt failed: %s", err) + } + + if !bytes.Equal(m, dt) { + t.Fatalf("Decrypted message wrong: %s", dt) + } + err = pair.Store(temp, priv) if err == nil { diff --git a/nonce.go b/nonce.go index 580d54f..61dc499 100644 --- a/nonce.go +++ b/nonce.go @@ -22,21 +22,19 @@ func (n *nonce) Slice() []byte { return n.d[:] } -func newNonce(m []byte) (*nonce, error) { - if len(m) < nonceSize { - return nil, fmt.Errorf("Not enough data.") - } +// assumes that m is long enough, which all callers currently do checks for +func newNonce(m []byte) *nonce { n := &nonce{d: new([nonceSize]byte)} copy(n.Slice(), m[:nonceSize]) - return n, nil + return n } func generateNonce(in io.Reader) (*nonce, error) { if in == nil { - return nil, fmt.Errorf("No random source"); + return nil, fmt.Errorf("No random source") } n := &nonce{d: new([nonceSize]byte)} diff --git a/nonce_test.go b/nonce_test.go index 2532df4..85624e3 100644 --- a/nonce_test.go +++ b/nonce_test.go @@ -1,8 +1,8 @@ package ezcrypt import ( - "testing" "bytes" + "testing" ) var ( @@ -10,21 +10,9 @@ var ( ) func TestNonce(t *testing.T) { - _, err := newNonce([]byte("")) - - if err == nil { - t.Fatalf("Expected error"); - } - - b, err := newNonce(nonceBytes); + b := newNonce(nonceBytes) if !bytes.Equal(b.Slice(), nonceBytes) { t.Fatalf("Nonce mismatch.") } - - _, err = generateNonce(nil) - - if err == nil { - t.Fatalf("Expected error"); - } } diff --git a/symetric.go b/symetric.go index 06eb6af..0b149d9 100644 --- a/symetric.go +++ b/symetric.go @@ -9,7 +9,7 @@ import ( ) const ( - overhead = box.Overhead + nonceSize + overhead = box.Overhead + nonceSize ) func checkEncrypt(m []byte, k Key) error { @@ -33,10 +33,6 @@ func encrypt(m []byte, k Key, in io.Reader) ([]byte, error) { return nil, err } - if in == nil { - return nil, fmt.Errorf("No random source") - } - n, err := generateNonce(in) if err != nil { @@ -70,9 +66,8 @@ func decrypt(m []byte, k Key) ([]byte, error) { return nil, err } - // error is not possible here because checkDecrypt validated message length - n, _ := newNonce(m) - + n := newNonce(m) + out := make([]byte, 0, len(m)-overhead) ret, ok := secretbox.Open(out, m[nonceSize:], n.Bytes(), k.Bytes()) diff --git a/symetric_test.go b/symetric_test.go index d7e2564..6aa9734 100644 --- a/symetric_test.go +++ b/symetric_test.go @@ -3,9 +3,9 @@ package ezcrypt import ( "bytes" "crypto/rand" - "testing" - "io" "fmt" + "io" + "testing" ) type badKey struct{} @@ -19,15 +19,15 @@ func (*badKey) Slice() []byte { return []byte{} } -func (*badKey) Encrypt(data []byte, in io.Reader) ([]byte, error) { - return nil, fmt.Errorf("Bad Key!"); +func (*badKey) Encrypt(data []byte, in io.Reader) ([]byte, error) { + return nil, fmt.Errorf("Bad Key!") } -func (*badKey) Decrypt(data []byte) ([]byte, error) { - return nil, fmt.Errorf("Bad Key!"); +func (*badKey) Decrypt(data []byte) ([]byte, error) { + return nil, fmt.Errorf("Bad Key!") } func (*badKey) Store(file string) error { - return fmt.Errorf("Bad Key!"); + return fmt.Errorf("Bad Key!") } var (