|
|
@@ -0,0 +1,243 @@ |
|
|
# Copyright 2021 The Matrix.org Foundation C.I.C. |
|
|
# |
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
# you may not use this file except in compliance with the License. |
|
|
# You may obtain a copy of the License at |
|
|
# |
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
# |
|
|
# Unless required by applicable law or agreed to in writing, software |
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
# See the License for the specific language governing permissions and |
|
|
# limitations under the License. |
|
|
|
|
|
|
|
|
from mock import patch |
|
|
from netaddr import IPSet |
|
|
from twisted.internet import defer |
|
|
from twisted.internet.error import DNSLookupError |
|
|
from twisted.test.proto_helpers import StringTransport |
|
|
from twisted.trial.unittest import TestCase |
|
|
from twisted.web.client import Agent |
|
|
|
|
|
from sydent.http.blacklisting_reactor import BlacklistingReactorWrapper |
|
|
from sydent.http.srvresolver import Server |
|
|
from tests.utils import make_request, make_sydent |
|
|
|
|
|
|
|
|
class BlacklistingAgentTest(TestCase): |
|
|
def setUp(self): |
|
|
config = { |
|
|
"general": { |
|
|
"ip.blacklist": "5.0.0.0/8", |
|
|
"ip.whitelist": "5.1.1.1", |
|
|
}, |
|
|
} |
|
|
|
|
|
self.sydent = make_sydent(test_config=config) |
|
|
|
|
|
self.reactor = self.sydent.reactor |
|
|
|
|
|
self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4" |
|
|
self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8" |
|
|
self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1" |
|
|
|
|
|
# Configure the reactor's DNS resolver. |
|
|
for (domain, ip) in ( |
|
|
(self.safe_domain, self.safe_ip), |
|
|
(self.unsafe_domain, self.unsafe_ip), |
|
|
(self.allowed_domain, self.allowed_ip), |
|
|
): |
|
|
self.reactor.lookups[domain.decode()] = ip.decode() |
|
|
self.reactor.lookups[ip.decode()] = ip.decode() |
|
|
|
|
|
self.ip_whitelist = self.sydent.ip_whitelist |
|
|
self.ip_blacklist = self.sydent.ip_blacklist |
|
|
|
|
|
def test_reactor(self): |
|
|
"""Apply the blacklisting reactor and ensure it properly blocks |
|
|
connections to particular domains and IPs. |
|
|
""" |
|
|
agent = Agent( |
|
|
BlacklistingReactorWrapper( |
|
|
self.reactor, |
|
|
ip_whitelist=self.ip_whitelist, |
|
|
ip_blacklist=self.ip_blacklist, |
|
|
), |
|
|
) |
|
|
|
|
|
# The unsafe domains and IPs should be rejected. |
|
|
for domain in (self.unsafe_domain, self.unsafe_ip): |
|
|
self.failureResultOf( |
|
|
agent.request(b"GET", b"http://" + domain), DNSLookupError |
|
|
) |
|
|
|
|
|
self.reactor.tcpClients = [] |
|
|
|
|
|
# The safe domains IPs should be accepted. |
|
|
for domain in ( |
|
|
self.safe_domain, |
|
|
self.allowed_domain, |
|
|
self.safe_ip, |
|
|
self.allowed_ip, |
|
|
): |
|
|
agent.request(b"GET", b"http://" + domain) |
|
|
|
|
|
# Grab the latest TCP connection. |
|
|
( |
|
|
host, |
|
|
port, |
|
|
client_factory, |
|
|
_timeout, |
|
|
_bindAddress, |
|
|
) = self.reactor.tcpClients.pop() |
|
|
|
|
|
@patch("sydent.http.srvresolver.SrvResolver.resolve_service") |
|
|
def test_federation_client_allowed_ip(self, resolver): |
|
|
self.sydent.run() |
|
|
|
|
|
request, channel = make_request( |
|
|
self.sydent.reactor, |
|
|
"POST", |
|
|
"/_matrix/identity/v2/account/register", |
|
|
{ |
|
|
"access_token": "foo", |
|
|
"expires_in": 300, |
|
|
"matrix_server_name": "example.com", |
|
|
"token_type": "Bearer", |
|
|
}, |
|
|
) |
|
|
|
|
|
resolver.return_value = defer.succeed( |
|
|
[ |
|
|
Server( |
|
|
host=self.allowed_domain, |
|
|
port=443, |
|
|
priority=1, |
|
|
weight=1, |
|
|
expires=100, |
|
|
) |
|
|
] |
|
|
) |
|
|
|
|
|
request.render(self.sydent.servlets.registerServlet) |
|
|
|
|
|
transport, protocol = self._get_http_request( |
|
|
self.allowed_ip.decode("ascii"), 443 |
|
|
) |
|
|
|
|
|
self.assertRegex( |
|
|
transport.value(), b"^GET /_matrix/federation/v1/openid/userinfo" |
|
|
) |
|
|
self.assertRegex(transport.value(), b"Host: example.com") |
|
|
|
|
|
# Send it the HTTP response |
|
|
res_json = '{ "sub": "@test:example.com" }'.encode("ascii") |
|
|
protocol.dataReceived( |
|
|
b"HTTP/1.1 200 OK\r\n" |
|
|
b"Server: Fake\r\n" |
|
|
b"Content-Type: application/json\r\n" |
|
|
b"Content-Length: %i\r\n" |
|
|
b"\r\n" |
|
|
b"%s" % (len(res_json), res_json) |
|
|
) |
|
|
|
|
|
self.assertEqual(channel.code, 200) |
|
|
|
|
|
@patch("sydent.http.srvresolver.SrvResolver.resolve_service") |
|
|
def test_federation_client_safe_ip(self, resolver): |
|
|
self.sydent.run() |
|
|
|
|
|
request, channel = make_request( |
|
|
self.sydent.reactor, |
|
|
"POST", |
|
|
"/_matrix/identity/v2/account/register", |
|
|
{ |
|
|
"access_token": "foo", |
|
|
"expires_in": 300, |
|
|
"matrix_server_name": "example.com", |
|
|
"token_type": "Bearer", |
|
|
}, |
|
|
) |
|
|
|
|
|
resolver.return_value = defer.succeed( |
|
|
[ |
|
|
Server( |
|
|
host=self.safe_domain, |
|
|
port=443, |
|
|
priority=1, |
|
|
weight=1, |
|
|
expires=100, |
|
|
) |
|
|
] |
|
|
) |
|
|
|
|
|
request.render(self.sydent.servlets.registerServlet) |
|
|
|
|
|
transport, protocol = self._get_http_request(self.safe_ip.decode("ascii"), 443) |
|
|
|
|
|
self.assertRegex( |
|
|
transport.value(), b"^GET /_matrix/federation/v1/openid/userinfo" |
|
|
) |
|
|
self.assertRegex(transport.value(), b"Host: example.com") |
|
|
|
|
|
# Send it the HTTP response |
|
|
res_json = '{ "sub": "@test:example.com" }'.encode("ascii") |
|
|
protocol.dataReceived( |
|
|
b"HTTP/1.1 200 OK\r\n" |
|
|
b"Server: Fake\r\n" |
|
|
b"Content-Type: application/json\r\n" |
|
|
b"Content-Length: %i\r\n" |
|
|
b"\r\n" |
|
|
b"%s" % (len(res_json), res_json) |
|
|
) |
|
|
|
|
|
self.assertEqual(channel.code, 200) |
|
|
|
|
|
@patch("sydent.http.srvresolver.SrvResolver.resolve_service") |
|
|
def test_federation_client_unsafe_ip(self, resolver): |
|
|
self.sydent.run() |
|
|
|
|
|
request, channel = make_request( |
|
|
self.sydent.reactor, |
|
|
"POST", |
|
|
"/_matrix/identity/v2/account/register", |
|
|
{ |
|
|
"access_token": "foo", |
|
|
"expires_in": 300, |
|
|
"matrix_server_name": "example.com", |
|
|
"token_type": "Bearer", |
|
|
}, |
|
|
) |
|
|
|
|
|
resolver.return_value = defer.succeed( |
|
|
[ |
|
|
Server( |
|
|
host=self.unsafe_domain, |
|
|
port=443, |
|
|
priority=1, |
|
|
weight=1, |
|
|
expires=100, |
|
|
) |
|
|
] |
|
|
) |
|
|
|
|
|
request.render(self.sydent.servlets.registerServlet) |
|
|
|
|
|
self.assertNot(self.reactor.tcpClients) |
|
|
|
|
|
self.assertEqual(channel.code, 500) |
|
|
|
|
|
def _get_http_request(self, expected_host, expected_port): |
|
|
clients = self.reactor.tcpClients |
|
|
(host, port, factory, _timeout, _bindAddress) = clients[-1] |
|
|
self.assertEqual(host, expected_host) |
|
|
self.assertEqual(port, expected_port) |
|
|
|
|
|
# complete the connection and wire it up to a fake transport |
|
|
protocol = factory.buildProtocol(None) |
|
|
transport = StringTransport() |
|
|
protocol.makeConnection(transport) |
|
|
|
|
|
return transport, protocol |