Skip to content

Commit

Permalink
Merge pull request #1 from koron-go/modernize-framework
Browse files Browse the repository at this point in the history
Modernize framework
  • Loading branch information
koron committed Apr 3, 2024
2 parents d374719 + 44c9493 commit 96597f8
Show file tree
Hide file tree
Showing 14 changed files with 322 additions and 32 deletions.
10 changes: 0 additions & 10 deletions .circleci/config.yml

This file was deleted.

1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* -text
29 changes: 29 additions & 0 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: Go

on: [push]

env:
GO_VERSION: '>=1.21.0'

jobs:

build:
name: Build
runs-on: ${{ matrix.os }}

strategy:
matrix:
os: [ ubuntu-latest, macos-latest, windows-latest ]
steps:

- uses: actions/checkout@v4

- uses: actions/setup-go@v5
with:
go-version: ${{ env.GO_VERSION }}

- run: go test

- run: go build

# based on: github.com/koron-go/_skeleton/.github/workflows/go.yml
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
*~
default.pgo
tags
tmp/
42 changes: 42 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
TEST_PACKAGE ?= ./...

.PHONY: build
build:
go build -gcflags '-e'

.PHONY: test
test:
go test $(TEST_PACKAGE)

.PHONY: bench
bench:
go test -bench $(TEST_PACKAGE)

.PHONY: tags
tags:
gotags -f tags -R .

.PHONY: cover
cover:
mkdir -p tmp
go test -coverprofile tmp/_cover.out $(TEST_PACKAGE)
go tool cover -html tmp/_cover.out -o tmp/cover.html

.PHONY: checkall
checkall: vet staticcheck

.PHONY: vet
vet:
go vet $(TEST_PACKAGE)

.PHONY: staticcheck
staticcheck:
staticcheck $(TEST_PACKAGE)

.PHONY: clean
clean:
go clean
rm -f tags
rm -f tmp/_cover.out tmp/cover.html

# based on: github.com/koron-go/_skeleton/Makefile
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# koron-go/dialsrv

