Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/config/test_legacy_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def test_env_pending_overrides_apply_on_activation(mock_config_dir, config_manag
assert current_manager.profile == "dev"
dev_web = current_manager.get_section("web")
assert dev_web.enable_caching is False
assert dev_web.ssl_version == ssl.TLSVersion.TLSv1_2
assert dev_web.ssl_version == "TLSv1_2"
assert Env.current.enable_caching is False
assert Env.current.ssl_version == ssl.TLSVersion.TLSv1_2
assert Env.current.ssl_version == "TLSv1_2"
finally:
reload_config(profile="default")

Expand Down
22 changes: 22 additions & 0 deletions tests/config/test_web_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from __future__ import annotations

import ssl

import pytest

from tidy3d.config.sections import WebConfig


Expand All @@ -21,3 +25,21 @@ def test_build_api_url_returns_base_for_empty_path():
def test_build_api_url_without_base_returns_path():
web = WebConfig.model_construct(api_endpoint="")
assert web.build_api_url("/v1/tasks") == "v1/tasks"


@pytest.mark.parametrize(
("value", "expected"),
[
(ssl.TLSVersion.TLSv1, "TLSv1"),
(ssl.TLSVersion.TLSv1_1, "TLSv1_1"),
],
)
def test_web_config_normalizes_ssl_version_aliases(value, expected):
web = WebConfig(ssl_version=value)
assert web.ssl_version == expected


@pytest.mark.parametrize("value", ["", "TLSv2", "SSLv3", "udp1.0"])
def test_web_config_rejects_invalid_ssl_version(value):
with pytest.raises(ValueError):
WebConfig(ssl_version=value)
8 changes: 4 additions & 4 deletions tests/test_web/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ def test_tidy3d_env():


def test_set_ssl_version():
Env.set_ssl_version(ssl.TLSVersion.TLSv1_3)
assert Env.current.ssl_version == ssl.TLSVersion.TLSv1_3
Env.set_ssl_version("TLSv1_3")
assert Env.current.ssl_version == "TLSv1_3"

Env.set_ssl_version(ssl.TLSVersion.TLSv1_2)
assert Env.current.ssl_version == ssl.TLSVersion.TLSv1_2
assert Env.current.ssl_version == "TLSv1_2"

Env.set_ssl_version(ssl.TLSVersion.TLSv1_1)
assert Env.current.ssl_version == ssl.TLSVersion.TLSv1_1
assert Env.current.ssl_version == "TLSv1_1"

Env.set_ssl_version(None)
assert Env.current.ssl_version is None
9 changes: 4 additions & 5 deletions tidy3d/config/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from __future__ import annotations

import os
import ssl
import warnings
from pathlib import Path
from typing import Any, Optional
Expand Down Expand Up @@ -160,7 +159,7 @@ def __init__(
s3_region: Optional[str] = None,
ssl_verify: Optional[bool] = None,
enable_caching: Optional[bool] = None,
ssl_version: Optional[ssl.TLSVersion] = None,
ssl_version: Optional[str] = None,
env_vars: Optional[dict[str, str]] = None,
environment: Optional[LegacyEnvironment] = None,
) -> None:
Expand Down Expand Up @@ -239,11 +238,11 @@ def enable_caching(self, value: Optional[bool]) -> None:
self._set_pending("enable_caching", value)

@property
def ssl_version(self) -> Optional[ssl.TLSVersion]:
def ssl_version(self) -> Optional[str]:
return self._value("ssl_version")

@ssl_version.setter
def ssl_version(self, value: Optional[ssl.TLSVersion]) -> None:
def ssl_version(self, value: Optional[str]) -> None:
self._set_pending("ssl_version", value)

@property
Expand Down Expand Up @@ -363,7 +362,7 @@ def enable_caching(self, enable_caching: Optional[bool] = True) -> None:
config.enable_caching = enable_caching
self._sync_to_manager()

def set_ssl_version(self, ssl_version: Optional[ssl.TLSVersion]) -> None:
def set_ssl_version(self, ssl_version: Optional[str]) -> None:
config = self.current
config.ssl_version = ssl_version
self._sync_to_manager()
Expand Down
36 changes: 30 additions & 6 deletions tidy3d/config/sections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import os
import ssl
from os import PathLike
from pathlib import Path
from typing import Any, Literal, Optional
Expand All @@ -28,6 +27,8 @@
from .registry import get_manager as _get_attached_manager
from .registry import register_handler, register_section

TLS_VERSION_CHOICES = {"TLSv1", "TLSv1_1", "TLSv1_2", "TLSv1_3"}


class ConfigSection(BaseModel):
"""Base class for configuration sections."""
Expand Down Expand Up @@ -328,10 +329,13 @@ class WebConfig(ConfigSection):
le=300,
)

