diff --git a/README.md b/README.md index ae444e5ee..7eaa944f4 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,8 @@ Features: * [node-endpoint-tcp-server](examples/node-endpoint-tcp-server/main.go) * [node-endpoint-tcp-client](examples/node-endpoint-tcp-client/main.go) * [node-endpoint-custom](examples/node-endpoint-custom/main.go) +* [node-endpoint-custom-client](examples/node-endpoint-custom-client/main.go) +* [node-endpoint-custom-server](examples/node-endpoint-custom-server/main.go) * [node-message-read](examples/node-message-read/main.go) * [node-message-write](examples/node-message-write/main.go) * [node-signature](examples/node-signature/main.go) diff --git a/endpoint_client.go b/endpoint_client.go index 06b688997..a7431f8ef 100644 --- a/endpoint_client.go +++ b/endpoint_client.go @@ -2,6 +2,7 @@ package gomavlib import ( "context" + "errors" "fmt" "io" "net" @@ -13,7 +14,6 @@ import ( var reconnectPeriod = 2 * time.Second type endpointClientConf interface { - isUDP() bool getAddress() string init(*Node) (Endpoint, error) } @@ -27,10 +27,6 @@ type EndpointTCPClient struct { Address string } -func (EndpointTCPClient) isUDP() bool { - return false -} - func (conf EndpointTCPClient) getAddress() string { return conf.Address } @@ -50,10 +46,6 @@ type EndpointUDPClient struct { Address string } -func (EndpointUDPClient) isUDP() bool { - return true -} - func (conf EndpointUDPClient) getAddress() string { return conf.Address } @@ -67,6 +59,30 @@ func (conf EndpointUDPClient) init(node *Node) (Endpoint, error) { return e, err } +// EndpointCustomClient sets up a endpoint that works with a custom implementation +// by providing a Connect func that returns a net.Conn. +type EndpointCustomClient struct { + // domain name or IP of the server to connect to, example: 1.2.3.4:5600 + Address string + // custom connect function that connects to the provided address + Connect func(address string) (net.Conn, error) + // the label of the protocol + Label string +} + +func (conf EndpointCustomClient) getAddress() string { + return conf.Address +} + +func (conf EndpointCustomClient) init(node *Node) (Endpoint, error) { + e := &endpointClient{ + node: node, + conf: conf, + } + err := e.initialize() + return e, err +} + type endpointClient struct { node *Node conf endpointClientConf @@ -103,16 +119,31 @@ func (e *endpointClient) oneChannelAtAtime() bool { func (e *endpointClient) connect() (io.ReadWriteCloser, error) { network := func() string { - if e.conf.isUDP() { + switch e.conf.(type) { + case EndpointTCPClient: + return "tcp4" + case EndpointUDPClient: return "udp4" + case EndpointCustomClient: + return "cust" + default: + return "" } - return "tcp4" }() // in UDP, the only possible error is a DNS failure // in TCP, the handshake must be completed timedContext, timedContextClose := context.WithTimeout(e.ctx, e.node.ReadTimeout) - nconn, err := (&net.Dialer{}).DialContext(timedContext, network, e.conf.getAddress()) + nconn, err := func() (net.Conn, error) { + if network == "cust" { + customConf := e.conf.(EndpointCustomClient) + if customConf.Connect == nil { + return nil, errors.New("no connect function provided on custom endpoint") + } + return customConf.Connect(customConf.Address) + } + return (&net.Dialer{}).DialContext(timedContext, network, e.conf.getAddress()) + }() timedContextClose() if err != nil { @@ -154,9 +185,18 @@ func (e *endpointClient) provide() (string, io.ReadWriteCloser, error) { func (e *endpointClient) label() string { return fmt.Sprintf("%s:%s", func() string { - if e.conf.isUDP() { + switch conf := e.conf.(type) { + case EndpointTCPClient: + return "tcp" + case EndpointUDPClient: return "udp" + case EndpointCustomClient: + if conf.Label != "" { + return conf.Label + } + return "custom" + default: + return "unknown" } - return "tcp" }(), e.conf.getAddress()) } diff --git a/endpoint_server.go b/endpoint_server.go index 58567abe8..bf16b73d4 100644 --- a/endpoint_server.go +++ b/endpoint_server.go @@ -1,17 +1,16 @@ package gomavlib import ( + "errors" "fmt" "io" "net" - "github.com/pion/transport/v2/udp" - "github.com/bluenviron/gomavlib/v3/pkg/timednetconn" + "github.com/pion/transport/v2/udp" ) type endpointServerConf interface { - isUDP() bool getAddress() string init(*Node) (Endpoint, error) } @@ -25,10 +24,6 @@ type EndpointTCPServer struct { Address string } -func (EndpointTCPServer) isUDP() bool { - return false -} - func (conf EndpointTCPServer) getAddress() string { return conf.Address } @@ -50,10 +45,6 @@ type EndpointUDPServer struct { Address string } -func (EndpointUDPServer) isUDP() bool { - return true -} - func (conf EndpointUDPServer) getAddress() string { return conf.Address } @@ -67,6 +58,32 @@ func (conf EndpointUDPServer) init(node *Node) (Endpoint, error) { return e, err } +// EndpointCustomServer sets up a endpoint that works with custom implementations +// by providing a custom Listen func that returns a net.Listener. +// This allows you to use custom protocols that conform to the net.listner. +// A use case could be to add encrypted protocol implementations like DTLS or TCP with TLS. +type EndpointCustomServer struct { + // listen address, example: 0.0.0.0:5600 + Address string + // function to invoke when server should start listening + Listen func(address string) (net.Listener, error) + // the label of the protocol + Label string +} + +func (conf EndpointCustomServer) getAddress() string { + return conf.Address +} + +func (conf EndpointCustomServer) init(node *Node) (Endpoint, error) { + e := &endpointServer{ + node: node, + conf: conf, + } + err := e.initialize() + return e, err +} + type endpointServer struct { node *Node conf endpointServerConf @@ -83,7 +100,8 @@ func (e *endpointServer) initialize() error { return fmt.Errorf("invalid address") } - if e.conf.isUDP() { + switch conf := e.conf.(type) { + case EndpointUDPServer: var addr *net.UDPAddr addr, err = net.ResolveUDPAddr("udp4", e.conf.getAddress()) if err != nil { @@ -94,11 +112,21 @@ func (e *endpointServer) initialize() error { if err != nil { return err } - } else { + + case EndpointTCPServer: e.listener, err = net.Listen("tcp4", e.conf.getAddress()) if err != nil { return err } + + case EndpointCustomServer: + e.listener, err = conf.Listen(e.conf.getAddress()) + if err != nil { + return err + } + + default: + return errors.New("unsupported server-type") } e.terminate = make(chan struct{}) @@ -130,10 +158,19 @@ func (e *endpointServer) provide() (string, io.ReadWriteCloser, error) { } label := fmt.Sprintf("%s:%s", func() string { - if e.conf.isUDP() { + switch conf := e.conf.(type) { + case EndpointTCPServer: + return "tcp" + case EndpointUDPServer: return "udp" + case EndpointCustomServer: + if conf.Label != "" { + return conf.Label + } + return "custom" + default: + return "unknown" } - return "tcp" }(), nconn.RemoteAddr()) conn := timednetconn.New( diff --git a/examples/node-endpoint-custom-client/main.go b/examples/node-endpoint-custom-client/main.go new file mode 100644 index 000000000..0422fb170 --- /dev/null +++ b/examples/node-endpoint-custom-client/main.go @@ -0,0 +1,50 @@ +// Package main contains an example. +package main + +import ( + "crypto/tls" + "log" + "net" + + "github.com/bluenviron/gomavlib/v3" + "github.com/bluenviron/gomavlib/v3/pkg/dialects/ardupilotmega" +) + +// this example shows how to: +// 1) create a node which communicates with a custom TCP/TLS endpoint in client mode. +// 2) print incoming messages. + +func main() { + // create a node which communicates with a TCP endpoint in client mode + node := &gomavlib.Node{ + Endpoints: []gomavlib.EndpointConf{ + gomavlib.EndpointCustomClient{ + Address: "127.0.0.1:5600", + Connect: func(address string) (net.Conn, error) { + tlsConfig := &tls.Config{ + // skip checking the certificate against a CA (just set to true for simplicity of this example) + InsecureSkipVerify: true, + } + + return tls.Dial("tcp", address, tlsConfig) + }, + Label: "TCP/TLS", + }, + }, + Dialect: ardupilotmega.Dialect, + OutVersion: gomavlib.V2, // change to V1 if you're unable to communicate with the target + OutSystemID: 10, + } + err := node.Initialize() + if err != nil { + panic(err) + } + defer node.Close() + + // print incoming messages + for evt := range node.Events() { + if frm, ok := evt.(*gomavlib.EventFrame); ok { + log.Printf("received: id=%d, %+v\n", frm.Message().GetID(), frm.Message()) + } + } +} diff --git a/examples/node-endpoint-custom-server/.gitignore b/examples/node-endpoint-custom-server/.gitignore new file mode 100644 index 000000000..b2290143a --- /dev/null +++ b/examples/node-endpoint-custom-server/.gitignore @@ -0,0 +1 @@ +certs diff --git a/examples/node-endpoint-custom-server/main.go b/examples/node-endpoint-custom-server/main.go new file mode 100644 index 000000000..97724d82d --- /dev/null +++ b/examples/node-endpoint-custom-server/main.go @@ -0,0 +1,153 @@ +// Package main contains an example. +package main + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "log" + "math/big" + "net" + "os" + "time" + + "github.com/bluenviron/gomavlib/v3" + "github.com/bluenviron/gomavlib/v3/pkg/dialects/ardupilotmega" +) + +// this example shows how to: +// 1) create a node which communicates with a custom TCP/TLS endpoint in server mode. +// 2) print incoming messages. + +func main() { + // ensure the certificate and key exists + if err := EnsureCertsExist(); err != nil { + fmt.Println("Error ensuring certificates:", err) + return + } + + // create a node which communicates with a custom TCP/TLS endpoint in server mode + node := &gomavlib.Node{ + Endpoints: []gomavlib.EndpointConf{ + gomavlib.EndpointCustomServer{ + Address: ":5600", + Listen: func(address string) (net.Listener, error) { + // Loads the certificate and key from the generated certs dir + cert, err := tls.LoadX509KeyPair("certs/cert.pem", "certs/key.pem") + if err != nil { + return nil, err + } + + return tls.Listen("tcp", address, &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + }, + Label: "TCP/TLS", + }, + }, + Dialect: ardupilotmega.Dialect, + OutVersion: gomavlib.V2, // change to V1 if you're unable to communicate with the target + OutSystemID: 10, + } + err := node.Initialize() + if err != nil { + panic(err) + } + defer node.Close() + + // print incoming messages + for evt := range node.Events() { + if frm, ok := evt.(*gomavlib.EventFrame); ok { + log.Printf("received: id=%d, %+v\n", frm.Message().GetID(), frm.Message()) + } + } +} + +// Below are just functions to check and generate certificate and private key +// they are just here to make this example simpler to run + +// EnsureCertsExist checks if the cert.pem and key.pem exist in the certs directory, +// and if not, generates them. +func EnsureCertsExist() error { + // Check if cert.pem exists + if _, err := os.Stat("certs/cert.pem"); os.IsNotExist(err) { + fmt.Println("cert.pem not found. Generating certificates...") + return GenerateCertAndKey() + } + + // Check if key.pem exists + if _, err := os.Stat("certs/key.pem"); os.IsNotExist(err) { + fmt.Println("key.pem not found. Generating certificates...") + return GenerateCertAndKey() + } + + return nil +} + +// GenerateCertAndKey generates a self-signed certificate and private key, saving them to the certs/ directory. +func GenerateCertAndKey() error { + // Create the certs directory if it doesn't exist + err := os.MkdirAll("certs", os.ModePerm) + if err != nil { + return fmt.Errorf("failed to create certs directory: %w", err) + } + + // Generate RSA private key + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return fmt.Errorf("failed to generate private key: %w", err) + } + + // Create certificate template + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "gomavlib", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), // valid for 1 year + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + x509.ExtKeyUsageClientAuth, + }, + BasicConstraintsValid: true, + } + + // Create the certificate + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return fmt.Errorf("failed to create certificate: %w", err) + } + + // Save the certificate + certOut, err := os.Create("certs/cert.pem") + if err != nil { + return fmt.Errorf("failed to create cert.pem: %w", err) + } + defer certOut.Close() + + err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + if err != nil { + return fmt.Errorf("failed to encode certificate to PEM: %w", err) + } + + // Save the private key + keyOut, err := os.Create("certs/key.pem") + if err != nil { + return fmt.Errorf("failed to create key.pem: %w", err) + } + defer keyOut.Close() + + err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + if err != nil { + return fmt.Errorf("failed to encode private key to PEM: %w", err) + } + + fmt.Println("cert.pem and key.pem generated in the 'certs/' directory.") + return nil +}