Skip to content

Commit

Permalink
Rework hostname validation to make port checking stricter.
Browse files Browse the repository at this point in the history
Instead of using a regex to validate the entire hostname + port
combination, we now split the hostname into components and check each
component separately. This makes the regex a bit simpler and allows us
to validate the port number better, including that it belongs to the
valid range.
  • Loading branch information
Denis Kasak authored and clokep committed Mar 24, 2021
1 parent 809ad96 commit 9e57334
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 6 deletions.
17 changes: 14 additions & 3 deletions sydent/http/servlets/registerservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from sydent.http.servlets import get_args, jsonwrap, deferjsonwrap, send_cors
from sydent.http.httpclient import FederationHttpClient
from sydent.users.tokens import issueToken

from sydent.util.stringutils import is_valid_hostname

logger = logging.getLogger(__name__)

Expand All @@ -47,9 +47,20 @@ def render_POST(self, request):

args = get_args(request, ('matrix_server_name', 'access_token'))

hostname = args['matrix_server_name'].lower()

if not is_valid_hostname(hostname):
request.setResponseCode(400)
return {
'errcode': 'M_INVALID_PARAM',
'error': 'matrix_server_name must be a valid hostname'
}

result = yield self.client.get_json(
"matrix://%s/_matrix/federation/v1/openid/userinfo?access_token=%s" % (
args['matrix_server_name'], urllib.parse.quote(args['access_token']),
"matrix://%s/_matrix/federation/v1/openid/userinfo?access_token=%s"
% (
hostname,
urllib.parse.quote(args['access_token']),
),
1024 * 5,
)
Expand Down
42 changes: 41 additions & 1 deletion sydent/util/stringutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,54 @@
# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")

# hostname/domain name + optional port
# https://regex101.com/r/OyN1lg/2
hostname_regex = re.compile(
r"^(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)(?:\.[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)*$",
flags=re.IGNORECASE)


def is_valid_client_secret(client_secret):
"""Validate that a given string matches the client_secret regex defined by the spec
:param client_secret: The client_secret to validate
:type client_secret: unicode
:type client_secret: str
:return: Whether the client_secret is valid
:rtype: bool
"""
return client_secret_regex.match(client_secret) is not None


def is_valid_hostname(string: str) -> bool:
"""Validate that a given string is a valid hostname or domain name, with an
optional port number.
For domain names, this only validates that the form is right (for
instance, it doesn't check that the TLD is valid). If a port is
specified, it has to be a valid port number.
:param string: The string to validate
:type string: str
:return: Whether the input is a valid hostname
:rtype: bool
"""

host_parts = string.split(":", 1)

if len(host_parts) == 1:
return hostname_regex.match(string) is not None
else:
host, port = host_parts
valid_hostname = hostname_regex.match(host) is not None

try:
port_num = int(port)
valid_port = (
port == str(port_num) # exclude things like '08090' or ' 8090'
and 1 <= port_num < 65536
except ValueError:
valid_port = False

return valid_hostname and valid_port
4 changes: 2 additions & 2 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def setUp(self):
self.sydent.db.commit()

def test_can_read_token_from_headers(self):
"""Tests that Sydent correct extracts an auth token from request headers"""
"""Tests that Sydent correctly extracts an auth token from request headers"""
self.sydent.run()

request, _ = make_request(
Expand All @@ -59,7 +59,7 @@ def test_can_read_token_from_headers(self):
self.assertEqual(token, self.test_token)

def test_can_read_token_from_query_parameters(self):
"""Tests that Sydent correct extracts an auth token from query parameters"""
"""Tests that Sydent correctly extracts an auth token from query parameters"""
self.sydent.run()

request, _ = make_request(
Expand Down
45 changes: 45 additions & 0 deletions tests/test_register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-

# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.trial import unittest

from tests.utils import make_request, make_sydent


class RegisterTestCase(unittest.TestCase):
"""Tests Sydent's register servlet"""
def setUp(self):
# Create a new sydent
self.sydent = make_sydent()

def test_sydent_rejects_invalid_hostname(self):
"""Tests that the /register endpoint rejects an invalid hostname passed as matrix_server_name"""
self.sydent.run()

bad_hostname = "example.com#"

request, channel = make_request(
self.sydent.reactor,
"POST",
"/_matrix/identity/v2/account/register",
content={
"matrix_server_name": bad_hostname,
"access_token": "foo"
})

request.render(self.sydent.servlets.registerServlet)

self.assertEqual(channel.code, 400)
26 changes: 26 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from twisted.trial import unittest
from sydent.util.stringutils import is_valid_hostname


class UtilTests(unittest.TestCase):
"""Tests Sydent utility functions."""
def test_is_valid_hostname(self):
"""Tests that the is_valid_hostname function accepts only valid
hostnames (or domain names), with optional port number.
"""

self.assertTrue(is_valid_hostname("example.com"))
self.assertTrue(is_valid_hostname("EXAMPLE.COM"))
self.assertTrue(is_valid_hostname("ExAmPlE.CoM"))
self.assertTrue(is_valid_hostname("example.com:4242"))
self.assertTrue(is_valid_hostname("localhost"))
self.assertTrue(is_valid_hostname("localhost:9000"))
self.assertTrue(is_valid_hostname("a.b:1234"))

self.assertFalse(is_valid_hostname("example.com:65536"))
self.assertFalse(is_valid_hostname("example.com:0"))
self.assertFalse(is_valid_hostname("example.com:a"))
self.assertFalse(is_valid_hostname("example.com:04242"))
self.assertFalse(is_valid_hostname("example.com: 4242"))
self.assertFalse(is_valid_hostname("example.com/example.com"))
self.assertFalse(is_valid_hostname("example.com#example.com"))

0 comments on commit 9e57334

Please sign in to comment.