@@ -0,0 +1,118 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from typing import Iterable, Optional

from netaddr import AddrFormatError, IPNetwork, IPSet

# IP ranges that are considered private / unroutable / don't make sense.
DEFAULT_IP_RANGE_BLACKLIST = [
# Localhost
"127.0.0.0/8",
# Private networks.
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
# Carrier grade NAT.
"100.64.0.0/10",
# Address registry.
"192.0.0.0/24",
# Link-local networks.
"169.254.0.0/16",
# Formerly used for 6to4 relay.
"192.88.99.0/24",
# Testing networks.
"198.18.0.0/15",
"192.0.2.0/24",
"198.51.100.0/24",
"203.0.113.0/24",
# Multicast.
"224.0.0.0/4",
# Localhost
"::1/128",
# Link-local addresses.
"fe80::/10",
# Unique local addresses.
"fc00::/7",
# Testing networks.
"2001:db8::/32",
# Multicast.
"ff00::/8",
# Site-local addresses
"fec0::/10",
]


def generate_ip_set(
ip_addresses: Optional[Iterable[str]],
extra_addresses: Optional[Iterable[str]] = None,
config_path: Optional[Iterable[str]] = None,
) -> IPSet:
"""
Generate an IPSet from a list of IP addresses or CIDRs.
Additionally, for each IPv4 network in the list of IP addresses, also
includes the corresponding IPv6 networks.
This includes:
* IPv4-Compatible IPv6 Address (see RFC 4291, section 2.5.5.1)
* IPv4-Mapped IPv6 Address (see RFC 4291, section 2.5.5.2)
* 6to4 Address (see RFC 3056, section 2)
Args:
ip_addresses: An iterable of IP addresses or CIDRs.
extra_addresses: An iterable of IP addresses or CIDRs.
config_path: The path in the configuration for error messages.
Returns:
A new IP set.
"""
result = IPSet()
for ip in itertools.chain(ip_addresses or (), extra_addresses or ()):
try:
network = IPNetwork(ip)
except AddrFormatError as e:
raise Exception(
"Invalid IP range provided: %s." % (ip,), config_path
) from e
result.add(network)

# It is possible that these already exist in the set, but that's OK.
if ":" not in str(network):
result.add(IPNetwork(network).ipv6(ipv4_compatible=True))
result.add(IPNetwork(network).ipv6(ipv4_compatible=False))
result.add(_6to4(network))

return result


def _6to4(network: IPNetwork) -> IPNetwork:
"""Convert an IPv4 network into a 6to4 IPv6 network per RFC 3056."""

# 6to4 networks consist of:
# * 2002 as the first 16 bits
# * The first IPv4 address in the network hex-encoded as the next 32 bits
# * The new prefix length needs to include the bits from the 2002 prefix.
hex_network = hex(network.first)[2:]
hex_network = ("0" * (8 - len(hex_network))) + hex_network
return IPNetwork(
"2002:%s:%s::/%d"
% (
hex_network[:4],
hex_network[4:],
16 + network.prefixlen,
)
)
@@ -0,0 +1,243 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from mock import patch
from netaddr import IPSet
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.test.proto_helpers import StringTransport
from twisted.trial.unittest import TestCase
from twisted.web.client import Agent

from sydent.http.blacklisting_reactor import BlacklistingReactorWrapper
from sydent.http.srvresolver import Server
from tests.utils import make_request, make_sydent


class BlacklistingAgentTest(TestCase):
def setUp(self):
config = {
"general": {
"ip.blacklist": "5.0.0.0/8",
"ip.whitelist": "5.1.1.1",
},
}

self.sydent = make_sydent(test_config=config)

self.reactor = self.sydent.reactor

self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"

# Configure the reactor's DNS resolver.
for (domain, ip) in (
(self.safe_domain, self.safe_ip),
(self.unsafe_domain, self.unsafe_ip),
(self.allowed_domain, self.allowed_ip),
):
self.reactor.lookups[domain.decode()] = ip.decode()
self.reactor.lookups[ip.decode()] = ip.decode()

self.ip_whitelist = self.sydent.ip_whitelist
self.ip_blacklist = self.sydent.ip_blacklist

def test_reactor(self):
"""Apply the blacklisting reactor and ensure it properly blocks
connections to particular domains and IPs.
"""
agent = Agent(
BlacklistingReactorWrapper(
self.reactor,
ip_whitelist=self.ip_whitelist,
ip_blacklist=self.ip_blacklist,
),
)

# The unsafe domains and IPs should be rejected.
for domain in (self.unsafe_domain, self.unsafe_ip):
self.failureResultOf(
agent.request(b"GET", b"http://" + domain), DNSLookupError
)

self.reactor.tcpClients = []

# The safe domains IPs should be accepted.
for domain in (
self.safe_domain,
self.allowed_domain,
self.safe_ip,
self.allowed_ip,
):
agent.request(b"GET", b"http://" + domain)

# Grab the latest TCP connection.
(
host,
port,
client_factory,
_timeout,
_bindAddress,
) = self.reactor.tcpClients.pop()

@patch("sydent.http.srvresolver.SrvResolver.resolve_service")
def test_federation_client_allowed_ip(self, resolver):
self.sydent.run()

request, channel = make_request(
self.sydent.reactor,
"POST",
"/_matrix/identity/v2/account/register",
{
"access_token": "foo",
"expires_in": 300,
"matrix_server_name": "example.com",
"token_type": "Bearer",
},
)

resolver.return_value = defer.succeed(
[
Server(
host=self.allowed_domain,
port=443,
priority=1,
weight=1,
expires=100,
)
]
)

