From 9d1c771508e49a37307e6d5e7edeafa9b99a8004 Mon Sep 17 00:00:00 2001 From: Philip Zingmark Date: Sun, 6 Apr 2025 10:51:14 +0200 Subject: [PATCH 1/8] Added Custom client and server endpoints --- endpoint_client.go | 61 +++++++++++++++++++++++++++++++------ endpoint_server.go | 76 +++++++++++++++++++++++++++++++++++++++------- 2 files changed, 117 insertions(+), 20 deletions(-) diff --git a/endpoint_client.go b/endpoint_client.go index 06b688997..938a4b2c4 100644 --- a/endpoint_client.go +++ b/endpoint_client.go @@ -2,6 +2,7 @@ package gomavlib import ( "context" + "errors" "fmt" "io" "net" @@ -13,7 +14,7 @@ import ( var reconnectPeriod = 2 * time.Second type endpointClientConf interface { - isUDP() bool + clientType() EndpointServerType getAddress() string init(*Node) (Endpoint, error) } @@ -27,8 +28,8 @@ type EndpointTCPClient struct { Address string } -func (EndpointTCPClient) isUDP() bool { - return false +func (EndpointTCPClient) clientType() EndpointServerType { + return EndpointServerType_TCP } func (conf EndpointTCPClient) getAddress() string { @@ -50,8 +51,8 @@ type EndpointUDPClient struct { Address string } -func (EndpointUDPClient) isUDP() bool { - return true +func (EndpointUDPClient) clientType() EndpointServerType { + return EndpointServerType_UDP } func (conf EndpointUDPClient) getAddress() string { @@ -67,6 +68,30 @@ func (conf EndpointUDPClient) init(node *Node) (Endpoint, error) { return e, err } +type EndpointCustomClient struct { + // domain name or IP of the server to connect to, example: 1.2.3.4:5600 + Address string + // custom connect function that connects to the provided address + Connect func(address string) (net.Conn, error) +} + +func (EndpointCustomClient) clientType() EndpointServerType { + return EndpointServerType_CUSTOM +} + +func (conf EndpointCustomClient) getAddress() string { + return conf.Address +} + +func (conf EndpointCustomClient) init(node *Node) (Endpoint, error) { + e := &endpointClient{ + node: node, + conf: conf, + } + err := e.initialize() + return e, err +} + type endpointClient struct { node *Node conf endpointClientConf @@ -103,16 +128,34 @@ func (e *endpointClient) oneChannelAtAtime() bool { func (e *endpointClient) connect() (io.ReadWriteCloser, error) { network := func() string { - if e.conf.isUDP() { + switch e.conf.clientType() { + case EndpointServerType_TCP: + return "tcp4" + case EndpointServerType_UDP: return "udp4" + case EndpointServerType_CUSTOM: + return "cust" + default: + return "" } - return "tcp4" }() // in UDP, the only possible error is a DNS failure // in TCP, the handshake must be completed timedContext, timedContextClose := context.WithTimeout(e.ctx, e.node.ReadTimeout) - nconn, err := (&net.Dialer{}).DialContext(timedContext, network, e.conf.getAddress()) + nconn, err := func() (net.Conn, error) { + if network == "cust" { + if customConf, ok := e.conf.(EndpointCustomClient); ok { + if customConf.Connect == nil { + return nil, errors.New("no connect function provided on custom endpoint") + } + return customConf.Connect(customConf.Address) + } else { + return nil, errors.New("failed type assertion to endpointcustomclient") + } + } + return (&net.Dialer{}).DialContext(timedContext, network, e.conf.getAddress()) + }() timedContextClose() if err != nil { @@ -154,7 +197,7 @@ func (e *endpointClient) provide() (string, io.ReadWriteCloser, error) { func (e *endpointClient) label() string { return fmt.Sprintf("%s:%s", func() string { - if e.conf.isUDP() { + if e.conf.clientType() == EndpointServerType_UDP { return "udp" } return "tcp" diff --git a/endpoint_server.go b/endpoint_server.go index 58567abe8..2ff3a1214 100644 --- a/endpoint_server.go +++ b/endpoint_server.go @@ -1,17 +1,26 @@ package gomavlib import ( + "errors" "fmt" "io" "net" + "github.com/bluenviron/gomavlib/v3/pkg/timednetconn" "github.com/pion/transport/v2/udp" +) - "github.com/bluenviron/gomavlib/v3/pkg/timednetconn" +// Enum for specifying what protocol the endpoint uses +type EndpointServerType int + +const ( + EndpointServerType_TCP EndpointServerType = iota + EndpointServerType_UDP + EndpointServerType_CUSTOM ) type endpointServerConf interface { - isUDP() bool + serverType() EndpointServerType getAddress() string init(*Node) (Endpoint, error) } @@ -25,8 +34,8 @@ type EndpointTCPServer struct { Address string } -func (EndpointTCPServer) isUDP() bool { - return false +func (EndpointTCPServer) serverType() EndpointServerType { + return EndpointServerType_TCP } func (conf EndpointTCPServer) getAddress() string { @@ -50,8 +59,8 @@ type EndpointUDPServer struct { Address string } -func (EndpointUDPServer) isUDP() bool { - return true +func (EndpointUDPServer) serverType() EndpointServerType { + return EndpointServerType_UDP } func (conf EndpointUDPServer) getAddress() string { @@ -67,6 +76,31 @@ func (conf EndpointUDPServer) init(node *Node) (Endpoint, error) { return e, err } +// EndpointCustomServer sets up a endpoint that works with a provided net.listner. +// This allows you to use custom protocols that conform to the net.listner. +// A use case could be to add encrypted protocol implementations like DTLS or TCP with TLS. +type EndpointCustomServer struct { + Listener net.Listener + Label string +} + +func (EndpointCustomServer) serverType() EndpointServerType { + return EndpointServerType_CUSTOM +} + +func (conf EndpointCustomServer) getAddress() string { + return conf.Listener.Addr().String() +} + +func (conf EndpointCustomServer) init(node *Node) (Endpoint, error) { + e := &endpointServer{ + node: node, + conf: conf, + } + err := e.initialize() + return e, err +} + type endpointServer struct { node *Node conf endpointServerConf @@ -83,7 +117,8 @@ func (e *endpointServer) initialize() error { return fmt.Errorf("invalid address") } - if e.conf.isUDP() { + switch e.conf.serverType() { + case EndpointServerType_UDP: var addr *net.UDPAddr addr, err = net.ResolveUDPAddr("udp4", e.conf.getAddress()) if err != nil { @@ -92,13 +127,21 @@ func (e *endpointServer) initialize() error { e.listener, err = udp.Listen("udp4", addr) if err != nil { - return err + return fmt.Errorf("error starting UDP listener: %v", err) } - } else { + case EndpointServerType_TCP: e.listener, err = net.Listen("tcp4", e.conf.getAddress()) if err != nil { return err } + case EndpointServerType_CUSTOM: + if customConf, ok := e.conf.(EndpointCustomServer); ok { + e.listener = customConf.Listener + } else { + return errors.New("type assertion error to endpointcustomserver") + } + default: + return errors.New("unsupported server-type") } e.terminate = make(chan struct{}) @@ -130,10 +173,21 @@ func (e *endpointServer) provide() (string, io.ReadWriteCloser, error) { } label := fmt.Sprintf("%s:%s", func() string { - if e.conf.isUDP() { + switch e.conf.serverType() { + case EndpointServerType_TCP: + return "tcp" + case EndpointServerType_UDP: return "udp" + case EndpointServerType_CUSTOM: + if customConf, ok := e.conf.(EndpointCustomServer); ok { + if customConf.Label != "" { + return customConf.Label + } + } + return "cust" + default: + return "unk" } - return "tcp" }(), nconn.RemoteAddr()) conn := timednetconn.New( From f894b00bd439d62da1796c31ef4f6cd6423647f1 Mon Sep 17 00:00:00 2001 From: Philip Zingmark Date: Sun, 6 Apr 2025 11:11:14 +0200 Subject: [PATCH 2/8] Make custom server use a provided listen func so it connects when initialize is called instead of connecting before creating the endpoint --- endpoint_client.go | 19 +++++++++++++++++-- endpoint_server.go | 20 ++++++++++++++------ 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/endpoint_client.go b/endpoint_client.go index 938a4b2c4..a14475772 100644 --- a/endpoint_client.go +++ b/endpoint_client.go @@ -68,11 +68,15 @@ func (conf EndpointUDPClient) init(node *Node) (Endpoint, error) { return e, err } +// EndpointCustomClient sets up a endpoint that works with a custom implementation +// by providing a Connect func that returns a net.Conn. type EndpointCustomClient struct { // domain name or IP of the server to connect to, example: 1.2.3.4:5600 Address string // custom connect function that connects to the provided address Connect func(address string) (net.Conn, error) + // the label of the protocol + Label string } func (EndpointCustomClient) clientType() EndpointServerType { @@ -197,9 +201,20 @@ func (e *endpointClient) provide() (string, io.ReadWriteCloser, error) { func (e *endpointClient) label() string { return fmt.Sprintf("%s:%s", func() string { - if e.conf.clientType() == EndpointServerType_UDP { + switch e.conf.clientType() { + case EndpointServerType_TCP: + return "tcp" + case EndpointServerType_UDP: return "udp" + case EndpointServerType_CUSTOM: + if customConf, ok := e.conf.(EndpointCustomClient); ok { + if customConf.Label != "" { + return customConf.Label + } + } + return "cust" + default: + return "unk" } - return "tcp" }(), e.conf.getAddress()) } diff --git a/endpoint_server.go b/endpoint_server.go index 2ff3a1214..542abd334 100644 --- a/endpoint_server.go +++ b/endpoint_server.go @@ -76,12 +76,17 @@ func (conf EndpointUDPServer) init(node *Node) (Endpoint, error) { return e, err } -// EndpointCustomServer sets up a endpoint that works with a provided net.listner. +// EndpointCustomServer sets up a endpoint that works with custom implementations +// by providing a custom Listen func that returns a net.Listener. // This allows you to use custom protocols that conform to the net.listner. // A use case could be to add encrypted protocol implementations like DTLS or TCP with TLS. type EndpointCustomServer struct { - Listener net.Listener - Label string + // listen address, example: 0.0.0.0:5600 + Address string + // function to invoke when server should start listening + Listen func(address string) (net.Listener, error) + // the label of the protocol + Label string } func (EndpointCustomServer) serverType() EndpointServerType { @@ -89,7 +94,7 @@ func (EndpointCustomServer) serverType() EndpointServerType { } func (conf EndpointCustomServer) getAddress() string { - return conf.Listener.Addr().String() + return conf.Address } func (conf EndpointCustomServer) init(node *Node) (Endpoint, error) { @@ -127,7 +132,7 @@ func (e *endpointServer) initialize() error { e.listener, err = udp.Listen("udp4", addr) if err != nil { - return fmt.Errorf("error starting UDP listener: %v", err) + return err } case EndpointServerType_TCP: e.listener, err = net.Listen("tcp4", e.conf.getAddress()) @@ -136,7 +141,10 @@ func (e *endpointServer) initialize() error { } case EndpointServerType_CUSTOM: if customConf, ok := e.conf.(EndpointCustomServer); ok { - e.listener = customConf.Listener + e.listener, err = customConf.Listen(e.conf.getAddress()) + if err != nil { + return err + } } else { return errors.New("type assertion error to endpointcustomserver") } From 439b37600d3811e1787a535150190f054e4b3235 Mon Sep 17 00:00:00 2001 From: Philip Zingmark Date: Sun, 6 Apr 2025 11:30:31 +0200 Subject: [PATCH 3/8] add examples for tcp/tls custom client and server --- examples/node-endpoint-custom-client/main.go | 54 +++++++ .../node-endpoint-custom-server/.gitignore | 1 + examples/node-endpoint-custom-server/main.go | 152 ++++++++++++++++++ 3 files changed, 207 insertions(+) create mode 100644 examples/node-endpoint-custom-client/main.go create mode 100644 examples/node-endpoint-custom-server/.gitignore create mode 100644 examples/node-endpoint-custom-server/main.go diff --git a/examples/node-endpoint-custom-client/main.go b/examples/node-endpoint-custom-client/main.go new file mode 100644 index 000000000..cb685583e --- /dev/null +++ b/examples/node-endpoint-custom-client/main.go @@ -0,0 +1,54 @@ +package main + +import ( + "crypto/tls" + "log" + "net" + + "github.com/bluenviron/gomavlib/v3" + "github.com/bluenviron/gomavlib/v3/pkg/dialects/ardupilotmega" +) + +// this example shows how to: +// 1) create a node which communicates with a custom TCP/TLS endpoint in client mode. +// 2) print incoming messages. + +func main() { + // create a node which communicates with a TCP endpoint in client mode + node := &gomavlib.Node{ + Endpoints: []gomavlib.EndpointConf{ + gomavlib.EndpointCustomClient{ + Address: "127.0.0.1:5600", + Connect: func(address string) (net.Conn, error) { + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, // skip checking the certificate against a CA (just set to true for simplicity of this example) + MinVersion: tls.VersionTLS12, + } + + conn, err := tls.Dial("tcp", address, tlsConfig) + if err != nil { + return nil, err + } + + return conn, nil + }, + Label: "TCP/TLS", + }, + }, + Dialect: ardupilotmega.Dialect, + OutVersion: gomavlib.V2, // change to V1 if you're unable to communicate with the target + OutSystemID: 10, + } + err := node.Initialize() + if err != nil { + panic(err) + } + defer node.Close() + + // print incoming messages + for evt := range node.Events() { + if frm, ok := evt.(*gomavlib.EventFrame); ok { + log.Printf("received: id=%d, %+v\n", frm.Message().GetID(), frm.Message()) + } + } +} diff --git a/examples/node-endpoint-custom-server/.gitignore b/examples/node-endpoint-custom-server/.gitignore new file mode 100644 index 000000000..b2290143a --- /dev/null +++ b/examples/node-endpoint-custom-server/.gitignore @@ -0,0 +1 @@ +certs diff --git a/examples/node-endpoint-custom-server/main.go b/examples/node-endpoint-custom-server/main.go new file mode 100644 index 000000000..169b71a73 --- /dev/null +++ b/examples/node-endpoint-custom-server/main.go @@ -0,0 +1,152 @@ +package main + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "log" + "math/big" + "net" + "os" + "time" + + "github.com/bluenviron/gomavlib/v3" + "github.com/bluenviron/gomavlib/v3/pkg/dialects/ardupilotmega" +) + +// this example shows how to: +// 1) create a node which communicates with a custom TCP/TLS endpoint in server mode. +// 2) print incoming messages. + +func main() { + // ensure the certificate and key exists + if err := EnsureCertsExist(); err != nil { + fmt.Println("Error ensuring certificates:", err) + return + } + + // create a node which communicates with a custom TCP/TLS endpoint in server mode + node := &gomavlib.Node{ + Endpoints: []gomavlib.EndpointConf{ + gomavlib.EndpointCustomServer{ + Address: ":5600", + Listen: func(address string) (net.Listener, error) { + // Loads the certificate and key from the generated certs dir + cert, err := tls.LoadX509KeyPair("certs/cert.pem", "certs/key.pem") + if err != nil { + return nil, err + } + + return tls.Listen("tcp", address, &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + }, + Label: "TCP/TLS", + }, + }, + Dialect: ardupilotmega.Dialect, + OutVersion: gomavlib.V2, // change to V1 if you're unable to communicate with the target + OutSystemID: 10, + } + err := node.Initialize() + if err != nil { + panic(err) + } + defer node.Close() + + // print incoming messages + for evt := range node.Events() { + if frm, ok := evt.(*gomavlib.EventFrame); ok { + log.Printf("received: id=%d, %+v\n", frm.Message().GetID(), frm.Message()) + } + } +} + +// Below are just functions to check and generate certificate and private key +// they are just here to make this example simpler to run + +// EnsureCertsExist checks if the cert.pem and key.pem exist in the certs directory, +// and if not, generates them. +func EnsureCertsExist() error { + // Check if cert.pem exists + if _, err := os.Stat("certs/cert.pem"); os.IsNotExist(err) { + fmt.Println("cert.pem not found. Generating certificates...") + return GenerateCertAndKey() + } + + // Check if key.pem exists + if _, err := os.Stat("certs/key.pem"); os.IsNotExist(err) { + fmt.Println("key.pem not found. Generating certificates...") + return GenerateCertAndKey() + } + + return nil +} + +// GenerateCertAndKey generates a self-signed certificate and private key, saving them to the certs/ directory. +func GenerateCertAndKey() error { + // Create the certs directory if it doesn't exist + err := os.MkdirAll("certs", os.ModePerm) + if err != nil { + return fmt.Errorf("failed to create certs directory: %v", err) + } + + // Generate RSA private key + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return fmt.Errorf("failed to generate private key: %v", err) + } + + // Create certificate template + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "gomavlib", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), // valid for 1 year + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + x509.ExtKeyUsageClientAuth, + }, + BasicConstraintsValid: true, + } + + // Create the certificate + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return fmt.Errorf("failed to create certificate: %v", err) + } + + // Save the certificate + certOut, err := os.Create("certs/cert.pem") + if err != nil { + return fmt.Errorf("failed to create cert.pem: %v", err) + } + defer certOut.Close() + + err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + if err != nil { + return fmt.Errorf("failed to encode certificate to PEM: %v", err) + } + + // Save the private key + keyOut, err := os.Create("certs/key.pem") + if err != nil { + return fmt.Errorf("failed to create key.pem: %v", err) + } + defer keyOut.Close() + + err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + if err != nil { + return fmt.Errorf("failed to encode private key to PEM: %v", err) + } + + fmt.Println("cert.pem and key.pem generated in the 'certs/' directory.") + return nil +} From 97047d7c641d62a5631ac39d58926134b39559e3 Mon Sep 17 00:00:00 2001 From: Philip Zingmark Date: Sun, 6 Apr 2025 11:35:04 +0200 Subject: [PATCH 4/8] add examples to readme --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index ae444e5ee..7eaa944f4 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,8 @@ Features: * [node-endpoint-tcp-server](examples/node-endpoint-tcp-server/main.go) * [node-endpoint-tcp-client](examples/node-endpoint-tcp-client/main.go) * [node-endpoint-custom](examples/node-endpoint-custom/main.go) +* [node-endpoint-custom-client](examples/node-endpoint-custom-client/main.go) +* [node-endpoint-custom-server](examples/node-endpoint-custom-server/main.go) * [node-message-read](examples/node-message-read/main.go) * [node-message-write](examples/node-message-write/main.go) * [node-signature](examples/node-signature/main.go) From f896d87be01a1edf639f7df8a08f2bffc7427834 Mon Sep 17 00:00:00 2001 From: Philip Zingmark Date: Sun, 6 Apr 2025 20:17:47 +0200 Subject: [PATCH 5/8] improve client tcp/tls example --- examples/node-endpoint-custom-client/main.go | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/examples/node-endpoint-custom-client/main.go b/examples/node-endpoint-custom-client/main.go index cb685583e..01ad00d2d 100644 --- a/examples/node-endpoint-custom-client/main.go +++ b/examples/node-endpoint-custom-client/main.go @@ -22,15 +22,9 @@ func main() { Connect: func(address string) (net.Conn, error) { tlsConfig := &tls.Config{ InsecureSkipVerify: true, // skip checking the certificate against a CA (just set to true for simplicity of this example) - MinVersion: tls.VersionTLS12, } - conn, err := tls.Dial("tcp", address, tlsConfig) - if err != nil { - return nil, err - } - - return conn, nil + return tls.Dial("tcp", address, tlsConfig) }, Label: "TCP/TLS", }, From 15f12ea2ae4120d6ed2f5bb8e6babb3c05d6b604 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Thu, 16 Oct 2025 14:44:13 +0200 Subject: [PATCH 6/8] fix lint errors --- endpoint_client.go | 21 ++++++++------- endpoint_server.go | 27 ++++++++++---------- examples/node-endpoint-custom-client/main.go | 4 ++- examples/node-endpoint-custom-server/main.go | 15 ++++++----- 4 files changed, 35 insertions(+), 32 deletions(-) diff --git a/endpoint_client.go b/endpoint_client.go index a14475772..b8cd717d5 100644 --- a/endpoint_client.go +++ b/endpoint_client.go @@ -29,7 +29,7 @@ type EndpointTCPClient struct { } func (EndpointTCPClient) clientType() EndpointServerType { - return EndpointServerType_TCP + return EndpointServerTypeTCP } func (conf EndpointTCPClient) getAddress() string { @@ -52,7 +52,7 @@ type EndpointUDPClient struct { } func (EndpointUDPClient) clientType() EndpointServerType { - return EndpointServerType_UDP + return EndpointServerTypeUDP } func (conf EndpointUDPClient) getAddress() string { @@ -80,7 +80,7 @@ type EndpointCustomClient struct { } func (EndpointCustomClient) clientType() EndpointServerType { - return EndpointServerType_CUSTOM + return EndpointServerTypeCustom } func (conf EndpointCustomClient) getAddress() string { @@ -133,11 +133,11 @@ func (e *endpointClient) oneChannelAtAtime() bool { func (e *endpointClient) connect() (io.ReadWriteCloser, error) { network := func() string { switch e.conf.clientType() { - case EndpointServerType_TCP: + case EndpointServerTypeTCP: return "tcp4" - case EndpointServerType_UDP: + case EndpointServerTypeUDP: return "udp4" - case EndpointServerType_CUSTOM: + case EndpointServerTypeCustom: return "cust" default: return "" @@ -154,9 +154,8 @@ func (e *endpointClient) connect() (io.ReadWriteCloser, error) { return nil, errors.New("no connect function provided on custom endpoint") } return customConf.Connect(customConf.Address) - } else { - return nil, errors.New("failed type assertion to endpointcustomclient") } + return nil, errors.New("failed type assertion to endpointcustomclient") } return (&net.Dialer{}).DialContext(timedContext, network, e.conf.getAddress()) }() @@ -202,11 +201,11 @@ func (e *endpointClient) provide() (string, io.ReadWriteCloser, error) { func (e *endpointClient) label() string { return fmt.Sprintf("%s:%s", func() string { switch e.conf.clientType() { - case EndpointServerType_TCP: + case EndpointServerTypeTCP: return "tcp" - case EndpointServerType_UDP: + case EndpointServerTypeUDP: return "udp" - case EndpointServerType_CUSTOM: + case EndpointServerTypeCustom: if customConf, ok := e.conf.(EndpointCustomClient); ok { if customConf.Label != "" { return customConf.Label diff --git a/endpoint_server.go b/endpoint_server.go index 542abd334..218c66511 100644 --- a/endpoint_server.go +++ b/endpoint_server.go @@ -10,13 +10,14 @@ import ( "github.com/pion/transport/v2/udp" ) -// Enum for specifying what protocol the endpoint uses +// EndpointServerType is an Enum for specifying what protocol the endpoint uses type EndpointServerType int +// endpoint types const ( - EndpointServerType_TCP EndpointServerType = iota - EndpointServerType_UDP - EndpointServerType_CUSTOM + EndpointServerTypeTCP EndpointServerType = iota + EndpointServerTypeUDP + EndpointServerTypeCustom ) type endpointServerConf interface { @@ -35,7 +36,7 @@ type EndpointTCPServer struct { } func (EndpointTCPServer) serverType() EndpointServerType { - return EndpointServerType_TCP + return EndpointServerTypeTCP } func (conf EndpointTCPServer) getAddress() string { @@ -60,7 +61,7 @@ type EndpointUDPServer struct { } func (EndpointUDPServer) serverType() EndpointServerType { - return EndpointServerType_UDP + return EndpointServerTypeUDP } func (conf EndpointUDPServer) getAddress() string { @@ -90,7 +91,7 @@ type EndpointCustomServer struct { } func (EndpointCustomServer) serverType() EndpointServerType { - return EndpointServerType_CUSTOM + return EndpointServerTypeCustom } func (conf EndpointCustomServer) getAddress() string { @@ -123,7 +124,7 @@ func (e *endpointServer) initialize() error { } switch e.conf.serverType() { - case EndpointServerType_UDP: + case EndpointServerTypeUDP: var addr *net.UDPAddr addr, err = net.ResolveUDPAddr("udp4", e.conf.getAddress()) if err != nil { @@ -134,12 +135,12 @@ func (e *endpointServer) initialize() error { if err != nil { return err } - case EndpointServerType_TCP: + case EndpointServerTypeTCP: e.listener, err = net.Listen("tcp4", e.conf.getAddress()) if err != nil { return err } - case EndpointServerType_CUSTOM: + case EndpointServerTypeCustom: if customConf, ok := e.conf.(EndpointCustomServer); ok { e.listener, err = customConf.Listen(e.conf.getAddress()) if err != nil { @@ -182,11 +183,11 @@ func (e *endpointServer) provide() (string, io.ReadWriteCloser, error) { label := fmt.Sprintf("%s:%s", func() string { switch e.conf.serverType() { - case EndpointServerType_TCP: + case EndpointServerTypeTCP: return "tcp" - case EndpointServerType_UDP: + case EndpointServerTypeUDP: return "udp" - case EndpointServerType_CUSTOM: + case EndpointServerTypeCustom: if customConf, ok := e.conf.(EndpointCustomServer); ok { if customConf.Label != "" { return customConf.Label diff --git a/examples/node-endpoint-custom-client/main.go b/examples/node-endpoint-custom-client/main.go index 01ad00d2d..0422fb170 100644 --- a/examples/node-endpoint-custom-client/main.go +++ b/examples/node-endpoint-custom-client/main.go @@ -1,3 +1,4 @@ +// Package main contains an example. package main import ( @@ -21,7 +22,8 @@ func main() { Address: "127.0.0.1:5600", Connect: func(address string) (net.Conn, error) { tlsConfig := &tls.Config{ - InsecureSkipVerify: true, // skip checking the certificate against a CA (just set to true for simplicity of this example) + // skip checking the certificate against a CA (just set to true for simplicity of this example) + InsecureSkipVerify: true, } return tls.Dial("tcp", address, tlsConfig) diff --git a/examples/node-endpoint-custom-server/main.go b/examples/node-endpoint-custom-server/main.go index 169b71a73..97724d82d 100644 --- a/examples/node-endpoint-custom-server/main.go +++ b/examples/node-endpoint-custom-server/main.go @@ -1,3 +1,4 @@ +// Package main contains an example. package main import ( @@ -92,13 +93,13 @@ func GenerateCertAndKey() error { // Create the certs directory if it doesn't exist err := os.MkdirAll("certs", os.ModePerm) if err != nil { - return fmt.Errorf("failed to create certs directory: %v", err) + return fmt.Errorf("failed to create certs directory: %w", err) } // Generate RSA private key priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return fmt.Errorf("failed to generate private key: %v", err) + return fmt.Errorf("failed to generate private key: %w", err) } // Create certificate template @@ -120,31 +121,31 @@ func GenerateCertAndKey() error { // Create the certificate certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) if err != nil { - return fmt.Errorf("failed to create certificate: %v", err) + return fmt.Errorf("failed to create certificate: %w", err) } // Save the certificate certOut, err := os.Create("certs/cert.pem") if err != nil { - return fmt.Errorf("failed to create cert.pem: %v", err) + return fmt.Errorf("failed to create cert.pem: %w", err) } defer certOut.Close() err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) if err != nil { - return fmt.Errorf("failed to encode certificate to PEM: %v", err) + return fmt.Errorf("failed to encode certificate to PEM: %w", err) } // Save the private key keyOut, err := os.Create("certs/key.pem") if err != nil { - return fmt.Errorf("failed to create key.pem: %v", err) + return fmt.Errorf("failed to create key.pem: %w", err) } defer keyOut.Close() err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) if err != nil { - return fmt.Errorf("failed to encode private key to PEM: %v", err) + return fmt.Errorf("failed to encode private key to PEM: %w", err) } fmt.Println("cert.pem and key.pem generated in the 'certs/' directory.") From 41b57634be34e7829619fdf3de436d2273617197 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Thu, 16 Oct 2025 14:45:17 +0200 Subject: [PATCH 7/8] unexport endpointServerType --- endpoint_client.go | 26 +++++++++++++------------- endpoint_server.go | 36 +++++++++++++++++------------------- 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/endpoint_client.go b/endpoint_client.go index b8cd717d5..393605dd4 100644 --- a/endpoint_client.go +++ b/endpoint_client.go @@ -14,7 +14,7 @@ import ( var reconnectPeriod = 2 * time.Second type endpointClientConf interface { - clientType() EndpointServerType + clientType() endpointServerType getAddress() string init(*Node) (Endpoint, error) } @@ -28,8 +28,8 @@ type EndpointTCPClient struct { Address string } -func (EndpointTCPClient) clientType() EndpointServerType { - return EndpointServerTypeTCP +func (EndpointTCPClient) clientType() endpointServerType { + return endpointServerTypeTCP } func (conf EndpointTCPClient) getAddress() string { @@ -51,8 +51,8 @@ type EndpointUDPClient struct { Address string } -func (EndpointUDPClient) clientType() EndpointServerType { - return EndpointServerTypeUDP +func (EndpointUDPClient) clientType() endpointServerType { + return endpointServerTypeUDP } func (conf EndpointUDPClient) getAddress() string { @@ -79,8 +79,8 @@ type EndpointCustomClient struct { Label string } -func (EndpointCustomClient) clientType() EndpointServerType { - return EndpointServerTypeCustom +func (EndpointCustomClient) clientType() endpointServerType { + return endpointServerTypeCustom } func (conf EndpointCustomClient) getAddress() string { @@ -133,11 +133,11 @@ func (e *endpointClient) oneChannelAtAtime() bool { func (e *endpointClient) connect() (io.ReadWriteCloser, error) { network := func() string { switch e.conf.clientType() { - case EndpointServerTypeTCP: + case endpointServerTypeTCP: return "tcp4" - case EndpointServerTypeUDP: + case endpointServerTypeUDP: return "udp4" - case EndpointServerTypeCustom: + case endpointServerTypeCustom: return "cust" default: return "" @@ -201,11 +201,11 @@ func (e *endpointClient) provide() (string, io.ReadWriteCloser, error) { func (e *endpointClient) label() string { return fmt.Sprintf("%s:%s", func() string { switch e.conf.clientType() { - case EndpointServerTypeTCP: + case endpointServerTypeTCP: return "tcp" - case EndpointServerTypeUDP: + case endpointServerTypeUDP: return "udp" - case EndpointServerTypeCustom: + case endpointServerTypeCustom: if customConf, ok := e.conf.(EndpointCustomClient); ok { if customConf.Label != "" { return customConf.Label diff --git a/endpoint_server.go b/endpoint_server.go index 218c66511..1acac151f 100644 --- a/endpoint_server.go +++ b/endpoint_server.go @@ -10,18 +10,16 @@ import ( "github.com/pion/transport/v2/udp" ) -// EndpointServerType is an Enum for specifying what protocol the endpoint uses -type EndpointServerType int +type endpointServerType int -// endpoint types const ( - EndpointServerTypeTCP EndpointServerType = iota - EndpointServerTypeUDP - EndpointServerTypeCustom + endpointServerTypeTCP endpointServerType = iota + endpointServerTypeUDP + endpointServerTypeCustom ) type endpointServerConf interface { - serverType() EndpointServerType + serverType() endpointServerType getAddress() string init(*Node) (Endpoint, error) } @@ -35,8 +33,8 @@ type EndpointTCPServer struct { Address string } -func (EndpointTCPServer) serverType() EndpointServerType { - return EndpointServerTypeTCP +func (EndpointTCPServer) serverType() endpointServerType { + return endpointServerTypeTCP } func (conf EndpointTCPServer) getAddress() string { @@ -60,8 +58,8 @@ type EndpointUDPServer struct { Address string } -func (EndpointUDPServer) serverType() EndpointServerType { - return EndpointServerTypeUDP +func (EndpointUDPServer) serverType() endpointServerType { + return endpointServerTypeUDP } func (conf EndpointUDPServer) getAddress() string { @@ -90,8 +88,8 @@ type EndpointCustomServer struct { Label string } -func (EndpointCustomServer) serverType() EndpointServerType { - return EndpointServerTypeCustom +func (EndpointCustomServer) serverType() endpointServerType { + return endpointServerTypeCustom } func (conf EndpointCustomServer) getAddress() string { @@ -124,7 +122,7 @@ func (e *endpointServer) initialize() error { } switch e.conf.serverType() { - case EndpointServerTypeUDP: + case endpointServerTypeUDP: var addr *net.UDPAddr addr, err = net.ResolveUDPAddr("udp4", e.conf.getAddress()) if err != nil { @@ -135,12 +133,12 @@ func (e *endpointServer) initialize() error { if err != nil { return err } - case EndpointServerTypeTCP: + case endpointServerTypeTCP: e.listener, err = net.Listen("tcp4", e.conf.getAddress()) if err != nil { return err } - case EndpointServerTypeCustom: + case endpointServerTypeCustom: if customConf, ok := e.conf.(EndpointCustomServer); ok { e.listener, err = customConf.Listen(e.conf.getAddress()) if err != nil { @@ -183,11 +181,11 @@ func (e *endpointServer) provide() (string, io.ReadWriteCloser, error) { label := fmt.Sprintf("%s:%s", func() string { switch e.conf.serverType() { - case EndpointServerTypeTCP: + case endpointServerTypeTCP: return "tcp" - case EndpointServerTypeUDP: + case endpointServerTypeUDP: return "udp" - case EndpointServerTypeCustom: + case endpointServerTypeCustom: if customConf, ok := e.conf.(EndpointCustomServer); ok { if customConf.Label != "" { return customConf.Label From bdd40a89200c88aa510abeab754e12995a55adce Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Thu, 16 Oct 2025 15:03:36 +0200 Subject: [PATCH 8/8] remove endpointServerType --- endpoint_client.go | 49 +++++++++++++------------------------ endpoint_server.go | 60 ++++++++++++++-------------------------------- 2 files changed, 34 insertions(+), 75 deletions(-) diff --git a/endpoint_client.go b/endpoint_client.go index 393605dd4..a7431f8ef 100644 --- a/endpoint_client.go +++ b/endpoint_client.go @@ -14,7 +14,6 @@ import ( var reconnectPeriod = 2 * time.Second type endpointClientConf interface { - clientType() endpointServerType getAddress() string init(*Node) (Endpoint, error) } @@ -28,10 +27,6 @@ type EndpointTCPClient struct { Address string } -func (EndpointTCPClient) clientType() endpointServerType { - return endpointServerTypeTCP -} - func (conf EndpointTCPClient) getAddress() string { return conf.Address } @@ -51,10 +46,6 @@ type EndpointUDPClient struct { Address string } -func (EndpointUDPClient) clientType() endpointServerType { - return endpointServerTypeUDP -} - func (conf EndpointUDPClient) getAddress() string { return conf.Address } @@ -79,10 +70,6 @@ type EndpointCustomClient struct { Label string } -func (EndpointCustomClient) clientType() endpointServerType { - return endpointServerTypeCustom -} - func (conf EndpointCustomClient) getAddress() string { return conf.Address } @@ -132,12 +119,12 @@ func (e *endpointClient) oneChannelAtAtime() bool { func (e *endpointClient) connect() (io.ReadWriteCloser, error) { network := func() string { - switch e.conf.clientType() { - case endpointServerTypeTCP: + switch e.conf.(type) { + case EndpointTCPClient: return "tcp4" - case endpointServerTypeUDP: + case EndpointUDPClient: return "udp4" - case endpointServerTypeCustom: + case EndpointCustomClient: return "cust" default: return "" @@ -149,13 +136,11 @@ func (e *endpointClient) connect() (io.ReadWriteCloser, error) { timedContext, timedContextClose := context.WithTimeout(e.ctx, e.node.ReadTimeout) nconn, err := func() (net.Conn, error) { if network == "cust" { - if customConf, ok := e.conf.(EndpointCustomClient); ok { - if customConf.Connect == nil { - return nil, errors.New("no connect function provided on custom endpoint") - } - return customConf.Connect(customConf.Address) + customConf := e.conf.(EndpointCustomClient) + if customConf.Connect == nil { + return nil, errors.New("no connect function provided on custom endpoint") } - return nil, errors.New("failed type assertion to endpointcustomclient") + return customConf.Connect(customConf.Address) } return (&net.Dialer{}).DialContext(timedContext, network, e.conf.getAddress()) }() @@ -200,20 +185,18 @@ func (e *endpointClient) provide() (string, io.ReadWriteCloser, error) { func (e *endpointClient) label() string { return fmt.Sprintf("%s:%s", func() string { - switch e.conf.clientType() { - case endpointServerTypeTCP: + switch conf := e.conf.(type) { + case EndpointTCPClient: return "tcp" - case endpointServerTypeUDP: + case EndpointUDPClient: return "udp" - case endpointServerTypeCustom: - if customConf, ok := e.conf.(EndpointCustomClient); ok { - if customConf.Label != "" { - return customConf.Label - } + case EndpointCustomClient: + if conf.Label != "" { + return conf.Label } - return "cust" + return "custom" default: - return "unk" + return "unknown" } }(), e.conf.getAddress()) } diff --git a/endpoint_server.go b/endpoint_server.go index 1acac151f..bf16b73d4 100644 --- a/endpoint_server.go +++ b/endpoint_server.go @@ -10,16 +10,7 @@ import ( "github.com/pion/transport/v2/udp" ) -type endpointServerType int - -const ( - endpointServerTypeTCP endpointServerType = iota - endpointServerTypeUDP - endpointServerTypeCustom -) - type endpointServerConf interface { - serverType() endpointServerType getAddress() string init(*Node) (Endpoint, error) } @@ -33,10 +24,6 @@ type EndpointTCPServer struct { Address string } -func (EndpointTCPServer) serverType() endpointServerType { - return endpointServerTypeTCP -} - func (conf EndpointTCPServer) getAddress() string { return conf.Address } @@ -58,10 +45,6 @@ type EndpointUDPServer struct { Address string } -func (EndpointUDPServer) serverType() endpointServerType { - return endpointServerTypeUDP -} - func (conf EndpointUDPServer) getAddress() string { return conf.Address } @@ -88,10 +71,6 @@ type EndpointCustomServer struct { Label string } -func (EndpointCustomServer) serverType() endpointServerType { - return endpointServerTypeCustom -} - func (conf EndpointCustomServer) getAddress() string { return conf.Address } @@ -121,8 +100,8 @@ func (e *endpointServer) initialize() error { return fmt.Errorf("invalid address") } - switch e.conf.serverType() { - case endpointServerTypeUDP: + switch conf := e.conf.(type) { + case EndpointUDPServer: var addr *net.UDPAddr addr, err = net.ResolveUDPAddr("udp4", e.conf.getAddress()) if err != nil { @@ -133,20 +112,19 @@ func (e *endpointServer) initialize() error { if err != nil { return err } - case endpointServerTypeTCP: + + case EndpointTCPServer: e.listener, err = net.Listen("tcp4", e.conf.getAddress()) if err != nil { return err } - case endpointServerTypeCustom: - if customConf, ok := e.conf.(EndpointCustomServer); ok { - e.listener, err = customConf.Listen(e.conf.getAddress()) - if err != nil { - return err - } - } else { - return errors.New("type assertion error to endpointcustomserver") + + case EndpointCustomServer: + e.listener, err = conf.Listen(e.conf.getAddress()) + if err != nil { + return err } + default: return errors.New("unsupported server-type") } @@ -180,20 +158,18 @@ func (e *endpointServer) provide() (string, io.ReadWriteCloser, error) { } label := fmt.Sprintf("%s:%s", func() string { - switch e.conf.serverType() { - case endpointServerTypeTCP: + switch conf := e.conf.(type) { + case EndpointTCPServer: return "tcp" - case endpointServerTypeUDP: + case EndpointUDPServer: return "udp" - case endpointServerTypeCustom: - if customConf, ok := e.conf.(EndpointCustomServer); ok { - if customConf.Label != "" { - return customConf.Label - } + case EndpointCustomServer: + if conf.Label != "" { + return conf.Label } - return "cust" + return "custom" default: - return "unk" + return "unknown" } }(), nconn.RemoteAddr())