ssl_version: Optional[ssl.TLSVersion] = Field(
ssl_version: Optional[str] = Field(
None,
title="SSL/TLS version",
description="Optional SSL/TLS version to enforce for requests.",
description=(
"Optional TLS version override to enforce for requests. Accepts values such as "
"'TLSv1_2'."
),
)

env_vars: dict[str, str] = Field(
Expand All @@ -349,14 +353,34 @@ def to_dict(self, *, mask_secrets: bool = True) -> dict[str, Any]:
secret = data.get("apikey")
if isinstance(secret, SecretStr):
data["apikey"] = secret.get_secret_value()
ssl_version = data.get("ssl_version")
if isinstance(ssl_version, ssl.TLSVersion):
data["ssl_version"] = ssl_version.value
for field in ("api_endpoint", "website_endpoint"):
if field in data and data[field] is not None:
data[field] = str(data[field])
return data

@field_validator("ssl_version", mode="before")
@classmethod
def _convert_and_check_ssl_version_name(cls, value: Any) -> Optional[str]:
"""Convert SSL enum to string and check if valid.

Accepted examples:
"TLSv1"
"TLSv1_2"
ssl.TLSVersion.TLSv1_2.name -> "TLSv1_2"
"""
if value is None:
return None

# Prefer enum.name if present, otherwise raw string
candidate = getattr(value, "name", value)
candidate = str(candidate).strip()

if candidate not in TLS_VERSION_CHOICES:
allowed = ", ".join(sorted(TLS_VERSION_CHOICES))
raise ValueError(f"Invalid TLS version {candidate!r}. Must be one of: {allowed}")

return candidate

@field_validator("api_endpoint", "website_endpoint", mode="before")
@classmethod
def _validate_http_url(cls, value: Any) -> str:
Expand Down
13 changes: 12 additions & 1 deletion tidy3d/web/core/http_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import json
import os
import ssl
from enum import Enum
from functools import wraps
from typing import Any, Callable, Optional, TypeAlias
Expand All @@ -12,6 +13,7 @@
from requests.adapters import HTTPAdapter
from urllib3.util.ssl_ import create_urllib3_context

from tidy3d import log
from tidy3d.config import config

from . import core_config
Expand Down Expand Up @@ -200,7 +202,16 @@ def wrapper(*args: Any, **kwargs: Any) -> JSONType:

class TLSAdapter(HTTPAdapter):
def init_poolmanager(self, *args: Any, **kwargs: Any) -> None:
context = create_urllib3_context(ssl_version=config.web.ssl_version)
try:
ssl_version = (
ssl.TLSVersion[config.web.ssl_version]
if config.web.ssl_version is not None
else None
)
except KeyError:
log.warning(f"Invalid SSL/TLS version '{config.web.ssl_version}', using default")
ssl_version = None
context = create_urllib3_context(ssl_version=ssl_version)
kwargs["ssl_context"] = context
return super().init_poolmanager(*args, **kwargs)

Expand Down