request.render(self.sydent.servlets.registerServlet)

transport, protocol = self._get_http_request(
self.allowed_ip.decode("ascii"), 443
)

self.assertRegex(
transport.value(), b"^GET /_matrix/federation/v1/openid/userinfo"
)
self.assertRegex(transport.value(), b"Host: example.com")

# Send it the HTTP response
res_json = '{ "sub": "@test:example.com" }'.encode("ascii")
protocol.dataReceived(
b"HTTP/1.1 200 OK\r\n"
b"Server: Fake\r\n"
b"Content-Type: application/json\r\n"
b"Content-Length: %i\r\n"
b"\r\n"
b"%s" % (len(res_json), res_json)
)

self.assertEqual(channel.code, 200)

@patch("sydent.http.srvresolver.SrvResolver.resolve_service")
def test_federation_client_safe_ip(self, resolver):
self.sydent.run()

request, channel = make_request(
self.sydent.reactor,
"POST",
"/_matrix/identity/v2/account/register",
{
"access_token": "foo",
"expires_in": 300,
"matrix_server_name": "example.com",
"token_type": "Bearer",
},
)

resolver.return_value = defer.succeed(
[
Server(
host=self.safe_domain,
port=443,
priority=1,
weight=1,
expires=100,
)
]
)

request.render(self.sydent.servlets.registerServlet)

transport, protocol = self._get_http_request(self.safe_ip.decode("ascii"), 443)

self.assertRegex(
transport.value(), b"^GET /_matrix/federation/v1/openid/userinfo"
)
self.assertRegex(transport.value(), b"Host: example.com")

# Send it the HTTP response
res_json = '{ "sub": "@test:example.com" }'.encode("ascii")
protocol.dataReceived(
b"HTTP/1.1 200 OK\r\n"
b"Server: Fake\r\n"
b"Content-Type: application/json\r\n"
b"Content-Length: %i\r\n"
b"\r\n"
b"%s" % (len(res_json), res_json)
)

self.assertEqual(channel.code, 200)

@patch("sydent.http.srvresolver.SrvResolver.resolve_service")
def test_federation_client_unsafe_ip(self, resolver):
self.sydent.run()

request, channel = make_request(
self.sydent.reactor,
"POST",
"/_matrix/identity/v2/account/register",
{
"access_token": "foo",
"expires_in": 300,
"matrix_server_name": "example.com",
"token_type": "Bearer",
},
)

resolver.return_value = defer.succeed(
[
Server(
host=self.unsafe_domain,
port=443,
priority=1,
weight=1,
expires=100,
)
]
)

request.render(self.sydent.servlets.registerServlet)

self.assertNot(self.reactor.tcpClients)

self.assertEqual(channel.code, 500)

def _get_http_request(self, expected_host, expected_port):
clients = self.reactor.tcpClients
(host, port, factory, _timeout, _bindAddress) = clients[-1]
self.assertEqual(host, expected_host)
self.assertEqual(port, expected_port)

# complete the connection and wire it up to a fake transport
protocol = factory.buildProtocol(None)
transport = StringTransport()
protocol.makeConnection(transport)

return transport, protocol
@@ -2,9 +2,19 @@
from io import BytesIO
import logging
import os

from typing import Dict
import attr
from six import text_type
from zope.interface import implementer
from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.defer import fail, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
IHostnameResolver,
IReactorPluggableNameResolver,
IResolverSimple,
)

from twisted.internet import address
import twisted.logger
from twisted.web.http_headers import Headers
@@ -53,13 +63,13 @@ def make_sydent(test_config={}):
# Use an in-memory SQLite database. Note that the database isn't cleaned up between
# tests, so by default the same database will be used for each test if changed to be
# a file on disk.
if 'db' not in test_config:
test_config['db'] = {'db.file': ':memory:'}
if "db" not in test_config:
test_config["db"] = {"db.file": ":memory:"}
else:
test_config['db'].setdefault('db.file', ':memory:')
test_config["db"].setdefault("db.file", ":memory:")

reactor = MemoryReactorClock()
return Sydent(reactor=reactor, cfg=parse_config_dict(test_config))
reactor = ResolvingMemoryReactorClock()
return Sydent(reactor=reactor, cfg=parse_config_dict(test_config), use_tls_for_federation=False)


@attr.s
@@ -149,6 +159,7 @@ def getPeerCertificate(self):

class FakeSite:
"""A fake Twisted Web Site."""

pass


@@ -191,10 +202,7 @@ def make_request(
path = path.encode("ascii")

# Decorate it to be the full path, if we're using shorthand
if (
shorthand
and not path.startswith(b"/_matrix")
):
if shorthand and not path.startswith(b"/_matrix"):
path = b"/_matrix/identity/v2/" + path
path = path.replace(b"//", b"/")

@@ -253,10 +261,7 @@ def setup_logging():
"""
root_logger = logging.getLogger()

log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s"
" - %(message)s"
)
log_format = "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s" " - %(message)s"

handler = ToTwistedHandler()
formatter = logging.Formatter(log_format)
@@ -268,3 +273,26 @@ def setup_logging():


setup_logging()


@implementer(IReactorPluggableNameResolver)
class ResolvingMemoryReactorClock(MemoryReactorClock):
"""
A MemoryReactorClock that supports name resolution.
"""

def __init__(self):
lookups = self.lookups = {} # type: Dict[str, str]

@implementer(IResolverSimple)
class FakeResolver:
def getHostByName(self, name, timeout=None):
if name not in lookups:
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
return succeed(lookups[name])

self.nameResolver = SimpleResolverComplexifier(FakeResolver())
super().__init__()

def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
raise NotImplementedError()