Skip to content

Commit

Permalink
*: migrate to netip
Browse files Browse the repository at this point in the history
  • Loading branch information
jzelinskie committed Apr 14, 2022
1 parent 7455c2a commit e56ad81
Show file tree
Hide file tree
Showing 26 changed files with 445 additions and 471 deletions.
129 changes: 54 additions & 75 deletions bittorrent/bittorrent.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,34 @@
package bittorrent

import (
"encoding/binary"
"fmt"
"net"
"net/netip"
"time"

"github.com/chihaya/chihaya/pkg/iputil"
"github.com/chihaya/chihaya/pkg/log"
)

// PeerID represents a peer ID.
type PeerID [20]byte

// PeerIDFromBytes creates a PeerID from a byte slice.
//
// It panics if b is not 20 bytes long.
func PeerIDFromBytes(b []byte) PeerID {
if len(b) != 20 {
panic("peer ID must be 20 bytes")
}

var buf [20]byte
copy(buf[:], b)
return PeerID(buf)
}

// String implements fmt.Stringer, returning the base16 encoded PeerID.
func (p PeerID) String() string {
return fmt.Sprintf("%x", p[:])
}
func (p PeerID) String() string { return fmt.Sprintf("%x", p[:]) }

// RawString returns a 20-byte string of the raw bytes of the ID.
func (p PeerID) RawString() string {
return string(p[:])
}
// MarshalBinary returns a 20-byte string of the raw bytes of the ID.
func (p PeerID) MarshalBinary() []byte { return p[:] }

