Skip to content

Commit

Permalink
Add greedy controller.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 537427210
  • Loading branch information
Max Schwarzer authored and JesseFarebro committed Jun 2, 2023
1 parent b5ef4dc commit 8d1275f
Show file tree
Hide file tree
Showing 6 changed files with 412 additions and 9 deletions.
54 changes: 54 additions & 0 deletions putting_dune/action_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,57 @@ def action_spec(self) -> specs.BoundedArray:
minimum=np.asarray([-1.0, -1.0, 0.0]),
maximum=np.asarray([1.0, 1.0, 1.0]),
)


class RelativeToSiliconMaterialFrameActionAdapter(
RelativeToSiliconActionAdapter
):
"""An action adapter that takes a relative position in Angstroms to a silicon.
Input actions are a delta vector from the silicon, specified in Angstroms.
"""

def get_action(
self,
previous_observation: microscope_utils.MicroscopeObservation,
action: np.ndarray,
) -> List[microscope_utils.BeamControlMicroscopeFrame]:
"""Gets simulator controls from the agent action."""
relative_position_angstroms = action[:2]

silicon_position = graphene.get_silicon_positions(previous_observation.grid)

if silicon_position.shape != (1, 2):
raise RuntimeError(
'Expected to find one silicon with x, y coordinates. Instead, '
f'got {silicon_position.shape[0]} silicon atoms with '
f'{silicon_position.shape[1]} dimensions.'
)
silicon_position = np.reshape(silicon_position, (2,))

# Action is [dx, dy] in angstroms
relative_position_microscope = (
previous_observation.fov.material_frame_to_microscope_frame(
relative_position_angstroms
)
)
control_position = silicon_position + relative_position_microscope
control_position = np.clip(control_position, 0.0, 1.0)

if self._fixed_dwell_time:
dwell_time = dt.timedelta(seconds=self._min_dwell_seconds)
else:
dwell_time_action = np.clip(action[2], 0.0, 1.0)
dwell_range_seconds = self._max_dwell_seconds - self._min_dwell_seconds
dwell_time_seconds = (
dwell_time_action * dwell_range_seconds + self._min_dwell_seconds
)
dwell_time = dt.timedelta(seconds=dwell_time_seconds)

return [
microscope_utils.BeamControlMicroscopeFrame(
microscope_utils.BeamControl(
geometry.Point(*control_position), dwell_time
)
)
]
112 changes: 110 additions & 2 deletions putting_dune/agents/agent_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import abc
import enum
import functools
from typing import Sequence, Union

from typing import Callable, Optional, Sequence, Union
import dm_env
import numpy as np
from putting_dune import geometry


@enum.unique
Expand Down Expand Up @@ -76,3 +76,111 @@ def step(self, time_step: dm_env.TimeStep) -> np.ndarray:

def set_mode(self, mode: AgentMode) -> None:
pass # No action required.


class GreedyAgent(Agent):
"""An agent that acts greedily according to a transition function.
Optionally, some randomization can be added, for the sake of data collection,
and a fixed offset to the argmax may be specified.
The argmax is assumed to be calculated for an Si with a neighbor positioned
directly in the positive X-direction, and is the beam position most likely to
cause a transition to that neighbor.
If no transition function is specified, a standard offset of 1.42 A towards
the neighbor will be used. A manual argmax may also be specified in place of
this.
Currently must be used with the SingleSiliconMaterialFrameFeatureConstructor,
and the RelativeToSiliconMaterialFrameActionAdapter.
"""

def __init__(
self,
rng: np.random.Generator = None,
transition_function: Optional[Callable[[np.ndarray], np.ndarray]] = None,
argmax: Optional[np.ndarray] = np.asarray([1.42, 0.0]),
argmax_resolution: float = 0.05,
position_noise_sigma: float = 0.0,
fixed_offset: np.ndarray = np.zeros(2, dtype=np.float32),
low: Union[float, np.ndarray] = -5,
high: Union[float, np.ndarray] = 5,
):
"""GreedyAgent constructor.
Args:
rng: The rng Generator to use any randomness.
transition_function: function that takes a beam position and predicts
transition probabilities.
argmax: manual specification of greedy position (For a silicon with a
neighbor at (1.42, 0)).
argmax_resolution: resolution in Angstroms for argmax-finding.
position_noise_sigma: standard deviation for extra beam position noise.
fixed_offset: Fixed vector to add to beam position (on top of argmax).
low: The lowest value to sample.
high: The highest value to sample.
"""
self._position_noise_sigma = position_noise_sigma
self._fixed_offset = fixed_offset
self._rng = rng if rng is not None else np.random.default_rng()
self._low = low
self._high = high
if transition_function is not None:
self._argmax = self.find_argmax(transition_function, argmax_resolution)
elif argmax is not None:
self._argmax = argmax
else:
raise ValueError('One of transition_function or argmax must be set.')

