Skip to content

Commit

Permalink
Renamed Client to Dialer
Browse files Browse the repository at this point in the history
  • Loading branch information
oxtoacart committed Dec 3, 2014
1 parent 56fa25d commit 096d164
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 76 deletions.
103 changes: 52 additions & 51 deletions client.go → dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ var (
idleTimeout = 10 * time.Second
)

// ClientConfig captures the configuration of a domain-fronted client.
type ClientConfig struct {
// Config captures the configuration of a domain-fronted dialer.
type Config struct {
// Host: the host (e.g. getiantem.org)
Host string

// Port: the port (e.g. 443)
Port int

// Masquerades: the Masquerades to use when domain-fronting. These will be
// verified when the client starts.
// verified when the Dialer starts.
Masquerades []*Masquerade

// InsecureSkipVerify: if true, server's certificate is not verified.
Expand Down Expand Up @@ -79,45 +79,46 @@ type ClientConfig struct {
OnDialStats func(success bool, domain, addr string, resolutionTime, connectTime, handshakeTime time.Duration)
}

// Client provides a mechanism for dialing domain-fronted servers.
type Client struct {
cfg *ClientConfig
// Dialer implements the proxy.Dialer interface by dialing domain-fronted
// servers.
type Dialer struct {
cfg *Config
masquerades *verifiedMasqueradeSet
connPool *connpool.Pool
enproxyConfig *enproxy.Config
tlsConfigs map[string]*tls.Config
tlsConfigsMutex sync.Mutex
}

// NewClient creates a new client for the given ClientConfig.
func NewClient(cfg *ClientConfig) *Client {
client := &Client{
// NewDialer creates a new Dialer for the given Config.
func NewDialer(cfg *Config) *Dialer {
d := &Dialer{
cfg: cfg,
tlsConfigs: make(map[string]*tls.Config),
}
if client.cfg.Masquerades != nil {
client.masquerades = client.verifiedMasquerades()
if d.cfg.Masquerades != nil {
d.masquerades = d.verifiedMasquerades()
}
client.connPool = &connpool.Pool{
d.connPool = &connpool.Pool{
MinSize: 30,
ClaimTimeout: idleTimeout,
Dial: client.dialServer,
Dial: d.dialServer,
}
client.enproxyConfig = client.enproxyConfigWith(func(addr string) (net.Conn, error) {
return client.connPool.Get()
d.enproxyConfig = d.enproxyConfigWith(func(addr string) (net.Conn, error) {
return d.connPool.Get()
})
return client
return d
}

// Dial dials upstream using domain-fronting.
func (client *Client) Dial(network, addr string) (net.Conn, error) {
func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
if !strings.Contains(network, "tcp") {
return nil, fmt.Errorf("Protocol %s is not supported, only tcp is supported", network)
}

conn := &enproxy.Conn{
Addr: addr,
Config: client.enproxyConfig,
Config: d.enproxyConfig,
}
err := conn.Connect()
if err != nil {
Expand All @@ -128,22 +129,22 @@ func (client *Client) Dial(network, addr string) (net.Conn, error) {

// Close closes the cilent, in particular closing the underlying connection
// pool.
func (client *Client) Close() error {
if client.connPool != nil {
func (d *Dialer) Close() error {
if d.connPool != nil {
// We stop the connPool on a goroutine so as not to wait for Stop to finish
go client.connPool.Stop()
go d.connPool.Stop()
}
if client.masquerades != nil {
go client.masquerades.stop()
if d.masquerades != nil {
go d.masquerades.stop()
}
return nil
}

// HttpClientUsing creates a simple domain-fronted HTTP client using the
// specified Masquerade.
func (client *Client) HttpClientUsing(masquerade *Masquerade) *http.Client {
enproxyConfig := client.enproxyConfigWith(func(addr string) (net.Conn, error) {
return client.dialServerWith(masquerade)
func (d *Dialer) HttpClientUsing(masquerade *Masquerade) *http.Client {
enproxyConfig := d.enproxyConfigWith(func(addr string) (net.Conn, error) {
return d.dialServerWith(masquerade)
})

return &http.Client{
Expand All @@ -163,31 +164,31 @@ func (client *Client) HttpClientUsing(masquerade *Masquerade) *http.Client {
}
}

func (client *Client) enproxyConfigWith(dialProxy func(addr string) (net.Conn, error)) *enproxy.Config {
func (d *Dialer) enproxyConfigWith(dialProxy func(addr string) (net.Conn, error)) *enproxy.Config {
return &enproxy.Config{
DialProxy: dialProxy,
NewRequest: func(upstreamHost string, method string, body io.Reader) (req *http.Request, err error) {
if upstreamHost == "" {
// No specific host requested, use configured one
upstreamHost = client.cfg.Host
upstreamHost = d.cfg.Host
}
return http.NewRequest(method, "http://"+upstreamHost+"/", body)
},
BufferRequests: client.cfg.BufferRequests,
BufferRequests: d.cfg.BufferRequests,
IdleTimeout: idleTimeout, // TODO: make this configurable
}
}

func (client *Client) dialServer() (net.Conn, error) {
func (d *Dialer) dialServer() (net.Conn, error) {
var masquerade *Masquerade
if client.masquerades != nil {
masquerade = client.masquerades.nextVerified()
if d.masquerades != nil {
masquerade = d.masquerades.nextVerified()
}
return client.dialServerWith(masquerade)
return d.dialServerWith(masquerade)
}

func (client *Client) dialServerWith(masquerade *Masquerade) (net.Conn, error) {
dialTimeout := time.Duration(client.cfg.DialTimeoutMillis) * time.Millisecond
func (d *Dialer) dialServerWith(masquerade *Masquerade) (net.Conn, error) {
dialTimeout := time.Duration(d.cfg.DialTimeoutMillis) * time.Millisecond
if dialTimeout == 0 {
dialTimeout = 20 * time.Second
}
Expand All @@ -204,11 +205,11 @@ func (client *Client) dialServerWith(masquerade *Masquerade) (net.Conn, error) {
Timeout: dialTimeout,
},
"tcp",
client.addressForServer(masquerade),
d.addressForServer(masquerade),
sendServerNameExtension,
client.tlsConfig(masquerade))
d.tlsConfig(masquerade))

if client.cfg.OnDialStats != nil {
if d.cfg.OnDialStats != nil {
domain := ""
if masquerade != nil {
domain = masquerade.Domain
Expand All @@ -219,7 +220,7 @@ func (client *Client) dialServerWith(masquerade *Masquerade) (net.Conn, error) {
resultAddr = cwt.Conn.RemoteAddr().String()
}

client.cfg.OnDialStats(err == nil, domain, resultAddr, cwt.ResolutionTime, cwt.ConnectTime, cwt.HandshakeTime)
d.cfg.OnDialStats(err == nil, domain, resultAddr, cwt.ResolutionTime, cwt.ConnectTime, cwt.HandshakeTime)
}

if err != nil && masquerade != nil {
Expand All @@ -229,12 +230,12 @@ func (client *Client) dialServerWith(masquerade *Masquerade) (net.Conn, error) {
}

// Get the address to dial for reaching the server
func (client *Client) addressForServer(masquerade *Masquerade) string {
return fmt.Sprintf("%s:%d", client.serverHost(masquerade), client.cfg.Port)
func (d *Dialer) addressForServer(masquerade *Masquerade) string {
return fmt.Sprintf("%s:%d", d.serverHost(masquerade), d.cfg.Port)
}

func (client *Client) serverHost(masquerade *Masquerade) string {
serverHost := client.cfg.Host
func (d *Dialer) serverHost(masquerade *Masquerade) string {
serverHost := d.cfg.Host
if masquerade != nil {
if masquerade.IpAddress != "" {
serverHost = masquerade.IpAddress
Expand All @@ -248,23 +249,23 @@ func (client *Client) serverHost(masquerade *Masquerade) string {
// tlsConfig builds a tls.Config for dialing the upstream host. Constructed
// tls.Configs are cached on a per-masquerade basis to enable client session
// caching and reduce the amount of PEM certificate parsing.
func (client *Client) tlsConfig(masquerade *Masquerade) *tls.Config {
client.tlsConfigsMutex.Lock()
defer client.tlsConfigsMutex.Unlock()
func (d *Dialer) tlsConfig(masquerade *Masquerade) *tls.Config {
d.tlsConfigsMutex.Lock()
defer d.tlsConfigsMutex.Unlock()

serverName := client.cfg.Host
serverName := d.cfg.Host
if masquerade != nil {
serverName = masquerade.Domain
}
tlsConfig := client.tlsConfigs[serverName]
tlsConfig := d.tlsConfigs[serverName]
if tlsConfig == nil {
tlsConfig = &tls.Config{
ClientSessionCache: tls.NewLRUClientSessionCache(1000),
InsecureSkipVerify: client.cfg.InsecureSkipVerify,
InsecureSkipVerify: d.cfg.InsecureSkipVerify,
ServerName: serverName,
RootCAs: client.cfg.RootCAs,
RootCAs: d.cfg.RootCAs,
}
client.tlsConfigs[serverName] = tlsConfig
d.tlsConfigs[serverName] = tlsConfig
}

return tlsConfig
Expand Down
34 changes: 17 additions & 17 deletions domainfronted_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,26 @@ const (
)

func TestBadProtocol(t *testing.T) {
client := NewClient(&ClientConfig{})
_, err := client.Dial("udp", "127.0.0.1:25324")
d := NewDialer(&Config{})
_, err := d.Dial("udp", "127.0.0.1:25324")
assert.Error(t, err, "Using a non-tcp protocol should have resulted in an error")
}

func TestBadEnproxyConn(t *testing.T) {
client := NewClient(&ClientConfig{
d := NewDialer(&Config{
Host: "localhost",
Port: 3253,
})
_, err := client.Dial("tcp", "www.google.com")
_, err := d.Dial("tcp", "www.google.com")
assert.Error(t, err, "Dialing using a non-existent host should have failed")
}

func TestHttpClientWithBadEnproxyConn(t *testing.T) {
client := NewClient(&ClientConfig{
d := NewDialer(&Config{
Host: "localhost",
Port: 3253,
})
hc := client.HttpClientUsing(nil)
hc := d.HttpClientUsing(nil)
_, err := hc.Get("http://www.google.com/humans.txt")
assert.Error(t, err, "HttpClient using a non-existent host should have failed")
}
Expand Down Expand Up @@ -78,8 +78,8 @@ func TestNonGlobalAddressBadAddr(t *testing.T) {

func doTestNonGlobalAddress(t *testing.T, useRealAddress bool) {
l := startServer(t, false)
client := clientFor(t, l)
defer client.Close()
d := dialerFor(t, l)
defer d.Close()

gotConn := false
var gotConnMutex sync.Mutex
Expand All @@ -98,7 +98,7 @@ func doTestNonGlobalAddress(t *testing.T, useRealAddress bool) {
if !useRealAddress {
addr = "asdflklsdkfjhladskfjhlasdkfjhlsads.asflkjshadlfkadsjhflk:0"
}
conn, err := client.Dial("tcp", addr)
conn, err := d.Dial("tcp", addr)
defer conn.Close()

data := []byte("Some Meaningless Data")
Expand All @@ -112,10 +112,10 @@ func doTestNonGlobalAddress(t *testing.T, useRealAddress bool) {

func TestRoundTrip(t *testing.T) {
l := startServer(t, true)
client := clientFor(t, l)
defer client.Close()
d := dialerFor(t, l)
defer d.Close()

proxy.Test(t, client)
proxy.Test(t, d)
}

// TestIntegration tests against existing domain-fronted servers running on
Expand Down Expand Up @@ -155,7 +155,7 @@ func TestIntegration(t *testing.T) {
actualHandshakeTime := time.Duration(0)
var statsMutex sync.Mutex

client := NewClient(&ClientConfig{
d := NewDialer(&Config{
Host: "fallbacks.getiantem.org",
Port: 443,
Masquerades: masquerades,
Expand All @@ -172,11 +172,11 @@ func TestIntegration(t *testing.T) {
}
},
})
defer client.Close()
defer d.Close()

hc := &http.Client{
Transport: &http.Transport{
Dial: client.Dial,
Dial: d.Dial,
},
}

Expand Down Expand Up @@ -222,15 +222,15 @@ func startServer(t *testing.T, allowNonGlobal bool) net.Listener {
return l
}

func clientFor(t *testing.T, l net.Listener) *Client {
func dialerFor(t *testing.T, l net.Listener) *Dialer {
addrParts := strings.Split(l.Addr().String(), ":")
host := addrParts[0]
port, err := strconv.Atoi(addrParts[1])
if err != nil {
t.Fatalf("Unable to parse port: %s", err)
}

return NewClient(&ClientConfig{
return NewDialer(&Config{
Host: host,
Port: port,
InsecureSkipVerify: true,
Expand Down
16 changes: 8 additions & 8 deletions masquerade.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type MasqueradeSet []*Masquerade
// verifiedMasqueradeSet verifies each configured Masquerade by attempting to
// proxy using it.
type verifiedMasqueradeSet struct {
client *Client
dialer *Dialer
candidatesCh chan *Masquerade
stopCh chan interface{}
verifiedCh chan *Masquerade
Expand All @@ -50,16 +50,16 @@ func (vms *verifiedMasqueradeSet) nextVerified() *Masquerade {
}

// verified sets up a new verifiedMasqueradeSet that verifies each of the
// Masquerades in this MasqueradeSet for the given client.
func (client *Client) verifiedMasquerades() *verifiedMasqueradeSet {
// Masquerades in this MasqueradeSet for the given Dialer.
func (d *Dialer) verifiedMasquerades() *verifiedMasqueradeSet {
// Size verifiedChSize to be able to hold the smaller of MaxMasquerades or
// the number of configured masquerades.
verifiedChSize := len(client.cfg.Masquerades)
verifiedChSize := len(d.cfg.Masquerades)
if MaxMasquerades < verifiedChSize {
verifiedChSize = MaxMasquerades
}
vms := &verifiedMasqueradeSet{
client: client,
dialer: d,
candidatesCh: make(chan *Masquerade),
stopCh: make(chan interface{}, 1),
verifiedCh: make(chan *Masquerade, verifiedChSize),
Expand All @@ -80,8 +80,8 @@ func (client *Client) verifiedMasquerades() *verifiedMasqueradeSet {
// feedCandidates feeds the candidate masquerades to our worker routines in
// random order
func (vms *verifiedMasqueradeSet) feedCandidates() {
for _, i := range rand.Perm(len(vms.client.cfg.Masquerades)) {
if !vms.feedCandidate(vms.client.cfg.Masquerades[i]) {
for _, i := range rand.Perm(len(vms.dialer.cfg.Masquerades)) {
if !vms.feedCandidate(vms.dialer.cfg.Masquerades[i]) {
break
}
}
Expand Down Expand Up @@ -134,7 +134,7 @@ func (vms *verifiedMasqueradeSet) doVerify(masquerade *Masquerade) bool {
}()
go func() {
start := time.Now()
httpClient := vms.client.HttpClientUsing(masquerade)
httpClient := vms.dialer.HttpClientUsing(masquerade)
req, _ := http.NewRequest("HEAD", "http://www.google.com/humans.txt", nil)
resp, err := httpClient.Do(req)
if err != nil {
Expand Down

0 comments on commit 096d164

Please sign in to comment.