Skip to content

Commit c485d21

Browse files
committed
Add CIDR ip notation to ALLOWED_HOSTS
1 parent b1c7ae9 commit c485d21

File tree

5 files changed

+217
-68
lines changed

5 files changed

+217
-68
lines changed

plain/plain/http/hosts.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
"""
2+
Host validation utilities for ALLOWED_HOSTS functionality.
3+
4+
This module provides functions for validating hosts against allowed patterns,
5+
including domain patterns, wildcards, and CIDR notation for IP ranges.
6+
"""
7+
8+
import ipaddress
9+
10+
from plain.utils.regex_helper import _lazy_re_compile
11+
12+
host_validation_re = _lazy_re_compile(
13+
r"^([a-z0-9.-]+|\[[a-f0-9]*:[a-f0-9\.:]+\])(:[0-9]+)?$"
14+
)
15+
16+
17+
def split_domain_port(host: str) -> tuple[str, str]:
18+
"""
19+
Return a (domain, port) tuple from a given host.
20+
21+
Returned domain is lowercased. If the host is invalid, the domain will be
22+
empty.
23+
"""
24+
host = host.lower()
25+
26+
if not host_validation_re.match(host):
27+
return "", ""
28+
29+
if host[-1] == "]":
30+
# It's an IPv6 address without a port.
31+
return host, ""
32+
bits = host.rsplit(":", 1)
33+
domain, port = bits if len(bits) == 2 else (bits[0], "")
34+
# Remove a trailing dot (if present) from the domain.
35+
domain = domain.removesuffix(".")
36+
return domain, port
37+
38+
39+
def _is_same_domain(host: str, pattern: str) -> bool:
40+
"""
41+
Return ``True`` if the host is either an exact match or a match
42+
to the wildcard pattern.
43+
44+
Any pattern beginning with a period matches a domain and all of its
45+
subdomains. (e.g. ``.example.com`` matches ``example.com`` and
46+
``foo.example.com``). Anything else is an exact string match.
47+
"""
48+
if not pattern:
49+
return False
50+
51+
pattern = pattern.lower()
52+
return (
53+
pattern[0] == "."
54+
and (host.endswith(pattern) or host == pattern[1:])
55+
or pattern == host
56+
)
57+
58+
59+
def _parse_ip_address(
60+
host: str,
61+
) -> ipaddress.IPv4Address | ipaddress.IPv6Address | None:
62+
"""
63+
Parse a host string as an IP address (IPv4 or IPv6).
64+
65+
Returns the ipaddress.ip_address object if valid, None otherwise.
66+
Handles both bracketed and non-bracketed IPv6 addresses.
67+
"""
68+
# Remove brackets from IPv6 addresses
69+
if host.startswith("[") and host.endswith("]"):
70+
host = host[1:-1]
71+
72+
try:
73+
return ipaddress.ip_address(host)
74+
except ValueError:
75+
return None
76+
77+
78+
def _parse_cidr_pattern(
79+
pattern: str,
80+
) -> ipaddress.IPv4Network | ipaddress.IPv6Network | None:
81+
"""
82+
Parse a CIDR pattern and return the network object if valid.
83+
84+
Returns the ipaddress.ip_network object if valid CIDR notation, None otherwise.
85+
"""
86+
# Check if it contains a slash (required for CIDR)
87+
if "/" not in pattern:
88+
return None
89+
90+
# Remove brackets from IPv6 CIDR patterns
91+
test_pattern = pattern
92+
if pattern.startswith("[") and "]/" in pattern:
93+
# Handle format like [2001:db8::]/32
94+
bracket_end = pattern.find("]/")
95+
if bracket_end != -1:
96+
ip_part = pattern[1:bracket_end]
97+
cidr_part = pattern[bracket_end + 2 :]
98+
test_pattern = f"{ip_part}/{cidr_part}"
99+
elif pattern.startswith("[") and pattern.endswith("]") and "/" in pattern:
100+
# Handle format like [2001:db8::/32] (slash inside brackets)
101+
test_pattern = pattern[1:-1]
102+
103+
try:
104+
return ipaddress.ip_network(test_pattern, strict=False)
105+
except ValueError:
106+
return None
107+
108+
109+
def validate_host(host: str, allowed_hosts: list[str]) -> bool:
110+
"""
111+
Validate the given host for this site.
112+
113+
Check that the host looks valid and matches a host or host pattern in the
114+
given list of ``allowed_hosts``. Supported patterns:
115+
116+
- ``*`` matches anything
117+
- ``.example.com`` matches a domain and all its subdomains
118+
(e.g. ``example.com`` and ``sub.example.com``)
119+
- ``example.com`` matches exactly that domain
120+
- ``192.168.1.0/24`` matches IP addresses in that CIDR range
121+
- ``[2001:db8::]/32`` matches IPv6 addresses in that CIDR range
122+
- ``192.168.1.1`` matches that exact IP address
123+
124+
Note: This function assumes that the given host is lowercased and has
125+
already had the port, if any, stripped off.
126+
127+
Return ``True`` for a valid host, ``False`` otherwise.
128+
"""
129+
# Parse the host as an IP address if possible
130+
host_ip = _parse_ip_address(host)
131+
132+
for pattern in allowed_hosts:
133+
# Wildcard matches everything
134+
if pattern == "*":
135+
return True
136+
137+
# Check CIDR notation patterns using walrus operator
138+
if network := _parse_cidr_pattern(pattern):
139+
if host_ip and host_ip in network:
140+
return True
141+
continue
142+
143+
# For non-CIDR patterns, use existing domain matching logic
144+
if _is_same_domain(host, pattern):
145+
return True
146+
147+
return False

