Skip to content

Commit

Permalink
ELBv2: added validation for target group (#6808)
Browse files Browse the repository at this point in the history
  • Loading branch information
macnev2013 committed Sep 14, 2023
1 parent 9d8c11f commit c1a6609
Show file tree
Hide file tree
Showing 7 changed files with 297 additions and 62 deletions.
5 changes: 5 additions & 0 deletions moto/elbv2/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,8 @@ def __init__(self, msg: str):
class InvalidLoadBalancerActionException(ELBClientError):
def __init__(self, msg: str):
super().__init__("InvalidLoadBalancerAction", msg)


class ValidationError(ELBClientError):
def __init__(self, msg: str):
super().__init__("ValidationError", msg)
171 changes: 150 additions & 21 deletions moto/elbv2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
InvalidModifyRuleArgumentsError,
InvalidStatusCodeActionTypeError,
InvalidLoadBalancerActionException,
ValidationError,
)

ALLOWED_ACTIONS = [
Expand Down Expand Up @@ -79,42 +80,49 @@ def __init__(
healthcheck_port: Optional[str] = None,
healthcheck_path: Optional[str] = None,
healthcheck_interval_seconds: Optional[str] = None,
healthcheck_timeout_seconds: Optional[int] = None,
healthcheck_timeout_seconds: Optional[str] = None,
healthcheck_enabled: Optional[str] = None,
healthy_threshold_count: Optional[str] = None,
unhealthy_threshold_count: Optional[str] = None,
matcher: Optional[Dict[str, Any]] = None,
target_type: Optional[str] = None,
ip_address_type: Optional[str] = None,
):
# TODO: default values differs when you add Network Load balancer
self.name = name
self.arn = arn
self.vpc_id = vpc_id
self.protocol = protocol
self.protocol_version = protocol_version or "HTTP1"
if target_type == "lambda":
self.protocol = None
self.protocol_version = None
elif target_type == "alb":
self.protocol = "TCP"
self.protocol_version = None
else:
self.protocol = protocol
self.protocol_version = protocol_version
self.port = port
self.healthcheck_protocol = healthcheck_protocol or self.protocol
self.healthcheck_port = healthcheck_port
self.healthcheck_port = healthcheck_port or "traffic-port"
self.healthcheck_path = healthcheck_path
self.healthcheck_interval_seconds = healthcheck_interval_seconds or 30
self.healthcheck_timeout_seconds = healthcheck_timeout_seconds
if not healthcheck_timeout_seconds:
# Default depends on protocol
if protocol in ["TCP", "TLS"]:
self.healthcheck_timeout_seconds = 6
elif protocol in ["HTTP", "HTTPS", "GENEVE"]:
self.healthcheck_timeout_seconds = 5
else:
self.healthcheck_timeout_seconds = 30
self.healthcheck_enabled = healthcheck_enabled
self.healthy_threshold_count = healthy_threshold_count or 5
self.unhealthy_threshold_count = unhealthy_threshold_count or 2
self.healthcheck_interval_seconds = healthcheck_interval_seconds or "30"
self.healthcheck_timeout_seconds = healthcheck_timeout_seconds or "10"
self.ip_address_type = (
ip_address_type or "ipv4" if self.protocol != "GENEVE" else None
)
self.healthcheck_enabled = (
healthcheck_enabled.lower() == "true"
if healthcheck_enabled in ["true", "false"]
else True
)
self.healthy_threshold_count = healthy_threshold_count or "5"
self.unhealthy_threshold_count = unhealthy_threshold_count or "2"
self.load_balancer_arns: List[str] = []
if self.healthcheck_protocol != "TCP":
self.matcher: Dict[str, Any] = matcher or {"HttpCode": "200"}
self.healthcheck_path = self.healthcheck_path or "/"
self.healthcheck_path = self.healthcheck_path
self.healthcheck_port = self.healthcheck_port or str(self.port)
self.target_type = target_type
self.target_type = target_type or "instance"

self.attributes = {
"deregistration_delay.timeout_seconds": 300,
Expand Down Expand Up @@ -1030,6 +1038,9 @@ def _validate_fixed_response_action(
)

def create_target_group(self, name: str, **kwargs: Any) -> FakeTargetGroup:
protocol = kwargs.get("protocol")
target_type = kwargs.get("target_type")

if len(name) > 32:
raise InvalidTargetGroupNameError(
f"Target group name '{name}' cannot be longer than '32' characters"
Expand Down Expand Up @@ -1080,6 +1091,124 @@ def create_target_group(self, name: str, **kwargs: Any) -> FakeTargetGroup:
"HttpCode must be like 200 | 200-399 | 200,201 ...",
)

if target_type in ("instance", "ip", "alb"):
for param in ("protocol", "port", "vpc_id"):
if not kwargs.get(param):
param = "VPC ID" if param == "vpc_id" else param.lower()
raise ValidationError(f"A {param} must be specified")

if kwargs.get("vpc_id"):
from moto.ec2.exceptions import InvalidVPCIdError

try:
self.ec2_backend.get_vpc(kwargs.get("vpc_id"))
except InvalidVPCIdError:
raise ValidationError(
f"The VPC ID '{kwargs.get('vpc_id')}' is not found"
)

kwargs_patch = {}

conditions: Dict[str, Any] = {
"target_lambda": {
"healthcheck_interval_seconds": 35,
"healthcheck_timeout_seconds": 30,
"unhealthy_threshold_count": 2,
"healthcheck_enabled": "false",
"healthcheck_path": "/",
},
"target_alb": {
"healthcheck_protocol": "HTTP",
"healthcheck_path": "/",
"healthcheck_timeout_seconds": 6,
"matcher": {"HttpCode": "200-399"},
},
"protocol_GENEVE": {
"healthcheck_interval_seconds": 10,
"healthcheck_port": 80,
"healthcheck_timeout_seconds": 5,
"healthcheck_protocol": "TCP",
"unhealthy_threshold_count": 2,
},
"protocol_HTTP_HTTPS": {
"healthcheck_timeout_seconds": 5,
"protocol_version": "HTTP1",
"healthcheck_path": "/",
"unhealthy_threshold_count": 2,
"healthcheck_interval_seconds": 30,
},
"protocol_TCP": {
"healthcheck_timeout_seconds": 10,
},
"protocol_TCP_TCP_UDP_UDP_TLS": {
"healthcheck_protocol": "TCP",
"unhealthy_threshold_count": 2,
"healthcheck_interval_seconds": 30,
},
}

if target_type == "lambda":
kwargs_patch.update(
{k: kwargs.get(k) or v for k, v in conditions["target_lambda"].items()}
)

if protocol == "GENEVE":
kwargs_patch.update(
{
k: kwargs.get(k) or v
for k, v in conditions["protocol_GENEVE"].items()
}
)

if protocol in ("HTTP", "HTTPS"):
kwargs_patch.update(
{
k: kwargs.get(k) or v
for k, v in conditions["protocol_HTTP_HTTPS"].items()
}
)

if protocol == "TCP":
kwargs_patch.update(
{k: kwargs.get(k) or v for k, v in conditions["protocol_TCP"].items()}
)

if protocol in ("TCP", "TCP_UDP", "UDP", "TLS"):
kwargs_patch.update(
{
k: kwargs.get(k) or v
for k, v in conditions["protocol_TCP_TCP_UDP_UDP_TLS"].items()
}
)

if target_type == "alb":
kwargs_patch.update(
{k: kwargs.get(k) or v for k, v in conditions["target_alb"].items()}
)

kwargs.update(kwargs_patch)

healthcheck_timeout_seconds = int(
str(kwargs.get("healthcheck_timeout_seconds") or "10")
)
healthcheck_interval_seconds = int(
str(kwargs.get("healthcheck_interval_seconds") or "30")
)

if (
healthcheck_timeout_seconds is not None
and healthcheck_interval_seconds is not None
):

if healthcheck_interval_seconds < healthcheck_timeout_seconds:
raise ValidationError(
"Health check interval must be greater than the timeout."
)
if healthcheck_interval_seconds == healthcheck_timeout_seconds:
raise ValidationError(
f"Health check timeout '{healthcheck_timeout_seconds}' must be smaller than the interval '{healthcheck_interval_seconds}'"
)

arn = make_arn_for_target_group(
account_id=self.account_id, name=name, region_name=self.region_name
)
Expand Down Expand Up @@ -1524,11 +1653,11 @@ def modify_target_group(
health_check_port: Optional[str] = None,
health_check_path: Optional[str] = None,
health_check_interval: Optional[str] = None,
health_check_timeout: Optional[int] = None,
health_check_timeout: Optional[str] = None,
healthy_threshold_count: Optional[str] = None,
unhealthy_threshold_count: Optional[str] = None,
http_codes: Optional[str] = None,
health_check_enabled: Optional[str] = None,
health_check_enabled: Optional[bool] = None,
) -> FakeTargetGroup:
target_group = self.target_groups.get(arn)
if target_group is None:
Expand Down
26 changes: 21 additions & 5 deletions moto/elbv2/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def create_target_group(self) -> str:
name = params.get("Name")
vpc_id = params.get("VpcId")
protocol = params.get("Protocol")
protocol_version = params.get("ProtocolVersion", "HTTP1")
protocol_version = params.get("ProtocolVersion")
port = params.get("Port")
healthcheck_protocol = self._get_param("HealthCheckProtocol")
healthcheck_port = self._get_param("HealthCheckPort")
Expand All @@ -196,7 +196,8 @@ def create_target_group(self) -> str:
healthy_threshold_count = self._get_param("HealthyThresholdCount")
unhealthy_threshold_count = self._get_param("UnhealthyThresholdCount")
matcher = params.get("Matcher")
target_type = params.get("TargetType")
target_type = params.get("TargetType", "instance")
ip_address_type = params.get("IpAddressType")
tags = params.get("Tags")

target_group = self.elbv2_backend.create_target_group(
Expand All @@ -214,6 +215,7 @@ def create_target_group(self) -> str:
healthy_threshold_count=healthy_threshold_count,
unhealthy_threshold_count=unhealthy_threshold_count,
matcher=matcher,
ip_address_type=ip_address_type,
target_type=target_type,
tags=tags,
)
Expand Down Expand Up @@ -797,21 +799,32 @@ def remove_listener_certificates(self) -> str:
<TargetGroupName>{{ target_group.name }}</TargetGroupName>
{% if target_group.protocol %}
<Protocol>{{ target_group.protocol }}</Protocol>
{% if target_group.protocol_version %}
<ProtocolVersion>{{ target_group.protocol_version }}</ProtocolVersion>
{% endif %}
{% endif %}
{% if target_group.port %}
<Port>{{ target_group.port }}</Port>
{% endif %}
{% if target_group.vpc_id %}
<VpcId>{{ target_group.vpc_id }}</VpcId>
{% endif %}
<HealthCheckProtocol>{{ target_group.healthcheck_protocol }}</HealthCheckProtocol>
{% if target_group.healthcheck_port %}<HealthCheckPort>{{ target_group.healthcheck_port }}</HealthCheckPort>{% endif %}
{% if target_group.healthcheck_enabled %}
{% if target_group.healthcheck_port %}
<HealthCheckPort>{{ target_group.healthcheck_port }}</HealthCheckPort>
{% endif %}
{% if target_group.healthcheck_protocol %}
<HealthCheckProtocol>{{ target_group.healthcheck_protocol or "None" }}</HealthCheckProtocol>
{% endif %}
{% endif %}
{% if target_group.healthcheck_path %}
<HealthCheckPath>{{ target_group.healthcheck_path or '' }}</HealthCheckPath>
{% endif %}
<HealthCheckIntervalSeconds>{{ target_group.healthcheck_interval_seconds }}</HealthCheckIntervalSeconds>
<HealthCheckTimeoutSeconds>{{ target_group.healthcheck_timeout_seconds }}</HealthCheckTimeoutSeconds>
<HealthCheckEnabled>{{ target_group.healthcheck_enabled and 'true' or 'false' }}</HealthCheckEnabled>
<HealthyThresholdCount>{{ target_group.healthy_threshold_count }}</HealthyThresholdCount>
<UnhealthyThresholdCount>{{ target_group.unhealthy_threshold_count }}</UnhealthyThresholdCount>
<HealthCheckEnabled>{{ target_group.healthcheck_enabled and 'true' or 'false' }}</HealthCheckEnabled>
{% if target_group.matcher %}
<Matcher>
<HttpCode>{{ target_group.matcher['HttpCode'] }}</HttpCode>
Expand All @@ -820,6 +833,9 @@ def remove_listener_certificates(self) -> str:
{% if target_group.target_type %}
<TargetType>{{ target_group.target_type }}</TargetType>
{% endif %}
{% if target_group.ip_address_type %}
<IpAddressType>{{ target_group.ip_address_type }}</IpAddressType>
{% endif %}
</member>
</TargetGroups>
</CreateTargetGroupResult>
Expand Down
2 changes: 1 addition & 1 deletion tests/test_autoscaling/test_elbv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def setUp(self) -> None:
HealthCheckPort="8080",
HealthCheckPath="/",
HealthCheckIntervalSeconds=5,
HealthCheckTimeoutSeconds=5,
HealthCheckTimeoutSeconds=3,
HealthyThresholdCount=5,
UnhealthyThresholdCount=2,
Matcher={"HttpCode": "200"},
Expand Down
Loading

0 comments on commit c1a6609

Please sign in to comment.