Skip to content

Commit

Permalink
added an API to export the remote public key
Browse files Browse the repository at this point in the history
  • Loading branch information
mimoo committed Oct 8, 2018
1 parent caf14f3 commit 660a491
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 9 deletions.
43 changes: 40 additions & 3 deletions libdisco/apis.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,33 @@ func Client(conn net.Conn, config *Config) *Conn {
}

// A listener implements a network listener (net.Listener) for Disco connections.
type listener struct {
type Listener struct {
net.Listener
config *Config
}

// Accept waits for and returns the next incoming Disco connection.
// The returned connection is of type *Conn.
func (l *listener) Accept() (net.Conn, error) {
func (l *Listener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return Server(c, l.config), nil
}

// Accept waits for and returns the next incoming Disco connection.
// The returned connection is of type *Conn.
func (l *Listener) AcceptDisco() (*Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
conn := Server(c, l.config)
conn.Write([]byte{})
return conn, nil
}

// Listen creates a Disco listener accepting connections on the
// given network address using net.Listen.
// The configuration config must be non-nil.
Expand All @@ -67,7 +79,32 @@ func Listen(network, laddr string, config *Config) (net.Listener, error) {
}

// create new libdisco.listener
discoListener := new(listener)
discoListener := new(Listener)
discoListener.Listener = l
discoListener.config = config
return discoListener, nil
}

// ListenDisco creates a Disco listener accepting connections on the
// given network address using net.Listen.
// The configuration config must be non-nil.
func ListenDisco(network, laddr string, config *Config) (*Listener, error) {
// check Config
if config == nil {
return nil, errors.New("Disco: no Config set")
}
if err := checkRequirements(false, config); err != nil {
panic(err)
}

// make net.Conn listen
l, err := net.Listen(network, laddr)
if err != nil {
return nil, err
}

// create new libdisco.listener
discoListener := new(Listener)
discoListener.Listener = l
discoListener.config = config
return discoListener, nil
Expand Down
17 changes: 15 additions & 2 deletions libdisco/conn.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package libdisco

import (
"encoding/hex"
"errors"
"io"
"net"
Expand All @@ -18,12 +19,12 @@ type Conn struct {

// handshake
config *Config // configuration passed to constructor
hs handshakeState
handshakeComplete bool
handshakeMutex sync.Mutex

// Authentication thingies
isRemoteAuthenticated bool
remotePublicKey string

// input/output
in, out *strobe.Strobe
Expand Down Expand Up @@ -323,6 +324,9 @@ ContinueHandshake:
if !c.config.PublicKeyVerifier(hs.rs.PublicKey[:], receivedPayload) {
return errors.New("disco: the received public key could not be authenticated")
}
// authenticated!
c.isRemoteAuthenticated = true
c.remotePublicKey = hex.EncodeToString(hs.rs.PublicKey[:]) // so that it can be accessed later
}
}

Expand All @@ -343,7 +347,7 @@ ContinueHandshake:

// TODO: preserve c.hs.symmetricState.h
// At that point the HandshakeState should be deleted except for the hash value h, which may be used for post-handshake channel binding (see Section 11.2).
c.hs.clear()
hs.clear()

// no errors :)
c.handshakeComplete = true
Expand All @@ -355,6 +359,15 @@ func (c *Conn) IsRemoteAuthenticated() bool {
return c.isRemoteAuthenticated
}

// RemotePublicKey returns the static key of the remote peer. It is useful in case the
// static key is only transmitted during the handshake.
func (c *Conn) RemotePublicKey() (string, error) {
if !c.handshakeComplete {
return "", errors.New("disco: handshake not completed")
}
return c.remotePublicKey, nil
}

//
// input/output functions
//
Expand Down
4 changes: 0 additions & 4 deletions libdisco/disco.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,6 @@ func (h *handshakeState) readMessage(message []byte, payloadBuffer *[]byte) (c1,
if err != nil {
return
}
// if we already know the remote static, compare
copy(h.rs.PublicKey[:], plaintext)
offset += dhLen + tagLen

Expand Down Expand Up @@ -403,7 +402,4 @@ func (kp *KeyPair) clear() {
for i := 0; i < len(kp.PrivateKey); i++ {
kp.PrivateKey[i] = 0
}
for i := 0; i < len(kp.PublicKey); i++ {
kp.PublicKey[i] = 0
}
}

0 comments on commit 660a491

Please sign in to comment.