plain/plain/http/request.py

Lines changed: 2 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,9 @@
2727
MultiValueDict,
2828
)
2929
from plain.utils.encoding import iri_to_uri
30-
from plain.utils.http import is_same_domain, parse_header_parameters
31-
from plain.utils.regex_helper import _lazy_re_compile
30+
from plain.utils.http import parse_header_parameters
3231

33-
host_validation_re = _lazy_re_compile(
34-
r"^([a-z0-9.-]+|\[[a-f0-9]*:[a-f0-9\.:]+\])(:[0-9]+)?$"
35-
)
32+
from .hosts import split_domain_port, validate_host
3633

3734

3835
class UnreadablePostError(OSError):
@@ -702,47 +699,5 @@ def bytes_to_text(s, encoding):
702699
return s
703700

704701

705-
def split_domain_port(host):
706-
"""
707-
Return a (domain, port) tuple from a given host.
708-
709-
Returned domain is lowercased. If the host is invalid, the domain will be
710-
empty.
711-
"""
712-
host = host.lower()
713-
714-
if not host_validation_re.match(host):
715-
return "", ""
716-
717-
if host[-1] == "]":
718-
# It's an IPv6 address without a port.
719-
return host, ""
720-
bits = host.rsplit(":", 1)
721-
domain, port = bits if len(bits) == 2 else (bits[0], "")
722-
# Remove a trailing dot (if present) from the domain.
723-
domain = domain.removesuffix(".")
724-
return domain, port
725-
726-
727-
def validate_host(host, allowed_hosts):
728-
"""
729-
Validate the given host for this site.
730-
731-
Check that the host looks valid and matches a host or host pattern in the
732-
given list of ``allowed_hosts``. Any pattern beginning with a period
733-
matches a domain and all its subdomains (e.g. ``.example.com`` matches
734-
``example.com`` and any subdomain), ``*`` matches anything, and anything
735-
else must match exactly.
736-
737-
Note: This function assumes that the given host is lowercased and has
738-
already had the port, if any, stripped off.
739-
740-
Return ``True`` for a valid host, ``False`` otherwise.
741-
"""
742-
return any(
743-
pattern == "*" or is_same_domain(host, pattern) for pattern in allowed_hosts
744-
)
745-
746-
747702
def parse_accept_header(header):
748703
return [MediaType(token) for token in header.split(",") if token.strip()]

plain/plain/runtime/global_settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
# Hosts/domain names that are valid for this site.
2626
# "*" matches anything, ".example.com" matches example.com and all subdomains
27+
# "192.168.1.0/24" matches IP addresses in that CIDR range
2728
ALLOWED_HOSTS: list[str] = []
2829

2930
# Default headers for all responses.

plain/plain/utils/http.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -92,26 +92,6 @@ def int_to_base36(i):
9292
return b36
9393

9494

