Skip to content

Commit

Permalink
feat: allow user to specify destination port via resolve flag (#801)
Browse files Browse the repository at this point in the history
This PR enhances `--resolve` flag by allowing users to specify
destination port during DNS lookup.

Resolves #790.

Signed-off-by: Billy Zha <jinzha1@microsoft.com>
  • Loading branch information
qweeah committed Feb 22, 2023
1 parent 8995628 commit fb68c73
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 68 deletions.
29 changes: 18 additions & 11 deletions cmd/oras/internal/option/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (opts *Remote) ApplyFlagsWithPrefix(fs *pflag.FlagSet, prefix, description
}

if fs.Lookup("resolve") == nil {
fs.StringArrayVarP(&opts.resolveFlag, "resolve", "", nil, "customized DNS formatted in `host:port:address`")
fs.StringArrayVarP(&opts.resolveFlag, "resolve", "", nil, "customized DNS formatted in `host:port:address[:address_port]`")
}
}

Expand Down Expand Up @@ -144,22 +144,29 @@ func (opts *Remote) parseResolve() error {
}
var dialer onet.Dialer
for _, r := range opts.resolveFlag {
parts := strings.SplitN(r, ":", 3)
if len(parts) < 3 {
return formatError(r, "expecting host:port:address")
parts := strings.SplitN(r, ":", 4)
length := len(parts)
if length < 3 {
return formatError(r, "expecting host:port:address[:address_port]")
}

port, err := strconv.Atoi(parts[1])
host := parts[0]
hostPort, err := strconv.Atoi(parts[1])
if err != nil {
return formatError(r, "expecting uint64 port")
return formatError(r, "expecting uint64 host port")
}

// ipv6 zone is not parsed
to := net.ParseIP(parts[2])
if to == nil {
address := net.ParseIP(parts[2])
if address == nil {
return formatError(r, "invalid IP address")
}
dialer.Add(parts[0], port, to)
addressPort := hostPort
if length > 3 {
addressPort, err = strconv.Atoi(parts[3])
if err != nil {
return formatError(r, "expecting uint64 address port")
}
}
dialer.Add(host, hostPort, address, addressPort)
}
opts.resolveDialContext = func(base *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
dialer.Dialer = base
Expand Down
69 changes: 49 additions & 20 deletions cmd/oras/internal/option/remote_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,40 +247,69 @@ func TestRemote_isPlainHttp_localhost(t *testing.T) {

func TestRemote_parseResolve_err(t *testing.T) {
tests := []struct {
name string
opts *Remote
wantErr bool
name string
opts *Remote
}{
{
name: "invalid flag",
opts: &Remote{resolveFlag: []string{"this-shouldn't_work"}},
wantErr: true,
name: "invalid flag",
opts: &Remote{resolveFlag: []string{"this-shouldn't_work"}},
},
{
name: "no host",
opts: &Remote{resolveFlag: []string{":port:address"}},
wantErr: true,
name: "no host",
opts: &Remote{resolveFlag: []string{":port:address"}},
},
{
name: "no address",
opts: &Remote{resolveFlag: []string{"host:port:"}},
wantErr: true,
name: "no address",
opts: &Remote{resolveFlag: []string{"host:port:"}},
},
{
name: "invalid address",
opts: &Remote{resolveFlag: []string{"host:port:invalid-ip"}},
wantErr: true,
name: "invalid address",
opts: &Remote{resolveFlag: []string{"host:port:invalid-ip"}},
},
{
name: "no port",
opts: &Remote{resolveFlag: []string{"host::address"}},
wantErr: true,
name: "no port",
opts: &Remote{resolveFlag: []string{"host::address"}},
},
{
name: "invalid source port",
opts: &Remote{resolveFlag: []string{"host:port:address"}},
},
{
name: "invalid destination port",
opts: &Remote{resolveFlag: []string{"host:443:address:port"}},
},
{
name: "no source port",
opts: &Remote{resolveFlag: []string{"host::address"}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.opts.parseResolve(); err == nil {
t.Errorf("Expecting error in Remote.parseResolve()")
}
})
}
}

func TestRemote_parseResolve(t *testing.T) {
tests := []struct {
name string
opts *Remote
}{
{
name: "fromHost:fromPort:toIp",
opts: &Remote{resolveFlag: []string{"host:443:0.0.0.0"}},
},
{
name: "fromHost:fromPort:toIp:toPort",
opts: &Remote{resolveFlag: []string{"host:443:0.0.0.0:5000"}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.opts.parseResolve(); (err != nil) != tt.wantErr {
t.Errorf("Remote.parseResolve() error = %v, wantErr %v", err, tt.wantErr)
if err := tt.opts.parseResolve(); err != nil {
t.Errorf("Remote.parseResolve() error = %v", err)
}
})
}
Expand Down
8 changes: 4 additions & 4 deletions internal/net/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,18 @@ type Dialer struct {
}

// Add adds an entry for DNS resolve.
func (d *Dialer) Add(from string, port int, to net.IP) {
func (d *Dialer) Add(from string, fromPort int, to net.IP, toPort int) {
if d.resolve == nil {
d.resolve = make(map[string]string)
}
d.resolve[fmt.Sprintf("%s:%d", from, port)] = fmt.Sprintf("%s:%d", to, port)
d.resolve[fmt.Sprintf("%s:%d", from, fromPort)] = fmt.Sprintf("%s:%d", to, toPort)
}

// DialContext connects to the addr on the named network using the provided
// context.
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if resolve, ok := d.resolve[addr]; ok {
addr = resolve
if resolved, ok := d.resolve[addr]; ok {
addr = resolved
}
return d.Dialer.DialContext(ctx, network, addr)
}
38 changes: 5 additions & 33 deletions internal/net/net_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,53 +16,25 @@ limitations under the License.
package net

import (
"context"
"fmt"
"net"
"reflect"
"testing"
)

func TestDialer_DialContext(t *testing.T) {
type args struct {
ctx context.Context
network string
addr string
}
tests := []struct {
name string
d *Dialer
args args
want net.Conn
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.d.DialContext(tt.args.ctx, tt.args.network, tt.args.addr)
if (err != nil) != tt.wantErr {
t.Errorf("Dialer.DialContext() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Dialer.DialContext() = %v, want %v", got, tt.want)
}
})
}
}

func TestRemote_parseResolve_ipv4(t *testing.T) {
host := "mockedHost"
port := "12345"
hostPort := 443
address := "192.168.1.1"
addressPort := 12345
var d Dialer
d.Add(host, 12345, net.ParseIP(address))
d.Add(host, hostPort, net.ParseIP(address), addressPort)

if len(d.resolve) != 1 {
t.Fatalf("expect 1 resolve entries but got %v", d.resolve)
}
want := make(map[string]string)
want[host+":"+port] = address + ":" + port
want[host+":"+fmt.Sprint(hostPort)] = address + ":" + fmt.Sprint(addressPort)
if !reflect.DeepEqual(want, d.resolve) {
t.Fatalf("expecting %v but got %v", want, d.resolve)
}
Expand Down

0 comments on commit fb68c73

Please sign in to comment.