Skip to content

Commit

Permalink
✅ Add tests for refactored reading support
Browse files Browse the repository at this point in the history
  • Loading branch information
mnixry committed Mar 21, 2024
1 parent a9e8828 commit 5c7713c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 35 deletions.
47 changes: 14 additions & 33 deletions tests/test_00_components.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import socket
from io import BytesIO

import pytest
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from vmess_aead.kdf import kdf
from vmess_aead.utils import SM4GCM
from vmess_aead.utils import SM4GCM, Shake128Reader
from vmess_aead.utils.reader import (
BufferedReader,
BytesReader,
IOReader,
ReadOutOfBoundError,
SocketReader,
StreamCipherReader,
)

Expand All @@ -34,7 +32,6 @@ def test_bytes_reader():
reader.append(b"67890")
assert reader.read(5) == b"12345"
assert reader.offset == 5
assert reader.read_before() == b"12345"
assert reader.read_all() == b"67890"
assert reader.offset == 10
assert reader.remaining == 0
Expand All @@ -51,6 +48,7 @@ def test_buffered_reader():
assert reader.read_all() == b"67890"
assert reader.offset == 10

assert reader.remaining == 0
with pytest.raises(ReadOutOfBoundError):
reader.read(5)

Expand All @@ -62,6 +60,7 @@ def test_io_reader():
assert reader.read_all() == b"67890"
assert reader.offset == 10

assert reader.remaining == 0
with pytest.raises(ReadOutOfBoundError):
reader.read(5)

Expand All @@ -79,41 +78,23 @@ def test_stream_cipher_reader():
assert reader.read_all() == b"90"
assert reader.offset == 10

assert reader.remaining == 0
with pytest.raises(ReadOutOfBoundError):
reader.read(5)


def test_socket_reader():
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind(("localhost", 0))
server.listen()

client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client.connect(server.getsockname())
client.send(b"1234567890")
client.close()

connection, src_addr = server.accept()
reader = SocketReader(connection)
assert reader.read(5) == b"12345"
assert reader.offset == 5
assert reader.read(5) == b"67890"
assert reader.offset == 10

with pytest.raises(ReadOutOfBoundError):
reader.read(5)

client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client.connect(server.getsockname())
client.send(b"1234567890")
client.close()

connection, src_addr = server.accept()
reader = SocketReader(connection, buffer_size=2)
assert reader.read(5) == b"12345"
def test_shake128_reader():
nonce = b"1234567890123456"
reader = Shake128Reader(nonce, increase_length=10)
assert reader.read(5).hex() == "c31adff3ca"
assert reader.offset == 5
assert reader.read_all() == b"67890"
assert reader.read_all().hex() == "ff74f708b9"
assert reader.offset == 10
# test increase length
assert reader.remaining == 0
assert reader.read(5).hex() == "bb1ff0ee84"
# buffer size doubled, 10*2 + 10 = 30 bytes
assert reader.buffer_size == 30


def test_sm4_gcm():
Expand Down
9 changes: 7 additions & 2 deletions tests/test_04_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from vmess_aead.headers.response import VMessAEADResponsePacketHeader
from vmess_aead.utils import generate_response_key
from vmess_aead.utils.reader import SocketReader
from vmess_aead.utils.reader import IOReader

if platform.system() != "Linux" or platform.machine() != "x86_64":
pytest.skip("Cross validation only works on Linux x86_64", allow_module_level=True)
Expand Down Expand Up @@ -155,7 +155,12 @@ def test_as_client(
server_connection.settimeout(5)
server_connection.send(b"ok")

reader = SocketReader(client)
reader = IOReader(client.makefile("rb"))

# socket file is not seekable
with pytest.raises(NotImplementedError):
assert reader.remaining

resp_iv = generate_response_key(header_packet.payload.body_iv)
resp_key = generate_response_key(header_packet.payload.body_key)
resp_header = VMessAEADResponsePacketHeader.from_packet(
Expand Down

0 comments on commit 5c7713c

Please sign in to comment.