From 5c7713c99d13ddf9cb27900d325ba70793c32157 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Thu, 21 Mar 2024 15:27:14 +0800 Subject: [PATCH] :white_check_mark: Add tests for refactored reading support --- tests/test_00_components.py | 47 +++++++++---------------------- tests/test_04_cross_validation.py | 9 ++++-- 2 files changed, 21 insertions(+), 35 deletions(-) diff --git a/tests/test_00_components.py b/tests/test_00_components.py index 7f42441..54a32ef 100644 --- a/tests/test_00_components.py +++ b/tests/test_00_components.py @@ -1,16 +1,14 @@ -import socket from io import BytesIO import pytest from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from vmess_aead.kdf import kdf -from vmess_aead.utils import SM4GCM +from vmess_aead.utils import SM4GCM, Shake128Reader from vmess_aead.utils.reader import ( BufferedReader, BytesReader, IOReader, ReadOutOfBoundError, - SocketReader, StreamCipherReader, ) @@ -34,7 +32,6 @@ def test_bytes_reader(): reader.append(b"67890") assert reader.read(5) == b"12345" assert reader.offset == 5 - assert reader.read_before() == b"12345" assert reader.read_all() == b"67890" assert reader.offset == 10 assert reader.remaining == 0 @@ -51,6 +48,7 @@ def test_buffered_reader(): assert reader.read_all() == b"67890" assert reader.offset == 10 + assert reader.remaining == 0 with pytest.raises(ReadOutOfBoundError): reader.read(5) @@ -62,6 +60,7 @@ def test_io_reader(): assert reader.read_all() == b"67890" assert reader.offset == 10 + assert reader.remaining == 0 with pytest.raises(ReadOutOfBoundError): reader.read(5) @@ -79,41 +78,23 @@ def test_stream_cipher_reader(): assert reader.read_all() == b"90" assert reader.offset == 10 + assert reader.remaining == 0 with pytest.raises(ReadOutOfBoundError): reader.read(5) -def test_socket_reader(): - server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server.bind(("localhost", 0)) - server.listen() - - client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - client.connect(server.getsockname()) - client.send(b"1234567890") - client.close() - - connection, src_addr = server.accept() - reader = SocketReader(connection) - assert reader.read(5) == b"12345" - assert reader.offset == 5 - assert reader.read(5) == b"67890" - assert reader.offset == 10 - - with pytest.raises(ReadOutOfBoundError): - reader.read(5) - - client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - client.connect(server.getsockname()) - client.send(b"1234567890") - client.close() - - connection, src_addr = server.accept() - reader = SocketReader(connection, buffer_size=2) - assert reader.read(5) == b"12345" +def test_shake128_reader(): + nonce = b"1234567890123456" + reader = Shake128Reader(nonce, increase_length=10) + assert reader.read(5).hex() == "c31adff3ca" assert reader.offset == 5 - assert reader.read_all() == b"67890" + assert reader.read_all().hex() == "ff74f708b9" assert reader.offset == 10 + # test increase length + assert reader.remaining == 0 + assert reader.read(5).hex() == "bb1ff0ee84" + # buffer size doubled, 10*2 + 10 = 30 bytes + assert reader.buffer_size == 30 def test_sm4_gcm(): diff --git a/tests/test_04_cross_validation.py b/tests/test_04_cross_validation.py index 451e8d4..ed65dc9 100644 --- a/tests/test_04_cross_validation.py +++ b/tests/test_04_cross_validation.py @@ -25,7 +25,7 @@ ) from vmess_aead.headers.response import VMessAEADResponsePacketHeader from vmess_aead.utils import generate_response_key -from vmess_aead.utils.reader import SocketReader +from vmess_aead.utils.reader import IOReader if platform.system() != "Linux" or platform.machine() != "x86_64": pytest.skip("Cross validation only works on Linux x86_64", allow_module_level=True) @@ -155,7 +155,12 @@ def test_as_client( server_connection.settimeout(5) server_connection.send(b"ok") - reader = SocketReader(client) + reader = IOReader(client.makefile("rb")) + + # socket file is not seekable + with pytest.raises(NotImplementedError): + assert reader.remaining + resp_iv = generate_response_key(header_packet.payload.body_iv) resp_key = generate_response_key(header_packet.payload.body_key) resp_header = VMessAEADResponsePacketHeader.from_packet(