def find_argmax(
self,
transition_function: Callable[[np.ndarray], np.ndarray],
resolution: float = 0.05,
) -> np.ndarray:
"""Finds the argmax of a transition function by grid search.
Args:
transition_function: A function taking a numpy array of shape (2,) and
returning the probability (or rate) of transitioning to three neighbors.
resolution: Grid resolution (in Angstroms) to use for the search.
Returns:
The approximate argmax of the transition function with respect to
transitioning to neighbor 0.
"""
num_points = int((self._high - self._low) // resolution)
points_1d = np.linspace(self._low, self._high, num_points, dtype=np.float32)
points_x = np.tile(points_1d[None], (num_points, 1))
points_y = np.tile(points_1d[:, None], (1, num_points))
points = np.stack([points_x, points_y], axis=-1)
points = np.reshape(points, (-1, points.shape[-1]))
transition_probabilities = np.stack(
[transition_function(x) for x in points], 0
)
return points[np.argmax(transition_probabilities[..., 0], axis=-1)]

def step(self, time_step: dm_env.TimeStep) -> np.ndarray:
assert time_step.observation.shape == (10,)
neighbor_deltas = time_step.observation[2:-2].reshape(3, 2)
goal_delta = time_step.observation[-2:]

neighbor_scores = np.linalg.norm(
neighbor_deltas - goal_delta[None], axis=-1
)
best_neighbor = np.argmin(neighbor_scores, axis=-1)
angles = geometry.get_angles(neighbor_deltas)
angle = angles[best_neighbor]

beam_position = self._argmax + self._fixed_offset
beam_position_noise = self._rng.normal(
0, self._position_noise_sigma, size=2
)
beam_position = beam_position + beam_position_noise

rotated_beam_position = geometry.rotate_coordinates(beam_position, angle)

return rotated_beam_position

def set_mode(self, mode: AgentMode) -> None:
pass # No action required.
120 changes: 119 additions & 1 deletion putting_dune/agents/agent_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,17 @@
observation=np.zeros((4,), dtype=np.float32),
)

_CANONICAL_GREEDY_STEP = dm_env.TimeStep(
step_type=dm_env.StepType.MID,
reward=0.0,
discount=0.99,
observation=np.array(
[0, 0, 1.42, 0, -0.71, 1.23, -0.71, -1.23, 1.42, 0], dtype=np.float32
),
)


