Skip to content

Commit

Permalink
optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
fanpei91 committed Jan 23, 2019
1 parent 07d8f0a commit d0bbb49
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 89 deletions.
34 changes: 19 additions & 15 deletions blacklist.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,22 @@ type blackList struct {
ll *list.List
cache map[string]*list.Element
expiredAfter time.Duration
maxEntries int
limit int
}

func newBlackList(expiredAfter time.Duration, maxEntries int) *blackList {
func newBlackList(expiredAfter time.Duration, limit int) *blackList {
return &blackList{
ll: list.New(),
cache: make(map[string]*list.Element),
expiredAfter: expiredAfter,
maxEntries: maxEntries,
limit: limit,
}
}

func (b *blackList) add(addr string) {
var next *list.Element
for elem := b.ll.Front(); elem != nil; elem = next {
next = elem.Next()
e := elem.Value.(*entry)
if time.Now().Sub(e.ctime) < b.expiredAfter {
break
}
b.ll.Remove(elem)
delete(b.cache, e.addr)
}
b.removeExpired()

if b.ll.Len() >= b.maxEntries {
if b.ll.Len() >= b.limit {
return
}

Expand All @@ -49,7 +40,6 @@ func (b *blackList) add(addr string) {
})
b.cache[addr] = e
}

}

