Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions aepsych/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,16 @@ def update(
# Validate the parameter-specific block
self._check_param_settings(par_name)

lb[i] = self[par_name].get("lower_bound", fallback="0")
ub[i] = self[par_name].get("upper_bound", fallback="1")
if self[par_name]["par_type"] == "categorical":
raise NotImplementedError(
"Categorical parameters not supported yet"
)
choices = self.getlist(par_name, "choices", element_type=str)
lb[i] = "0"
ub[i] = str(len(choices) - 1)
else:
lb[i] = self[par_name].get("lower_bound", fallback="0")
ub[i] = self[par_name].get("upper_bound", fallback="1")

self["common"]["lb"] = f"[{', '.join(lb)}]"
self["common"]["ub"] = f"[{', '.join(ub)}]"
Expand Down Expand Up @@ -397,6 +405,12 @@ def _check_param_settings(self, param_name: str) -> None:
f"Parameter {param_name} is fixed and needs to have value set."
)

elif param_block["par_type"] == "categorical":
# Need a choices array
if "choices" not in param_block:
raise ValueError(
f"Parameter {param_name} is missing the choices setting."
)
else:
raise ParameterConfigError(
f"Parameter {param_name} has an unsupported parameter type {param_block['par_type']}."
Expand Down
3 changes: 2 additions & 1 deletion aepsych/transforms/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from .categorical import Categorical
from .fixed import Fixed
from .log10_plus import Log10Plus
from .normalize_scale import NormalizeScale
from .round import Round

__all__ = ["Log10Plus", "NormalizeScale", "Round", "Fixed"]
__all__ = ["Categorical", "Fixed", "Log10Plus", "NormalizeScale", "Round"]
97 changes: 97 additions & 0 deletions aepsych/transforms/ops/categorical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any

import torch
from aepsych.config import Config
from aepsych.transforms.ops.base import StringParameterMixin, Transform


class Categorical(Transform, StringParameterMixin):
# These attributes do nothing here but ensures compat.
is_one_to_many = False
transform_on_train = True
transform_on_eval = True
transform_on_fantasize = True
training = True
reverse = False

def __init__(
self,
indices: list[int],
categories: dict[int, list[str]],
) -> None:
"""Initialize a categorical transform. The transform itself does not
change the tensors. Instead, this class allows passing in NumPy object
arrays where the categorical values are stored as strings. This provides
a convenient API to turn mixed categorical/continuous data into the
expected form for models.

Args:
indices (list[int]): The indices of the inputs that are categorical.
categories (dict[int, list[str]]): A dictionary mapping indices to
the list of categories for that input. There must be a list for
each index in `indices`.
"""
self.indices = indices
self.categories = categories
self.string_map = self.categories

def _transform(self, X: torch.Tensor) -> torch.Tensor:
r"""This is a no-op as these transforms should be acting on indices
already.

Args:
X (torch.Tensor): A `batch_shape x n x d`-dim tensor of inputs.

Returns:
torch.Tensor: The input tensor.
"""
return X

def _untransform(self, X: torch.Tensor) -> torch.Tensor:
r"""This is a no-op as these transforms should be acting on indices
already.

Args:
X (torch.Tensor): A `batch_shape x n x d`-dim tensor of transformed inputs.

Returns:
torch.Tensor: The input tensor.
"""
return X

@classmethod
def get_config_options(
cls,
config: Config,
name: str | None = None,
options: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Return a dictionary of the relevant options to initialize a Fixed parameter
transform for the named parameter within the config.

Args:
config (Config): Config to look for options in.
name (str, optional): Parameter to find options for.
options (Dict[str, Any], optional): Options to override from the config.

Returns:
Dict[str, Any]: A dictionary of options to initialize this class with,
including the transformed bounds.
"""
options = super().get_config_options(config=config, name=name, options=options)

if name is None:
raise ValueError(f"{name} must be set to initialize a transform.")

if "categories" not in options:
idx = options["indices"][0] # There should only be one index
cat_dict = {idx: config.getlist(name, "categories", element_type=str)}
options["categories"] = cat_dict

return options
20 changes: 15 additions & 5 deletions aepsych/transforms/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from aepsych.config import Config, ConfigurableMixin
from aepsych.generators.base import AcqfGenerator, AEPsychGenerator
from aepsych.models.base import AEPsychModelMixin
from aepsych.transforms.ops import Fixed, Log10Plus, NormalizeScale, Round
from aepsych.transforms.ops import Categorical, Fixed, Log10Plus, NormalizeScale, Round
from aepsych.transforms.ops.base import Transform
from aepsych.utils import get_bounds
from botorch.acquisition import AcquisitionFunction
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(
for key in transform.string_map.keys():
if key in fixed_string_map:
raise RuntimeError(
"Conflicting string maps between the Fixed transforms, each parameter can only have a single string map."
"Conflicting string maps between the string transforms, each parameter can only have a single string map."
)

fixed_string_map.update(transform.string_map)
Expand Down Expand Up @@ -145,8 +145,9 @@ def transform_bounds(
) -> torch.Tensor:
r"""Transform bounds of a parameter.

Individual transforms are applied in sequence. Then an adjustment is applied to
ensure the bounds are correct.
Individual transforms are applied in sequence. Looks for a specific
transform_bounds method in each transform to apply that, otherwise uses the
normal transform.

Args:
X (torch.Tensor): A tensor of inputs. Either `[dim]` or `[2, dim]`.
Expand Down Expand Up @@ -255,14 +256,23 @@ def get_config_options(
)
transform_dict[f"{par}_Round"] = round

if par_type == "fixed":
elif par_type == "fixed":
fixed = Fixed.from_config(
config=config, name=par, options=transform_options
)

# We don't mess with bounds since we don't want to modify indices
transform_dict[f"{par}_Fixed"] = fixed

# Categorical variable
elif par_type == "categorical":
categorical = Categorical.from_config(
config=config, name=par, options=transform_options
)

transform_dict[f"{par}_Categorical"] = categorical
continue # Prevents log-scaling or normalizing

# Log scale
if config.getboolean(par, "log_scale", fallback=False):
log10 = Log10Plus.from_config(
Expand Down
33 changes: 31 additions & 2 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import unittest
import uuid

Expand All @@ -20,7 +19,7 @@
ParameterTransformedModel,
ParameterTransforms,
)
from aepsych.transforms.ops import Fixed, Log10Plus, NormalizeScale, Round
from aepsych.transforms.ops import Categorical, Fixed, Log10Plus, NormalizeScale, Round


class TransformsWrapperTest(unittest.TestCase):
Expand Down Expand Up @@ -804,3 +803,33 @@ def test_fixed_conflict(self):

with self.assertRaises(RuntimeError):
_ = ParameterTransforms(fixed1=fixed1, fixed2=fixed2)


class TransformCategorical(unittest.TestCase):
def test_standalone_transform(self):
categories = {1: ["red", "green", "blue"], 3: ["big", "small"]}
input = torch.tensor([[0.2, 2, 4, 0, 1], [0.5, 0, 3, 0, 1], [0.9, 1, 0, 1, 0]])
input_cats = np.array(
[
[0.2, "blue", 4, "big", "right"],
[0.5, "red", 3, "big", "right"],
[0.9, "green", 0, "small", "left"],
],
dtype="O",
)

transforms = ParameterTransforms(
categorical1=Categorical(indices=[1, 3], categories=categories),
categorical2=Categorical(indices=[4], categories={4: ["left", "right"]}),
)

transformed = transforms.transform(input)
untransformed = transforms.untransform(transformed)

self.assertTrue(torch.equal(input, untransformed)) # Test no-op

strings = transforms.indices_to_str(input)
self.assertTrue(np.all(input_cats == strings))

indices = transforms.str_to_indices(input_cats)
self.assertTrue(torch.all(indices == input))