|
17 | 17 | from mcp.client.auth.utils import ( |
18 | 18 | build_oauth_authorization_server_metadata_discovery_urls, |
19 | 19 | build_protected_resource_metadata_discovery_urls, |
| 20 | + create_client_info_from_metadata_url, |
20 | 21 | create_oauth_metadata_request, |
21 | 22 | extract_field_from_www_auth, |
22 | 23 | extract_resource_metadata_from_www_auth, |
23 | 24 | extract_scope_from_www_auth, |
24 | 25 | get_client_metadata_scopes, |
25 | 26 | handle_registration_response, |
| 27 | + is_valid_client_metadata_url, |
| 28 | + should_use_client_metadata_url, |
26 | 29 | ) |
27 | 30 | from mcp.shared.auth import ( |
28 | 31 | OAuthClientInformationFull, |
@@ -1783,3 +1786,294 @@ def test_extract_field_from_www_auth_invalid_cases( |
1783 | 1786 |
|
1784 | 1787 | result = extract_field_from_www_auth(init_response, field_name) |
1785 | 1788 | assert result is None, f"Should return None for {description}" |
| 1789 | + |
| 1790 | + |
| 1791 | +class TestCIMD: |
| 1792 | + """Test SEP-991 Client ID Metadata Document (CIMD) support.""" |
| 1793 | + |
| 1794 | + @pytest.mark.parametrize( |
| 1795 | + "url,expected", |
| 1796 | + [ |
| 1797 | + # Valid CIMD URLs |
| 1798 | + ("https://example.com/client", True), |
| 1799 | + ("https://example.com/client-metadata.json", True), |
| 1800 | + ("https://example.com/path/to/client", True), |
| 1801 | + ("https://example.com:8443/client", True), |
| 1802 | + # Invalid URLs - HTTP (not HTTPS) |
| 1803 | + ("http://example.com/client", False), |
| 1804 | + # Invalid URLs - root path |
| 1805 | + ("https://example.com", False), |
| 1806 | + ("https://example.com/", False), |
| 1807 | + # Invalid URLs - None or empty |
| 1808 | + (None, False), |
| 1809 | + ("", False), |
| 1810 | + ], |
| 1811 | + ) |
| 1812 | + def test_is_valid_client_metadata_url(self, url: str | None, expected: bool): |
| 1813 | + """Test CIMD URL validation.""" |
| 1814 | + assert is_valid_client_metadata_url(url) == expected |
| 1815 | + |
| 1816 | + def test_should_use_client_metadata_url_when_server_supports(self): |
| 1817 | + """Test that CIMD is used when server supports it and URL is provided.""" |
| 1818 | + oauth_metadata = OAuthMetadata( |
| 1819 | + issuer=AnyHttpUrl("https://auth.example.com"), |
| 1820 | + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), |
| 1821 | + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), |
| 1822 | + client_id_metadata_document_supported=True, |
| 1823 | + ) |
| 1824 | + assert should_use_client_metadata_url(oauth_metadata, "https://example.com/client") is True |
| 1825 | + |
| 1826 | + def test_should_not_use_client_metadata_url_when_server_does_not_support(self): |
| 1827 | + """Test that CIMD is not used when server doesn't support it.""" |
| 1828 | + oauth_metadata = OAuthMetadata( |
| 1829 | + issuer=AnyHttpUrl("https://auth.example.com"), |
| 1830 | + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), |
| 1831 | + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), |
| 1832 | + client_id_metadata_document_supported=False, |
| 1833 | + ) |
| 1834 | + assert should_use_client_metadata_url(oauth_metadata, "https://example.com/client") is False |
| 1835 | + |
| 1836 | + def test_should_not_use_client_metadata_url_when_not_provided(self): |
| 1837 | + """Test that CIMD is not used when no URL is provided.""" |
| 1838 | + oauth_metadata = OAuthMetadata( |
| 1839 | + issuer=AnyHttpUrl("https://auth.example.com"), |
| 1840 | + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), |
| 1841 | + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), |
| 1842 | + client_id_metadata_document_supported=True, |
| 1843 | + ) |
| 1844 | + assert should_use_client_metadata_url(oauth_metadata, None) is False |
| 1845 | + |
| 1846 | + def test_should_not_use_client_metadata_url_when_no_metadata(self): |
| 1847 | + """Test that CIMD is not used when OAuth metadata is None.""" |
| 1848 | + assert should_use_client_metadata_url(None, "https://example.com/client") is False |
| 1849 | + |
| 1850 | + def test_create_client_info_from_metadata_url(self): |
| 1851 | + """Test creating client info from CIMD URL.""" |
| 1852 | + client_info = create_client_info_from_metadata_url( |
| 1853 | + "https://example.com/client", |
| 1854 | + redirect_uris=[AnyUrl("http://localhost:3030/callback")], |
| 1855 | + ) |
| 1856 | + assert client_info.client_id == "https://example.com/client" |
| 1857 | + assert client_info.token_endpoint_auth_method == "none" |
| 1858 | + assert client_info.redirect_uris == [AnyUrl("http://localhost:3030/callback")] |
| 1859 | + assert client_info.client_secret is None |
| 1860 | + |
| 1861 | + def test_oauth_provider_with_valid_client_metadata_url( |
| 1862 | + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage |
| 1863 | + ): |
| 1864 | + """Test OAuthClientProvider initialization with valid client_metadata_url.""" |
| 1865 | + |
| 1866 | + async def redirect_handler(url: str) -> None: |
| 1867 | + pass # pragma: no cover |
| 1868 | + |
| 1869 | + async def callback_handler() -> tuple[str, str | None]: |
| 1870 | + return "test_auth_code", "test_state" # pragma: no cover |
| 1871 | + |
| 1872 | + provider = OAuthClientProvider( |
| 1873 | + server_url="https://api.example.com/v1/mcp", |
| 1874 | + client_metadata=client_metadata, |
| 1875 | + storage=mock_storage, |
| 1876 | + redirect_handler=redirect_handler, |
| 1877 | + callback_handler=callback_handler, |
| 1878 | + client_metadata_url="https://example.com/client", |
| 1879 | + ) |
| 1880 | + assert provider.context.client_metadata_url == "https://example.com/client" |
| 1881 | + |
| 1882 | + def test_oauth_provider_with_invalid_client_metadata_url_raises_error( |
| 1883 | + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage |
| 1884 | + ): |
| 1885 | + """Test OAuthClientProvider raises error for invalid client_metadata_url.""" |
| 1886 | + |
| 1887 | + async def redirect_handler(url: str) -> None: |
| 1888 | + pass # pragma: no cover |
| 1889 | + |
| 1890 | + async def callback_handler() -> tuple[str, str | None]: |
| 1891 | + return "test_auth_code", "test_state" # pragma: no cover |
| 1892 | + |
| 1893 | + with pytest.raises(ValueError) as exc_info: |
| 1894 | + OAuthClientProvider( |
| 1895 | + server_url="https://api.example.com/v1/mcp", |
| 1896 | + client_metadata=client_metadata, |
| 1897 | + storage=mock_storage, |
| 1898 | + redirect_handler=redirect_handler, |
| 1899 | + callback_handler=callback_handler, |
| 1900 | + client_metadata_url="http://example.com/client", # HTTP instead of HTTPS |
| 1901 | + ) |
| 1902 | + assert "HTTPS URL with a non-root pathname" in str(exc_info.value) |
| 1903 | + |
| 1904 | + @pytest.mark.anyio |
| 1905 | + async def test_auth_flow_uses_cimd_when_server_supports( |
| 1906 | + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage |
| 1907 | + ): |
| 1908 | + """Test that auth flow uses CIMD URL as client_id when server supports it.""" |
| 1909 | + |
| 1910 | + async def redirect_handler(url: str) -> None: |
| 1911 | + pass # pragma: no cover |
| 1912 | + |
| 1913 | + async def callback_handler() -> tuple[str, str | None]: |
| 1914 | + return "test_auth_code", "test_state" # pragma: no cover |
| 1915 | + |
| 1916 | + provider = OAuthClientProvider( |
| 1917 | + server_url="https://api.example.com/v1/mcp", |
| 1918 | + client_metadata=client_metadata, |
| 1919 | + storage=mock_storage, |
| 1920 | + redirect_handler=redirect_handler, |
| 1921 | + callback_handler=callback_handler, |
| 1922 | + client_metadata_url="https://example.com/client", |
| 1923 | + ) |
| 1924 | + |
| 1925 | + provider.context.current_tokens = None |
| 1926 | + provider.context.token_expiry_time = None |
| 1927 | + provider._initialized = True |
| 1928 | + |
| 1929 | + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") |
| 1930 | + auth_flow = provider.async_auth_flow(test_request) |
| 1931 | + |
| 1932 | + # First request |
| 1933 | + request = await auth_flow.__anext__() |
| 1934 | + assert "Authorization" not in request.headers |
| 1935 | + |
| 1936 | + # Send 401 response |
| 1937 | + response = httpx.Response(401, headers={}, request=test_request) |
| 1938 | + |
| 1939 | + # PRM discovery |
| 1940 | + prm_request = await auth_flow.asend(response) |
| 1941 | + prm_response = httpx.Response( |
| 1942 | + 200, |
| 1943 | + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', |
| 1944 | + request=prm_request, |
| 1945 | + ) |
| 1946 | + |
| 1947 | + # OAuth metadata discovery |
| 1948 | + oauth_request = await auth_flow.asend(prm_response) |
| 1949 | + oauth_response = httpx.Response( |
| 1950 | + 200, |
| 1951 | + content=( |
| 1952 | + b'{"issuer": "https://auth.example.com", ' |
| 1953 | + b'"authorization_endpoint": "https://auth.example.com/authorize", ' |
| 1954 | + b'"token_endpoint": "https://auth.example.com/token", ' |
| 1955 | + b'"client_id_metadata_document_supported": true}' |
| 1956 | + ), |
| 1957 | + request=oauth_request, |
| 1958 | + ) |
| 1959 | + |
| 1960 | + # Mock authorization |
| 1961 | + provider._perform_authorization_code_grant = mock.AsyncMock( |
| 1962 | + return_value=("test_auth_code", "test_code_verifier") |
| 1963 | + ) |
| 1964 | + |
| 1965 | + # Should skip DCR and go directly to token exchange |
| 1966 | + token_request = await auth_flow.asend(oauth_response) |
| 1967 | + assert token_request.method == "POST" |
| 1968 | + assert str(token_request.url) == "https://auth.example.com/token" |
| 1969 | + |
| 1970 | + # Verify client_id is the CIMD URL |
| 1971 | + content = token_request.content.decode() |
| 1972 | + assert "client_id=https%3A%2F%2Fexample.com%2Fclient" in content |
| 1973 | + |
| 1974 | + # Verify client info was set correctly |
| 1975 | + assert provider.context.client_info is not None |
| 1976 | + assert provider.context.client_info.client_id == "https://example.com/client" |
| 1977 | + assert provider.context.client_info.token_endpoint_auth_method == "none" |
| 1978 | + |
| 1979 | + # Complete the flow |
| 1980 | + token_response = httpx.Response( |
| 1981 | + 200, |
| 1982 | + content=b'{"access_token": "test_token", "token_type": "Bearer", "expires_in": 3600}', |
| 1983 | + request=token_request, |
| 1984 | + ) |
| 1985 | + |
| 1986 | + final_request = await auth_flow.asend(token_response) |
| 1987 | + assert final_request.headers["Authorization"] == "Bearer test_token" |
| 1988 | + |
| 1989 | + final_response = httpx.Response(200, request=final_request) |
| 1990 | + try: |
| 1991 | + await auth_flow.asend(final_response) |
| 1992 | + except StopAsyncIteration: |
| 1993 | + pass |
| 1994 | + |
| 1995 | + @pytest.mark.anyio |
| 1996 | + async def test_auth_flow_falls_back_to_dcr_when_no_cimd_support( |
| 1997 | + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage |
| 1998 | + ): |
| 1999 | + """Test that auth flow falls back to DCR when server doesn't support CIMD.""" |
| 2000 | + |
| 2001 | + async def redirect_handler(url: str) -> None: |
| 2002 | + pass # pragma: no cover |
| 2003 | + |
| 2004 | + async def callback_handler() -> tuple[str, str | None]: |
| 2005 | + return "test_auth_code", "test_state" # pragma: no cover |
| 2006 | + |
| 2007 | + provider = OAuthClientProvider( |
| 2008 | + server_url="https://api.example.com/v1/mcp", |
| 2009 | + client_metadata=client_metadata, |
| 2010 | + storage=mock_storage, |
| 2011 | + redirect_handler=redirect_handler, |
| 2012 | + callback_handler=callback_handler, |
| 2013 | + client_metadata_url="https://example.com/client", |
| 2014 | + ) |
| 2015 | + |
| 2016 | + provider.context.current_tokens = None |
| 2017 | + provider.context.token_expiry_time = None |
| 2018 | + provider._initialized = True |
| 2019 | + |
| 2020 | + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") |
| 2021 | + auth_flow = provider.async_auth_flow(test_request) |
| 2022 | + |
| 2023 | + # First request |
| 2024 | + request = await auth_flow.__anext__() |
| 2025 | + |
| 2026 | + # Send 401 response |
| 2027 | + response = httpx.Response(401, headers={}, request=test_request) |
| 2028 | + |
| 2029 | + # PRM discovery |
| 2030 | + prm_request = await auth_flow.asend(response) |
| 2031 | + prm_response = httpx.Response( |
| 2032 | + 200, |
| 2033 | + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', |
| 2034 | + request=prm_request, |
| 2035 | + ) |
| 2036 | + |
| 2037 | + # OAuth metadata discovery - server does NOT support CIMD |
| 2038 | + oauth_request = await auth_flow.asend(prm_response) |
| 2039 | + oauth_response = httpx.Response( |
| 2040 | + 200, |
| 2041 | + content=( |
| 2042 | + b'{"issuer": "https://auth.example.com", ' |
| 2043 | + b'"authorization_endpoint": "https://auth.example.com/authorize", ' |
| 2044 | + b'"token_endpoint": "https://auth.example.com/token", ' |
| 2045 | + b'"registration_endpoint": "https://auth.example.com/register"}' |
| 2046 | + ), |
| 2047 | + request=oauth_request, |
| 2048 | + ) |
| 2049 | + |
| 2050 | + # Should proceed to DCR instead of skipping it |
| 2051 | + registration_request = await auth_flow.asend(oauth_response) |
| 2052 | + assert registration_request.method == "POST" |
| 2053 | + assert str(registration_request.url) == "https://auth.example.com/register" |
| 2054 | + |
| 2055 | + # Complete the flow to avoid generator cleanup issues |
| 2056 | + registration_response = httpx.Response( |
| 2057 | + 201, |
| 2058 | + content=b'{"client_id": "dcr_client_id", "redirect_uris": ["http://localhost:3030/callback"]}', |
| 2059 | + request=registration_request, |
| 2060 | + ) |
| 2061 | + |
| 2062 | + # Mock authorization |
| 2063 | + provider._perform_authorization_code_grant = mock.AsyncMock( |
| 2064 | + return_value=("test_auth_code", "test_code_verifier") |
| 2065 | + ) |
| 2066 | + |
| 2067 | + token_request = await auth_flow.asend(registration_response) |
| 2068 | + token_response = httpx.Response( |
| 2069 | + 200, |
| 2070 | + content=b'{"access_token": "test_token", "token_type": "Bearer", "expires_in": 3600}', |
| 2071 | + request=token_request, |
| 2072 | + ) |
| 2073 | + |
| 2074 | + final_request = await auth_flow.asend(token_response) |
| 2075 | + final_response = httpx.Response(200, request=final_request) |
| 2076 | + try: |
| 2077 | + await auth_flow.asend(final_response) |
| 2078 | + except StopAsyncIteration: |
| 2079 | + pass |
0 commit comments