func (b *blackList) has(addr string) bool {
Expand All @@ -64,3 +54,17 @@ func (b *blackList) has(addr string) bool {

return false
}

func (b *blackList) removeExpired() {
now := time.Now()
var next *list.Element
for elem := b.ll.Front(); elem != nil; elem = next {
next = elem.Next()
e := elem.Value.(*entry)
if now.Sub(e.ctime) < b.expiredAfter {
break
}
b.ll.Remove(elem)
delete(b.cache, e.addr)
}
}
32 changes: 0 additions & 32 deletions bytespool.go

This file was deleted.

25 changes: 11 additions & 14 deletions dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,26 @@ func randBytes(n int) []byte {
func neighborID(target nodeID, local nodeID) nodeID {
const closeness = 15
id := make([]byte, 20)
copy(id[:10], target[:closeness])
copy(id[10:], local[closeness:])
copy(id[:closeness], target[:closeness])
copy(id[closeness:], local[closeness:])
return id
}

func makeQuery(tid string, q string, a map[string]interface{}) map[string]interface{} {
dict := map[string]interface{}{
return map[string]interface{}{
"t": tid,
"y": "q",
"q": q,
"a": a,
}
return dict
}

func makeReply(tid string, r map[string]interface{}) map[string]interface{} {
dict := map[string]interface{}{
return map[string]interface{}{
"t": tid,
"y": "r",
"r": r,
}
return dict
}

func decodeNodes(s string) (nodes []*node) {
Expand Down Expand Up @@ -136,8 +134,8 @@ func newDHT(laddr string, maxFriendsPerSec int) (*dht, error) {
}

func (g *dht) listen() {
buf := make([]byte, 2048)
for {
buf := packetPool.get()
n, addr, err := g.conn.ReadFromUDP(buf)
if err == nil {
g.onMessage(buf[:n], *addr)
Expand All @@ -146,13 +144,12 @@ func (g *dht) listen() {
close(g.die)
break
}
packetPool.put(buf)
}
}

func (g *dht) join() {
const times = 3
for i := 0; i < times; i++ {
const timesForSure = 3
for i := 0; i < timesForSure; i++ {
for _, addr := range g.bootstraps {
g.chNode <- &node{addr: addr, id: string(randBytes(20))}
}
Expand Down Expand Up @@ -189,8 +186,8 @@ func (g *dht) onQuery(dict map[string]interface{}, from net.UDPAddr) {
return
}

if f, ok := g.queryTypes[q]; ok {
f(dict, from)
if handle, ok := g.queryTypes[q]; ok {
handle(dict, from)
}
}

Expand Down Expand Up @@ -229,7 +226,7 @@ func (g *dht) findNode(to string, target nodeID) {
}

func (g *dht) onGetPeersQuery(dict map[string]interface{}, from net.UDPAddr) {
t := dict["t"].(string)
tid := dict["t"].(string)
a, ok := dict["a"].(map[string]interface{})
if !ok {
return
Expand All @@ -240,7 +237,7 @@ func (g *dht) onGetPeersQuery(dict map[string]interface{}, from net.UDPAddr) {
return
}

d := makeReply(t, map[string]interface{}{
d := makeReply(tid, map[string]interface{}{
"id": string(neighborID([]byte(id), g.localID)),
"nodes": "",
"token": g.genToken(from),
Expand Down
34 changes: 13 additions & 21 deletions meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha1"
"encoding/binary"
"errors"
Expand Down Expand Up @@ -35,12 +34,6 @@ var metaWirePool = sync.Pool{
},
}

func randomPeerID() string {
b := make([]byte, 20)
rand.Read(b)
return string(b)
}

type metaWire struct {
infohash string
from string
Expand All @@ -54,12 +47,14 @@ type metaWire struct {
err error
}

func newMetaWire(infohash string, from string) *metaWire {
func newMetaWire(infohash string, from string, timeout time.Duration) *metaWire {
w := metaWirePool.Get().(*metaWire)
w.infohash = infohash
w.from = from
w.peerID = randomPeerID()
w.timeout = 10 * time.Second
w.peerID = string(randBytes(20))
w.timeout = timeout
w.conn = nil
w.err = nil
return w
}

Expand Down Expand Up @@ -111,13 +106,6 @@ func (mw *metaWire) fetchCtx(ctx context.Context) ([]byte, error) {
}

func (mw *metaWire) connect(ctx context.Context) {
select {
case <-ctx.Done():
mw.err = errTimeout
return
default:
}

conn, err := net.DialTimeout("tcp", mw.from, mw.timeout)
if err != nil {
mw.err = fmt.Errorf("connect to remote peer failed: %v", err)
Expand Down Expand Up @@ -197,7 +185,7 @@ func (mw *metaWire) extHandshake(ctx context.Context) {
"ut_metadata": 1,
},
})...)
if err := mw.send(ctx, data); err != nil {
if err := mw.write(ctx, data); err != nil {
mw.err = err
return
}
Expand Down Expand Up @@ -261,7 +249,7 @@ func (mw *metaWire) requestPiece(ctx context.Context, i int) {
"msg_type": 0,
"piece": i,
}))
mw.send(ctx, buf.Bytes())
mw.write(ctx, buf.Bytes())
}

func (mw *metaWire) onExtended(ctx context.Context, ext byte, payload []byte) error {
Expand Down Expand Up @@ -357,7 +345,7 @@ func (mw *metaWire) read(ctx context.Context, size uint32) ([]byte, error) {
return buf.Bytes(), nil
}

func (mw *metaWire) send(ctx context.Context, data []byte) error {
func (mw *metaWire) write(ctx context.Context, data []byte) error {
select {
case <-ctx.Done():
return errTimeout
Expand All @@ -370,8 +358,12 @@ func (mw *metaWire) send(ctx context.Context, data []byte) error {
buf.Write(data)
_, err := mw.conn.Write(buf.Bytes())
if err != nil {
return fmt.Errorf("send message failed: %v", err)
return fmt.Errorf("write message failed: %v", err)
}

return nil
}

func (mw *metaWire) free() {
metaWirePool.Put(mw)
}
13 changes: 6 additions & 7 deletions torsniff.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (t *torrent) String() string {
)
}

func newTorrent(meta []byte, infohashHex string) (*torrent, error) {
func parseTorrent(meta []byte, infohashHex string) (*torrent, error) {
dict, err := bencode.Decode(bytes.NewBuffer(meta))
if err != nil {
return nil, err
Expand Down Expand Up @@ -173,8 +173,8 @@ func (t *torsniff) work(ac *announcement, tokens chan struct{}) {
}
t.mu.RUnlock()

wire := newMetaWire(string(ac.infohash), peerAddr)
defer metaWirePool.Put(wire)
wire := newMetaWire(string(ac.infohash), peerAddr, t.timeout)
defer wire.free()

data, err := wire.fetch()
if err != nil {
Expand All @@ -189,7 +189,7 @@ func (t *torsniff) work(ac *announcement, tokens chan struct{}) {
return
}

torrent, err := newTorrent(data, ac.infohashHex)
torrent, err := parseTorrent(data, ac.infohashHex)
if err != nil {
return
}
Expand Down Expand Up @@ -262,18 +262,17 @@ func main() {
return err
}

log.SetOutput(ioutil.Discard)
if verbose {
log.SetOutput(os.Stdout)
} else {
log.SetOutput(ioutil.Discard)
}

p := &torsniff{
laddr: fmt.Sprintf("%s:%d", addr, port),
timeout: timeout,
maxFriends: maxFriends,
maxPeers: peers,
secret: randomPeerID(),
secret: string(randBytes(20)),
dir: absDir,
blacklist: newBlackList(5*time.Minute, 50000),
}
Expand Down

0 comments on commit d0bbb49

Please sign in to comment.