[![GoDoc](https://godoc.org/github.com/koron-go/dialsrv?status.svg)](https://godoc.org/github.com/koron-go/dialsrv)
[![CircleCI](https://img.shields.io/circleci/project/github/koron-go/dialsrv/master.svg)](https://circleci.com/gh/koron-go/dialsrv/tree/master)
[![PkgGoDev](https://pkg.go.dev/badge/github.com/koron-go/dialsrv)](https://pkg.go.dev/github.com/koron-go/dialsrv)
[![Actions/Go](https://github.com/koron-go/dialsrv/workflows/Go/badge.svg)](https://github.com/koron-go/dialsrv/actions?query=workflow%3AGo)
[![Go Report Card](https://goreportcard.com/badge/github.com/koron-go/dialsrv)](https://goreportcard.com/report/github.com/koron-go/dialsrv)

Dialer with SRV lookup.
Expand Down
21 changes: 12 additions & 9 deletions dial.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
/*
Package dialsrv provides a net.Dialer implementation that can reference SRV
records to DNS servers.
*/
package dialsrv

import (
Expand All @@ -10,15 +14,17 @@ import (

// Dialer wraps net.Dialer with SRV lookup.
type Dialer struct {
nd *net.Dialer
drv driver
}

// New creates a new Dialer with base *net.Dialer.
func New(d *net.Dialer) *Dialer {
if d == nil {
d = &net.Dialer{}
}
return &Dialer{nd: d}
return &Dialer{
drv: &netDialerDriver{d},
}
}

// Dial connects to the address on the named network.
Expand All @@ -32,26 +38,23 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
if fa := parseAddr(network, address); fa != nil {
return d.dialSRV(ctx, fa)
}
return d.nd.DialContext(ctx, network, address)
return d.drv.DialContext(ctx, network, address)
}

func (d Dialer) dialSRV(ctx context.Context, fa *FlavoredAddr) (net.Conn, error) {
r := d.nd.Resolver
if r == nil {
r = net.DefaultResolver
}
host, err := splitHost(fa.Name)
if err != nil {
return nil, err
}
_, addrs, err := r.LookupSRV(ctx, fa.Service, fa.Proto, host)
_, addrs, err := d.drv.LookupSRV(ctx, fa.Service, fa.Proto, host)
if err != nil {
return nil, err
}
if len(addrs) == 0 {
return nil, fmt.Errorf("no SRV records for %s", fa.String())
}
return d.nd.DialContext(ctx, fa.Network, address(addrs[0]))
// TODO: consider the case of len(addrs) >= 2. Use with rotation or random?
return d.drv.DialContext(ctx, fa.Network, address(addrs[0]))
}

func splitHost(s string) (string, error) {
Expand Down
188 changes: 178 additions & 10 deletions dial_test.go
Original file line number Diff line number Diff line change
@@ -1,30 +1,52 @@
package dialsrv

import (
"context"
"net"
"reflect"
"testing"
"time"
)

func TestParseAddr(t *testing.T) {
for _, d := range []struct {
n, a string
fa FlavoredAddr
str string
}{
{"tcp", "srv+myservice+example.com",
FlavoredAddr{"tcp", "myservice", "tcp", "example.com"}},
{"udp", "srv+myservice+example.com",
FlavoredAddr{"udp", "myservice", "udp", "example.com"}},
{"tcp", "srv+myapi+example.com",
FlavoredAddr{"tcp", "myapi", "tcp", "example.com"}},
{"tcp", "srv+myservice+foo.example.org",
FlavoredAddr{"tcp", "myservice", "tcp", "foo.example.org"}},
{"tcp", "srv+example.com",
FlavoredAddr{"tcp", "", "", "example.com"}},
{
"tcp", "srv+myservice+example.com",
FlavoredAddr{"tcp", "myservice", "tcp", "example.com"},
"_myservice._tcp.example.com",
},
{
"udp", "srv+myservice+example.com",
FlavoredAddr{"udp", "myservice", "udp", "example.com"},
"_myservice._udp.example.com",
},
{
"tcp", "srv+myapi+example.com",
FlavoredAddr{"tcp", "myapi", "tcp", "example.com"},
"_myapi._tcp.example.com",
},
{
"tcp", "srv+myservice+foo.example.org",
FlavoredAddr{"tcp", "myservice", "tcp", "foo.example.org"},
"_myservice._tcp.foo.example.org",
},
{
"tcp", "srv+example.com",
FlavoredAddr{"tcp", "", "", "example.com"},
"example.com",
},
} {
act := parseAddr(d.n, d.a)
if !reflect.DeepEqual(act, &d.fa) {
t.Errorf("unexpected parse %s, %s: %#v", d.n, d.a, act)
}
if want, got := d.str, act.String(); want != got {
t.Errorf("unexpected string:\nwant=%s\n got=%s", want, got)
}
}
}

Expand All @@ -43,3 +65,149 @@ func TestParseAddrNil(t *testing.T) {
}
}
}

type testConn struct {
network string
address string
}

func (*testConn) Read([]byte) (int, error) { return 0, nil }
func (*testConn) Write([]byte) (int, error) { return 0, nil }
func (*testConn) Close() error { return nil }
func (*testConn) LocalAddr() net.Addr { return nil }
func (*testConn) RemoteAddr() net.Addr { return nil }
func (*testConn) SetDeadline(time.Time) error { return nil }
func (*testConn) SetReadDeadline(time.Time) error { return nil }
func (*testConn) SetWriteDeadline(time.Time) error { return nil }

type dialContextParams struct {
network string
address string
}

type dialContextResults struct {
conn net.Conn
err error
}

type lookupSRVParams struct {
service string
proto string
name string
}

type lookupSRVResults struct {
cname string
addrs []*net.SRV
err error
}

type testDriver struct {
dialContextParams *dialContextParams
dialContextResults *dialContextResults
lookupSRVParams *lookupSRVParams
lookupSRVResults *lookupSRVResults
}

func (d *testDriver) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
d.dialContextParams = &dialContextParams{
network: network,
address: address,
}
if d.dialContextResults == nil {
return &testConn{
network: network,
address: address,
}, nil
}
return d.dialContextResults.conn, d.dialContextResults.err
}

func (d *testDriver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) {
d.lookupSRVParams = &lookupSRVParams{
service: service,
proto: proto,
name: name,
}
if d.lookupSRVResults == nil {
target := name
if service != "" {
target = service + "." + name
}
return "sample", []*net.SRV{
{Target: target, Port: 1234, Priority: 1, Weight: 100},
}, nil
}
return d.lookupSRVResults.cname, d.lookupSRVResults.addrs, d.lookupSRVResults.err
}

var _ driver = (*testDriver)(nil)

func TestDial(t *testing.T) {
for _, c := range []struct {
network string
address string
want dialContextResults
wantDialParams *dialContextParams
wantLookupParams *lookupSRVParams
}{
{ // without "srv+" prefix
"tcp",
"example.com",
dialContextResults{&testConn{"tcp", "example.com"}, nil},
&dialContextParams{"tcp", "example.com"},
nil,
},
{ // with simple "srv+"
"tcp",
"srv+example.com",
dialContextResults{&testConn{"tcp", "example.com:1234"}, nil},
&dialContextParams{"tcp", "example.com:1234"},
&lookupSRVParams{"", "", "example.com"},
},
{ // with "srv+" and "ldap" service
"tcp",
"srv+ldap+example.com",
dialContextResults{&testConn{"tcp", "ldap.example.com:1234"}, nil},
&dialContextParams{"tcp", "ldap.example.com:1234"},
&lookupSRVParams{"ldap", "tcp", "example.com"},
},
{ // with "srv+", "ldap" service and specify port
"tcp",
"srv+ldap+example.com:443",
dialContextResults{&testConn{"tcp", "ldap.example.com:1234"}, nil},
&dialContextParams{"tcp", "ldap.example.com:1234"},
&lookupSRVParams{"ldap", "tcp", "example.com"},
},
} {
driver := &testDriver{}
d := Dialer{drv: driver}
gotConn, gotErr := d.Dial(c.network, c.address)
if gotErr != nil {
if c.want.err == nil {
t.Errorf("unexpected error: %v", gotErr)
continue
}
if want, got := c.want.err.Error(), gotErr.Error(); got != want {
t.Errorf("unmatch error:\nwant=%v\ngot=%v", want, got)
}
continue
}
if want, got := c.want.conn, gotConn; !reflect.DeepEqual(want, got) {
t.Errorf("unmatch conn:\nwant=%+v\ngot=%+v", want, got)
}
if want, got := c.wantDialParams, driver.dialContextParams; !reflect.DeepEqual(want, got) {
t.Errorf("unmatch dial params:\nwant=%+v\ngot=%+v", want, got)
}
if want, got := c.wantLookupParams, driver.lookupSRVParams; !reflect.DeepEqual(want, got) {
t.Errorf("unmatch lookup params:\nwant=%+v\ngot=%+v", want, got)
}
}
}

func TestNew(t *testing.T) {
d := New(nil)
if want, got := (&net.Dialer{}), d.drv.(*netDialerDriver).Dialer; !reflect.DeepEqual(want, got) {
t.Errorf("unexpected underlying dialer:\nwant=%+v\ngot=%+v", want, got)
}
}
25 changes: 25 additions & 0 deletions driver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package dialsrv

import (
"context"
"net"
)

type driver interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
}

type netDialerDriver struct {
*net.Dialer
}

var _ driver = (*netDialerDriver)(nil)

func (ndd *netDialerDriver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) {
r := ndd.Resolver
if r == nil {
r = net.DefaultResolver
}
return r.LookupSRV(ctx, service, proto, name)
}
Loading

0 comments on commit 96597f8

Please sign in to comment.