From 89040b169b9e8c0b6bac1e2294182f12ee32631d Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 16 May 2023 13:07:38 -0400 Subject: [PATCH] Some basic unit testing for custom TLS context factories. --- test/test_web.py | 83 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 81 insertions(+), 2 deletions(-) diff --git a/test/test_web.py b/test/test_web.py index 3933ef05..a6a7a0f4 100644 --- a/test/test_web.py +++ b/test/test_web.py @@ -1,6 +1,7 @@ from mock import Mock +from twisted.web.client import BrowserLikePolicyForHTTPS from twisted.trial import unittest from twisted.internet import defer @@ -14,6 +15,11 @@ from txtorcon.circuit import TorCircuitEndpoint +class CustomTLSContextFactory(BrowserLikePolicyForHTTPS): + def creatorForNetloc(self, hostname, port): + return super().creatorForNetloc(b"custom.domain", port) + + class WebAgentTests(unittest.TestCase): if not _HAVE_WEB: skip = "Missing web" @@ -56,7 +62,12 @@ def test_socks_agent_tcp_host_port(self): def getConnection(key, endpoint): self.assertTrue(isinstance(endpoint, TorSocksEndpoint)) - self.assertTrue(endpoint._tls) + self.assertIsInstance( + endpoint._tls, + BrowserLikePolicyForHTTPS().creatorForNetloc(b"host", 443).__class__ + ) + # This uses a Twisted private interface... + self.assertEqual(endpoint._tls._hostname, "meejah.ca") self.assertEqual(endpoint._host, u'meejah.ca') self.assertEqual(endpoint._port, 443) return defer.succeed(proto) @@ -70,6 +81,38 @@ def getConnection(key, endpoint): res = yield agent.request(b'GET', b'https://meejah.ca') self.assertIs(res, gold) + @defer.inlineCallbacks + def test_socks_agent_custom_tls_context_factory(self): + reactor = Mock() + config = Mock() + config.SocksPort = [] + proto = Mock() + gold = object() + proto.request = Mock(return_value=defer.succeed(gold)) + + def getConnection(key, endpoint): + + self.assertIsInstance( + endpoint._tls, + BrowserLikePolicyForHTTPS().creatorForNetloc(b"host", 443).__class__ + ) + # This uses a Twisted private interface... + self.assertEqual(endpoint._tls._hostname, "custom.domain") + self.assertEqual(endpoint._host, 'meejah.ca') + return defer.succeed(proto) + pool = Mock() + pool.getConnection = getConnection + + # do the test + agent = yield agent_for_socks_port( + reactor, config, '127.0.0.50:1234', pool=pool, + tls_context_factory=CustomTLSContextFactory() + ) + + # apart from the getConnection asserts... + res = yield agent.request(b'GET', b'https://meejah.ca') + self.assertIs(res, gold) + @defer.inlineCallbacks def test_agent(self): reactor = Mock() @@ -95,7 +138,12 @@ def test_agent_with_circuit(self): def getConnection(key, endpoint): self.assertTrue(isinstance(endpoint, TorCircuitEndpoint)) target = endpoint._target_endpoint - self.assertTrue(target._tls) + self.assertIsInstance( + target._tls, + BrowserLikePolicyForHTTPS().creatorForNetloc(b"host", 443).__class__ + ) + # This uses a Twisted private interface... + self.assertEqual(target._tls._hostname, "meejah.ca") self.assertEqual(target._host, u'meejah.ca') self.assertEqual(target._port, 443) return defer.succeed(proto) @@ -107,3 +155,34 @@ def getConnection(key, endpoint): # apart from the getConnection asserts... res = yield agent.request(b'GET', b'https://meejah.ca') self.assertIs(res, gold) + + @defer.inlineCallbacks + def test_agent_with_circuit_tls_context_factory(self): + reactor = Mock() + circuit = Mock() + socks_ep = Mock() + proto = Mock() + gold = object() + proto.request = Mock(return_value=defer.succeed(gold)) + + def getConnection(key, endpoint): + target = endpoint._target_endpoint + self.assertIsInstance( + target._tls, + BrowserLikePolicyForHTTPS().creatorForNetloc(b"host", 443).__class__ + ) + # This uses a Twisted private interface... + self.assertEqual(target._tls._hostname, "custom.domain") + self.assertEqual(target._host, 'meejah.ca') + return defer.succeed(proto) + pool = Mock() + pool.getConnection = getConnection + + agent = yield tor_agent( + reactor, socks_ep, circuit=circuit, pool=pool, + tls_context_factory=CustomTLSContextFactory() + ) + + # apart from the getConnection asserts... + res = yield agent.request(b'GET', b'https://meejah.ca') + self.assertIs(res, gold)