95-
def is_same_domain(host, pattern):
96-
"""
97-
Return ``True`` if the host is either an exact match or a match
98-
to the wildcard pattern.
99-
100-
Any pattern beginning with a period matches a domain and all of its
101-
subdomains. (e.g. ``.example.com`` matches ``example.com`` and
102-
``foo.example.com``). Anything else is an exact string match.
103-
"""
104-
if not pattern:
105-
return False
106-
107-
pattern = pattern.lower()
108-
return (
109-
pattern[0] == "."
110-
and (host.endswith(pattern) or host == pattern[1:])
111-
or pattern == host
112-
)
113-
114-
11595
def escape_leading_slashes(url):
11696
"""
11797
If redirecting to an absolute path (two leading slashes), a slash must be
Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from plain.http.request import split_domain_port, validate_host
3+
from plain.http.hosts import split_domain_port, validate_host
44

55

66
@pytest.mark.parametrize(
@@ -146,3 +146,69 @@ def test_split_domain_port(host, expected_domain, expected_port):
146146
def test_validate_host(host, allowed_hosts, expected):
147147
"""Test validate_host function with various inputs."""
148148
assert validate_host(host, allowed_hosts) is expected
149+
150+
151+
@pytest.mark.parametrize(
152+
("host", "allowed_hosts", "expected"),
153+
[
154+
# IPv4 CIDR tests
155+
("192.168.1.100", ["192.168.1.0/24"], True),
156+
("192.168.1.1", ["192.168.1.0/24"], True),
157+
("192.168.1.254", ["192.168.1.0/24"], True),
158+
("192.168.2.100", ["192.168.1.0/24"], False),
159+
("10.0.5.1", ["10.0.0.0/8"], True),
160+
("172.16.0.1", ["10.0.0.0/8"], False),
161+
# IPv4 single IP as CIDR
162+
("192.168.1.1", ["192.168.1.1/32"], True),
163+
("192.168.1.2", ["192.168.1.1/32"], False),
164+
# IPv4 larger networks
165+
("172.16.5.10", ["172.16.0.0/12"], True),
166+
("172.32.5.10", ["172.16.0.0/12"], False),
167+
("127.0.0.1", ["127.0.0.0/8"], True),
168+
# IPv6 CIDR tests
169+
("[2001:db8::1]", ["[2001:db8::]/32"], True),
170+
("[2001:db8:1::1]", ["[2001:db8::]/32"], True),
171+
("[2001:db9::1]", ["[2001:db8::]/32"], False),
172+
("[::1]", ["[::]/0"], True), # Match everything IPv6
173+
("[2001:db8::1]", ["[fe80::]/10"], False),
174+
# IPv6 without brackets in pattern (should still work)
175+
("[2001:db8::1]", ["2001:db8::/32"], True),
176+
("[2001:db9::1]", ["2001:db8::/32"], False),
177+
# IPv6 single address as CIDR
178+
("[::1]", ["[::1]/128"], True),
179+
("[::2]", ["[::1]/128"], False),
180+
# Mixed CIDR and domain patterns
181+
("192.168.1.50", ["192.168.1.0/24", ".example.com"], True),
182+
("sub.example.com", ["192.168.1.0/24", ".example.com"], True),
183+
("192.168.2.50", ["192.168.1.0/24", ".example.com"], False),
184+
("other.com", ["192.168.1.0/24", ".example.com"], False),
185+
# Multiple CIDR patterns
186+
("192.168.1.50", ["10.0.0.0/8", "192.168.0.0/16"], True),
187+
("10.5.0.1", ["10.0.0.0/8", "192.168.0.0/16"], True),
188+
("172.16.0.1", ["10.0.0.0/8", "192.168.0.0/16"], False),
189+
# Domain names should not match CIDR patterns
190+
("example.com", ["192.168.1.0/24"], False),
191+
("192.168.1.com", ["192.168.1.0/24"], False),
192+
# Non-IP strings should not match CIDR
193+
("not-an-ip", ["192.168.1.0/24"], False),
194+
("192.168.1", ["192.168.1.0/24"], False), # Incomplete IP
195+
# Invalid CIDR patterns should be ignored (treated as literal)
196+
("192.168.1.0/24", ["192.168.1.0/24"], False), # Literal match of CIDR string
197+
("192.168.1.100", ["192.168.1.0/999"], False), # Invalid CIDR
198+
("192.168.1.100", ["192.168.1.0/"], False), # Invalid CIDR
199+
# CIDR with wildcard - wildcard should take precedence
200+
("anything.com", ["*", "192.168.1.0/24"], True),
201+
("172.16.0.1", ["*", "192.168.1.0/24"], True),
202+
# Edge cases
203+
("0.0.0.0", ["0.0.0.0/0"], True), # Match all IPv4
204+
("255.255.255.255", ["0.0.0.0/0"], True),
205+
(
206+
"[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]",
207+
["[::]/0"],
208+
True,
209+
), # Match all IPv6
210+
],
211+
)
212+
def test_validate_host_cidr(host, allowed_hosts, expected):
213+
"""Test validate_host function with CIDR notation patterns."""
214+
assert validate_host(host, allowed_hosts) is expected

0 commit comments

Comments
 (0)