class AgentTest(parameterized.TestCase):
class UniformRandomAgentTest(parameterized.TestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -63,5 +72,114 @@ def test_uniform_agent_selects_actions_with_correct_shape(
self.assertEqual(action.shape, shape)


class GreedyAgentTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self._rng = np.random.default_rng(0)

@parameterized.named_parameters(
dict(
testcase_name='classic_argmax',
argmax=np.array([1.42, 0]),
),
dict(
testcase_name='weird_argmax',
argmax=np.random.randn(2),
),
)
def test_greedy_agent_selects_argmax(self, argmax: np.ndarray):
agent = agent_lib.GreedyAgent(argmax=argmax)

action = agent.step(_CANONICAL_GREEDY_STEP)
np.testing.assert_allclose(action, argmax)

@parameterized.named_parameters(
dict(
testcase_name='neighbor_0',
argmax=np.array([1.42, 0]),
neighbors=np.array([1.42, 0, -0.71, 1.23, -0.71, -1.23]),
goal=np.array([1.42, 0]),
),
dict(
testcase_name='neighbor_1',
argmax=np.array([1.42, 0]),
neighbors=np.array([1.42, 0, -0.71, 1.23, -0.71, -1.23]),
goal=np.array([-0.71, 1.23]),
),
dict(
testcase_name='neighbor_2',
argmax=np.array([1.42, 0]),
neighbors=np.array([1.42, 0, -0.71, 1.23, -0.71, -1.23]),
goal=np.array([-0.71, -1.23]),
),
)
def test_greedy_agent_selects_rotated_argmax(
self, argmax: np.ndarray, neighbors: np.ndarray, goal: np.ndarray
):
agent = agent_lib.GreedyAgent(argmax=argmax)
observation = np.concatenate([
np.zeros(2),
neighbors,
goal,
])

step = dm_env.TimeStep(
step_type=dm_env.StepType.MID,
reward=0.0,
discount=0.99,
observation=observation,
)

action = agent.step(step)
np.testing.assert_allclose(action, goal, atol=0.01)

@parameterized.named_parameters(
dict(
testcase_name='l2_rate',
neighbors=np.array([1.42, 0, -0.71, 1.23, -0.71, -1.23]),
goal=np.array([1.42, 0]),
low=-1.5,
high=1.5,
resolution=0.05,
),
)
def test_greedy_agent_finds_argmax(
self,
neighbors: np.ndarray,
goal: np.ndarray,
low: float,
high: float,
resolution: float,
):
neighbor = np.reshape(neighbors, (3, 2))

def transition_function(beam_position):
return -np.linalg.norm(neighbor - beam_position[..., None, :], axis=-1)

agent = agent_lib.GreedyAgent(
argmax=None,
transition_function=transition_function,
argmax_resolution=resolution,
low=low,
high=high,
)
observation = np.concatenate([
np.zeros(2),
neighbors,
goal,
])

step = dm_env.TimeStep(
step_type=dm_env.StepType.MID,
reward=0.0,
discount=0.99,
observation=observation,
)

action = agent.step(step)
np.testing.assert_allclose(action, neighbor[0], atol=resolution)


if __name__ == '__main__':
absltest.main()
50 changes: 49 additions & 1 deletion putting_dune/experiments/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,25 @@ def _get_relative_random_agent(
)


def _get_greedy_agent(
rng: np.random.Generator,
adapters_and_goal: experiments.AdaptersAndGoal,
argmax=np.asarray([1.42, 0.0]),
transition_function=None,
fixed_offset=np.zeros(
2,
),
) -> agent_lib.GreedyAgent:
return agent_lib.GreedyAgent(
rng=rng,
argmax=argmax,
transition_function=transition_function,
fixed_offset=fixed_offset,
low=adapters_and_goal.action_adapter.action_spec.minimum,
high=adapters_and_goal.action_adapter.action_spec.maximum,
)


@dataclasses.dataclass(frozen=True)
class _TfAgentCreator:
"""Gets a tf eval agent, loading from the specified path."""
Expand Down Expand Up @@ -163,7 +182,26 @@ def __call__(self) -> experiments.AdaptersAndGoal:
dwell_time_range=self.dwell_time_range,
max_distance_angstroms=self.max_distance_angstroms,
),
feature_constructor=feature_constructors.SingleSiliconPristineGraphineFeatureConstuctor(),
feature_constructor=feature_constructors.SingleSiliconPristineGrapheneFeatureConstuctor(),
goal=goals.SingleSiliconGoalReaching(),
)


@dataclasses.dataclass(frozen=True)
class _SingleSiliconGoalReachingMaterialFrame:
dwell_time_range: Tuple[dt.timedelta, dt.timedelta] = (
dt.timedelta(seconds=1.5),
dt.timedelta(seconds=1.5),
)
max_distance_angstroms: float = constants.CARBON_BOND_DISTANCE_ANGSTROMS * 2.0

def __call__(self) -> experiments.AdaptersAndGoal:
return experiments.AdaptersAndGoal(
action_adapter=action_adapters.RelativeToSiliconMaterialFrameActionAdapter(
dwell_time_range=self.dwell_time_range,
max_distance_angstroms=self.max_distance_angstroms,
),
feature_constructor=feature_constructors.SingleSiliconMaterialFrameFeatureConstructor(),
goal=goals.SingleSiliconGoalReaching(),
)

Expand Down Expand Up @@ -245,6 +283,16 @@ def _get_human_prior_rates_config() -> experiments.SimulatorConfig:
max_distance_angstroms=3 * constants.CARBON_BOND_DISTANCE_ANGSTROMS,
),
),
'greedy_simple': experiments.MicroscopeExperiment(
get_agent=_get_greedy_agent,
get_adapters_and_goal=_SingleSiliconGoalReachingMaterialFrame(
dwell_time_range=(
dt.timedelta(seconds=5.0),
dt.timedelta(seconds=5.0),
),
max_distance_angstroms=2 * constants.CARBON_BOND_DISTANCE_ANGSTROMS,
),
),
'ppo_simple_images_tf': experiments.MicroscopeExperiment(
get_agent=_GET_PPO_SIMPLE_IMAGES_TF,
get_adapters_and_goal=_SingleSiliconGoalReachingFromPixels(),
Expand Down
Loading

0 comments on commit 8d1275f

Please sign in to comment.