diff --git a/Extensions/UTPeerExchange.cs b/Extensions/UTPeerExchange.cs index b59d07d..f49dc4d 100644 --- a/Extensions/UTPeerExchange.cs +++ b/Extensions/UTPeerExchange.cs @@ -103,9 +103,38 @@ public void OnExtendedMessage(PeerWireClient peerWireClient, byte[] bytes) } } - public void SendMessage(PeerWireClient peerWireClient, IPEndPoint[] addedEndPoints, byte[] flags) - { - - } + public void SendMessage(PeerWireClient peerWireClient, IPEndPoint[] addedEndPoints, byte[] flags, IPEndPoint[] droppedEndPoints) + { + if (addedEndPoints == null && droppedEndPoints == null) return; + + BDict d = new BDict(); + + if (addedEndPoints != null) + { + byte[] added = new byte[addedEndPoints.Length * 6]; + for (int x = 0; x < addedEndPoints.Length; x++) + { + addedEndPoints[x].Address.GetAddressBytes().CopyTo(added, x * 6); + BitConverter.GetBytes((ushort)addedEndPoints[x].Port).CopyTo(added, (x * 6)+4); + } + + d.Add("added", new BString { ByteValue = added }); + } + + if (droppedEndPoints != null) + { + byte[] dropped = new byte[droppedEndPoints.Length * 6]; + for (int x = 0; x < droppedEndPoints.Length; x++) + { + droppedEndPoints[x].Address.GetAddressBytes().CopyTo(dropped, x * 6); + + dropped.SetValue((ushort)droppedEndPoints[x].Port, (x * 6) + 2); + } + + d.Add("dropped", new BString { ByteValue = dropped }); + } + + peerWireClient.SendExtended(peerWireClient.GetOutgoingMessageID(this), BencodingUtils.EncodeBytes(d)); + } } } diff --git a/PeerWireClient.cs b/PeerWireClient.cs index 331f9cb..c14ea86 100644 --- a/PeerWireClient.cs +++ b/PeerWireClient.cs @@ -46,16 +46,18 @@ public class PeerWireClient internal readonly Socket Socket; private byte[] _internalBuffer; //async internal buffer private readonly List _protocolExtensions; - private readonly Dictionary _extOutgoing = new Dictionary(); - private readonly Dictionary _extIncoming = new Dictionary(); + private readonly Dictionary _extOutgoing = new Dictionary(); + private readonly Dictionary _extIncoming = new Dictionary(); public Int32 Timeout { get; private set; } public bool[] PeerBitField { get; set; } public bool KeepConnectionAlive { get; set; } public bool UseExtended { get; set; } + public bool UseFast { get; set; } + public bool UseDHT { get; set; } public bool RemoteUsesExtended { get; private set; } public bool RemoteUsesFast { get; private set; } - public bool UseFast { get; set; } + public bool RemoteUsesDHT { get; private set; } public String LocalPeerID { get; set; } public String RemotePeerID { get; private set; } public String Hash { get; set; } @@ -83,11 +85,9 @@ public PeerWireClient(Int32 timeout) Timeout = timeout; - Socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp) - { - ReceiveTimeout = timeout*1000, - SendTimeout = timeout*1000 - }; + Socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + Socket.ReceiveTimeout = timeout*1000; + Socket.SendTimeout = timeout*1000; _internalBuffer = new byte[0]; } @@ -118,7 +118,7 @@ public void Connect(String ipHost, Int32 port) public void Disconnect() { Socket.Disconnect(false); - Socket.Close(); + //Socket.Close(); } public void Handshake() @@ -143,8 +143,9 @@ public void Handshake(byte[] hash, byte[] peerId) if (peerId.Length != 20) throw new ArgumentOutOfRangeException("peerId", "Peer ID must be 20 bytes exactly"); byte[] reservedBytes = {0, 0, 0, 0, 0, 0, 0, 0}; - if(UseExtended) reservedBytes[5] |= 0x10; - if(UseFast) reservedBytes[7] |= 0x04; + reservedBytes[5] |= (byte)(UseExtended ? 0x10 : 0x00); + reservedBytes[7] |= (byte)(UseFast ? 0x04 : 0x00); + reservedBytes[7] |= (byte)(UseDHT ? 0x1 : 0x00); byte[] sendBuf = (new[] { (byte)_bitTorrentProtocolHeader.Length }).Concat(_bitTorrentProtocolHeader).Concat(reservedBytes).Concat(hash).Concat(peerId).ToArray(); @@ -152,7 +153,7 @@ public void Handshake(byte[] hash, byte[] peerId) { BDict handshakeDict = new BDict(); BDict mDict = new BDict(); - Int32 i = 1; + byte i = 1; foreach (IBTExtension extension in _protocolExtensions) { _extOutgoing.Add(extension.Protocol, i); @@ -196,14 +197,20 @@ public void Handshake(byte[] hash, byte[] peerId) Int32 resLen = readBuf[0]; if (resLen != 19) { - Socket.Disconnect(false); - Socket.Close(); - throw new InvalidProgramException("Invalid response received from peer"); + if (resLen == 0) + { + // keep alive? + Thread.Sleep(100); + + Disconnect(); + return; + } } byte[] recReserved = readBuf.Skip(20).Take(8).ToArray(); RemoteUsesExtended = (recReserved[5] & 0x10) == 0x10; RemoteUsesFast = (recReserved[7] & 0x04) == 0x04; + RemoteUsesDHT = (recReserved[7] & 0x1) == 0x1; byte[] recBuffer = new byte[128]; Socket.BeginReceive(recBuffer, 0, 128, SocketFlags.None, OnReceived, recBuffer); @@ -274,7 +281,10 @@ public void SendBitField(bool[] bitField, bool obsf) int x = (int)Math.Floor((double)i/8); ushort p = (ushort) (i%8); - if(bitField[i]) bytes[x] = bytes[x].SetBit(p); + if (bitField[i]) + { + bytes[x] = bytes[x].SetBit(p); + } } Socket.Send(Pack.Int32(1 + bitField.Length, Pack.Endianness.Big).Concat(new byte[] { 5 }).Concat(bytes).ToArray()); @@ -303,9 +313,11 @@ public void SendCancel(Int32 index, Int32 start, Int32 length) Socket.Send(Pack.Int32(13, Pack.Endianness.Big).Concat(new byte[] { 8 }).Concat(Pack.Int32(index)).Concat(Pack.Int32(start)).Concat(Pack.Int32(length)).ToArray()); } - public void SendExtended(Int32 extMsgId, Int32 start, Int32 length) + public void SendExtended(byte extMsgId, byte[] bytes) { - + Int32 length = 2 + bytes.Length; + + Socket.Send(Pack.Int32(length, Pack.Endianness.Big).Concat(new [] { (byte)20} ).Concat(new [] { extMsgId }).Concat(bytes).ToArray()); } public void OnReceived(IAsyncResult ar) @@ -322,21 +334,17 @@ public void OnReceived(IAsyncResult ar) } byte[] recBuffer = new byte[128]; - if (Socket.Connected) Socket.BeginReceive(recBuffer, 0, 128, SocketFlags.None, OnReceived, recBuffer); + + if (Socket.Connected) + { + Socket.BeginReceive(recBuffer, 0, 128, SocketFlags.None, OnReceived, recBuffer); + } } public bool Process() { Thread.Sleep(10); - /*if (_socket.Connected && _socket.Available > 0) - { - byte[] recBuffer = new byte[_socket.Available]; - _socket.Receive(recBuffer); - - _internalBuffer = _internalBuffer == null ? recBuffer : _internalBuffer.Concat(recBuffer).ToArray(); - }*/ - if (_internalBuffer.Length < 4) { if (!Socket.Connected) return false; @@ -560,10 +568,16 @@ private void ProcessReject() private void ProcessExtended(Int32 length) { Int32 msgId = _internalBuffer[0]; - lock (_locker) _internalBuffer = _internalBuffer.Skip(1).ToArray(); + lock (_locker) + { + _internalBuffer = _internalBuffer.Skip(1).ToArray(); + } byte[] buffer = _internalBuffer.Take(length-1).ToArray(); - lock (_locker) _internalBuffer = _internalBuffer.Skip(length - 1).ToArray(); + lock (_locker) + { + _internalBuffer = _internalBuffer.Skip(length - 1).ToArray(); + } if (msgId == 0) { @@ -573,7 +587,7 @@ private void ProcessExtended(Int32 length) foreach (KeyValuePair pair in mDict) { BInt i = (BInt)pair.Value; - _extIncoming.Add(i, pair.Key); + _extIncoming.Add((byte)i, pair.Key); IBTExtension ext = _protocolExtensions.FirstOrDefault(f => f.Protocol == pair.Key); @@ -585,7 +599,7 @@ private void ProcessExtended(Int32 length) } else { - KeyValuePair pair = _extIncoming.FirstOrDefault(f => f.Key == msgId); + KeyValuePair pair = _extIncoming.FirstOrDefault(f => f.Key == msgId); IBTExtension ext = _protocolExtensions.FirstOrDefault(f => f.Protocol == pair.Value); if (ext != null) @@ -599,7 +613,10 @@ private void ProcessAllowFast() { Int32 index = Unpack.Int32(_internalBuffer, 0, Unpack.Endianness.Big); - lock (_locker) _internalBuffer = _internalBuffer.Skip(4).ToArray(); + lock (_locker) + { + _internalBuffer = _internalBuffer.Skip(4).ToArray(); + } OnAllowFast(index); } @@ -610,52 +627,82 @@ private void ProcessAllowFast() private void OnKeepAlive() { - if (KeepAlive != null) KeepAlive(this); + if (KeepAlive != null) + { + KeepAlive(this); + } } private void OnChoke() { - if (Choke != null) Choke(this); + if (Choke != null) + { + Choke(this); + } } private void OnUnChoke() { - if (UnChoke != null) UnChoke(this); + if (UnChoke != null) + { + UnChoke(this); + } } private void OnInterested() { - if (Interested != null) Interested(this); + if (Interested != null) + { + Interested(this); + } } private void OnNotInterested() { - if (NotInterested != null) NotInterested(this); + if (NotInterested != null) + { + NotInterested(this); + } } private void OnHave(Int32 pieceIndex) { - if (Have != null) Have(this, pieceIndex); + if (Have != null) + { + Have(this, pieceIndex); + } } private void OnBitField(Int32 size, bool[] bitField) { - if (BitField != null) BitField(this, size, bitField); + if (BitField != null) + { + BitField(this, size, bitField); + } } private void OnRequest(Int32 index, Int32 begin, Int32 length) { - if (Request != null) Request(this, index, begin, length); + if (Request != null) + { + Request(this, index, begin, length); + } } private void OnPiece(Int32 index, Int32 begin, byte[] bytes) { - if (Piece != null) Piece(this, index, begin, bytes); + if (Piece != null) + { + Piece(this, index, begin, bytes); + } } private void OnCancel(Int32 index, Int32 begin, Int32 length) { - if (Cancel != null) Cancel(this, index, begin, length); + if (Cancel != null) + { + Cancel(this, index, begin, length); + } } private void OnPort(UInt16 port) @@ -676,22 +723,34 @@ private void OnSuggest(Int32 pieceIndex) private void OnHaveAll() { - if (HaveAll != null) HaveAll(this); + if (HaveAll != null) + { + HaveAll(this); + } } private void OnHaveNone() { - if (HaveNone != null) HaveNone(this); + if (HaveNone != null) + { + HaveNone(this); + } } private void OnReject(Int32 index, Int32 begin, Int32 length) { - if (Reject != null) Reject(this, index, begin, length); + if (Reject != null) + { + Reject(this, index, begin, length); + } } private void OnAllowFast(Int32 pieceIndex) { - if (AllowedFast != null) AllowedFast(this, pieceIndex); + if (AllowedFast != null) + { + AllowedFast(this, pieceIndex); + } } #endregion @@ -707,14 +766,14 @@ public void UnregisterProtocolExtension(IBTExtension extension) extension.Deinit(this); } - public Int64 GetOutgoingMessageID(IBTExtension extension) + public byte GetOutgoingMessageID(IBTExtension extension) { if (_extOutgoing.ContainsKey(extension.Protocol)) { return _extOutgoing[extension.Protocol]; } - return -1; + return 0; } } }