// PeerIDFromString creates a PeerID from a string.
// PeerIDFromBytes creates a PeerID from bytes.
//
// It panics if s is not 20 bytes long.
func PeerIDFromString(s string) PeerID {
if len(s) != 20 {
func PeerIDFromBytes(b []byte) PeerID {
if len(b) != 20 {
panic("peer ID must be 20 bytes")
}

var buf [20]byte
copy(buf[:], s)
copy(buf[:], b)
return PeerID(buf)
}

Expand Down Expand Up @@ -80,14 +65,10 @@ func InfoHashFromString(s string) InfoHash {
}

// String implements fmt.Stringer, returning the base16 encoded InfoHash.
func (i InfoHash) String() string {
return fmt.Sprintf("%x", i[:])
}
func (i InfoHash) String() string { return fmt.Sprintf("%x", i[:]) }

// RawString returns a 20-byte string of the raw bytes of the InfoHash.
func (i InfoHash) RawString() string {
return string(i[:])
}
// MarshalBinary returns a 20-byte string of the raw bytes of the InfoHash.
func (i InfoHash) MarshalBinary() []byte { return i[:] }

// AnnounceRequest represents the parsed parameters from an announce request.
type AnnounceRequest struct {
Expand Down Expand Up @@ -150,17 +131,17 @@ func (r AnnounceResponse) LogFields() log.Fields {

// ScrapeRequest represents the parsed parameters from a scrape request.
type ScrapeRequest struct {
AddressFamily AddressFamily
InfoHashes []InfoHash
Params Params
Peer
InfoHashes []InfoHash
Params Params
}

// LogFields renders the current response as a set of log fields.
func (r ScrapeRequest) LogFields() log.Fields {
return log.Fields{
"addressFamily": r.AddressFamily,
"infoHashes": r.InfoHashes,
"params": r.Params,
"peer": r.Peer,
"infoHashes": r.InfoHashes,
"params": r.Params,
}
}

Expand All @@ -187,65 +168,63 @@ type Scrape struct {
Incomplete uint32
}

// AddressFamily is the address family of an IP address.
type AddressFamily uint8

func (af AddressFamily) String() string {
switch af {
case IPv4:
return "IPv4"
case IPv6:
return "IPv6"
default:
panic("tried to print unknown AddressFamily")
}
}

// AddressFamily constants.
const (
IPv4 AddressFamily = iota
IPv6
)

// IP is a net.IP with an AddressFamily.
type IP struct {
net.IP
AddressFamily
}

func (ip IP) String() string {
return ip.IP.String()
}

// Peer represents the connection details of a peer that is returned in an
// announce response.
type Peer struct {
ID PeerID
IP IP
Port uint16
ID PeerID
AddrPort netip.AddrPort
}

// String implements fmt.Stringer to return a human-readable representation.
// The string will have the format <PeerID>@[<IP>]:<port>, for example
// "0102030405060708090a0b0c0d0e0f1011121314@[10.11.12.13]:1234"
func (p Peer) String() string {
return fmt.Sprintf("%s@[%s]:%d", p.ID.String(), p.IP.String(), p.Port)
return fmt.Sprintf("%s@[%s]:%d", p.ID, p.AddrPort.Addr(), p.AddrPort.Port())
}

// MarshalBinary encodes a Peer into a memory-efficient byte representation.
//
// The format is:
// 20-byte PeerID
// 2-byte Big Endian Port
// 4-byte or 16-byte IP address
func (p Peer) MarshalBinary() []byte {
ip := p.AddrPort.Addr().Unmap()
b := make([]byte, 20+2+(ip.BitLen()/8))
copy(b[:20], p.ID[:])
binary.BigEndian.PutUint16(b[20:22], p.AddrPort.Port())
copy(b[22:], ip.AsSlice())
return b
}

// PeerFromBytes parses a Peer from its raw representation.
func PeerFromBytes(b []byte) Peer {
return Peer{
ID: PeerIDFromBytes(b[:20]),
AddrPort: netip.AddrPortFrom(
iputil.MustAddrFromSlice(b[22:]).Unmap(),
binary.BigEndian.Uint16(b[20:22]),
),
}
}

// LogFields renders the current peer as a set of Logrus fields.
func (p Peer) LogFields() log.Fields {
return log.Fields{
"ID": p.ID,
"IP": p.IP,
"port": p.Port,
"IP": p.AddrPort.Addr().String(),
"port": p.AddrPort.Port(),
}
}

// Equal reports whether p and x are the same.
func (p Peer) Equal(x Peer) bool { return p.EqualEndpoint(x) && p.ID == x.ID }

// EqualEndpoint reports whether p and x have the same endpoint.
func (p Peer) EqualEndpoint(x Peer) bool { return p.Port == x.Port && p.IP.Equal(x.IP.IP) }
func (p Peer) EqualEndpoint(x Peer) bool {
return p.AddrPort.Port() == x.AddrPort.Port() &&
p.AddrPort.Addr().Compare(x.AddrPort.Addr()) == 0
}

// ClientError represents an error that should be exposed to the client over
// the BitTorrent protocol implementation.
Expand Down
12 changes: 5 additions & 7 deletions bittorrent/bittorrent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package bittorrent

import (
"fmt"
"net"
"net/netip"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -19,17 +19,15 @@ var peerStringTestCases = []struct {
}{
{
input: Peer{
ID: PeerIDFromBytes(b),
IP: IP{net.IPv4(10, 11, 12, 1), IPv4},
Port: 1234,
ID: PeerIDFromBytes(b),
AddrPort: netip.MustParseAddrPort("10.11.12.1:1234"),
},
expected: fmt.Sprintf("%s@[10.11.12.1]:1234", expected),
},
{
input: Peer{
ID: PeerIDFromBytes(b),
IP: IP{net.ParseIP("2001:db8::ff00:42:8329"), IPv6},
Port: 1234,
ID: PeerIDFromBytes(b),
AddrPort: netip.MustParseAddrPort("[2001:db8::ff00:42:8329]:1234"),
},
expected: fmt.Sprintf("%s@[2001:db8::ff00:42:8329]:1234", expected),
},
Expand Down
2 changes: 1 addition & 1 deletion bittorrent/client_id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestClientID(t *testing.T) {
t.Run(tt.peerID, func(t *testing.T) {
var clientID ClientID
copy(clientID[:], []byte(tt.clientID))
parsedID := NewClientID(PeerIDFromString(tt.peerID))
parsedID := NewClientID(PeerIDFromBytes([]byte(tt.peerID)))
if parsedID != clientID {
t.Error("Incorrectly parsed peer ID", tt.peerID, "as", parsedID)
}
Expand Down
12 changes: 4 additions & 8 deletions bittorrent/sanitize.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package bittorrent

import (
"net"
"net/netip"

"github.com/chihaya/chihaya/pkg/log"
)
Expand All @@ -15,7 +15,7 @@ var ErrInvalidPort = ClientError("invalid port")
// SanitizeAnnounce enforces a max and default NumWant and coerces the peer's
// IP address into the proper format.
func SanitizeAnnounce(r *AnnounceRequest, maxNumWant, defaultNumWant uint32) error {
if r.Port == 0 {
if r.AddrPort.Port() == 0 {
return ErrInvalidPort
}

Expand All @@ -25,12 +25,8 @@ func SanitizeAnnounce(r *AnnounceRequest, maxNumWant, defaultNumWant uint32) err
r.NumWant = maxNumWant
}

if ip := r.Peer.IP.To4(); ip != nil {
r.Peer.IP.IP = ip
r.Peer.IP.AddressFamily = IPv4
} else if len(r.Peer.IP.IP) == net.IPv6len { // implies r.Peer.IP.To4() == nil
r.Peer.IP.AddressFamily = IPv6
} else {
r.AddrPort = netip.AddrPortFrom(r.AddrPort.Addr().Unmap(), r.AddrPort.Port())
if !r.AddrPort.Addr().IsValid() || r.AddrPort.Addr().IsUnspecified() {
return ErrInvalidIP
}

Expand Down
26 changes: 10 additions & 16 deletions frontend/http/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"errors"
"net"
"net/http"
"net/netip"
"time"

"github.com/julienschmidt/httprouter"
Expand Down Expand Up @@ -307,12 +308,12 @@ func (f *Frontend) announceRoute(w http.ResponseWriter, r *http.Request, ps http
if f.EnableRequestTiming {
start = time.Now()
}
var af *bittorrent.AddressFamily
var addr netip.Addr
defer func() {
if f.EnableRequestTiming {
recordResponseDuration("announce", af, err, time.Since(start))
recordResponseDuration("announce", addr, err, time.Since(start))
} else {
recordResponseDuration("announce", af, err, time.Duration(0))
recordResponseDuration("announce", addr, err, time.Duration(0))
}
}()

Expand All @@ -321,8 +322,7 @@ func (f *Frontend) announceRoute(w http.ResponseWriter, r *http.Request, ps http
_ = WriteError(w, err)
return
}
af = new(bittorrent.AddressFamily)
*af = req.IP.AddressFamily
addr = req.AddrPort.Addr()

ctx := injectRouteParamsToContext(context.Background(), ps)
ctx, resp, err := f.logic.HandleAnnounce(ctx, req)
Expand All @@ -348,12 +348,12 @@ func (f *Frontend) scrapeRoute(w http.ResponseWriter, r *http.Request, ps httpro
if f.EnableRequestTiming {
start = time.Now()
}
var af *bittorrent.AddressFamily
var addr netip.Addr
defer func() {
if f.EnableRequestTiming {
recordResponseDuration("scrape", af, err, time.Since(start))
recordResponseDuration("scrape", addr, err, time.Since(start))
} else {
recordResponseDuration("scrape", af, err, time.Duration(0))
recordResponseDuration("scrape", addr, err, time.Duration(0))
}
}()

Expand All @@ -370,18 +370,12 @@ func (f *Frontend) scrapeRoute(w http.ResponseWriter, r *http.Request, ps httpro
return
}

reqIP := net.ParseIP(host)
if reqIP.To4() != nil {
req.AddressFamily = bittorrent.IPv4
} else if len(reqIP) == net.IPv6len { // implies reqIP.To4() == nil
req.AddressFamily = bittorrent.IPv6
} else {
addr, err = netip.ParseAddr(host)
if err != nil || addr.IsUnspecified() {
log.Error("http: invalid IP: neither v4 nor v6", log.Fields{"RemoteAddr": r.RemoteAddr})
_ = WriteError(w, bittorrent.ErrInvalidIP)
return
}
af = new(bittorrent.AddressFamily)
*af = req.AddressFamily

ctx := injectRouteParamsToContext(context.Background(), ps)
ctx, resp, err := f.logic.HandleScrape(ctx, req)
Expand Down
Loading

0 comments on commit e56ad81

Please sign in to comment.