Skip to content

Commit

Permalink
Add support for the "expanded cost features" model.
Browse files Browse the repository at this point in the history
This commit includes two separate changes:

  1) the addition of extra features which contribute to the "inlining
  cost" estimate, and
  2) extending the size and training duration of the model to account
  for the additional features.
  • Loading branch information
jacob-hegna committed Jul 2, 2021
1 parent b235539 commit 8826749
Show file tree
Hide file tree
Showing 41 changed files with 44,542 additions and 2,523 deletions.
5 changes: 4 additions & 1 deletion compiler_opt/rl/feature_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _build_quantile_map(quantile_file_dir):

@gin.configurable
def get_observation_processing_layer_creator(quantile_file_dir,
with_sqrt=False,
with_z_score_normalization=False,
eps=1e-8):
"""Wrapper for observation_processing_layer."""
Expand All @@ -67,7 +68,9 @@ def normalization(obs):
x = tf.cast(
tf.raw_ops.Bucketize(input=expanded_obs, boundaries=quantile),
tf.float32) / len(quantile)
features = [x, tf.sqrt(x), x * x]
features = [x, x * x]
if with_sqrt:
features.append(np.sqrt(x))
if with_z_score_normalization:
y = tf.cast(expanded_obs, tf.float32)
y = (y - mean) / (std + eps)
Expand Down
46 changes: 28 additions & 18 deletions compiler_opt/rl/feature_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,18 @@
import os

from absl.testing import parameterized
import numpy as np
import tensorflow as tf

from compiler_opt.rl import constant
from compiler_opt.rl import feature_ops

_WITH_Z_SCORE_SQRT_PRODUCT_VALUES = [('with_sqrt_with_z_score', True, True),
('with_sqrt_without_z_score', True, False),
('without_sqrt_with_z_score', False, True),
('without_sqrt_without_z_score', False,
False)]


class FeatureUtilsTest(tf.test.TestCase, parameterized.TestCase):

Expand Down Expand Up @@ -50,12 +57,11 @@ def test_build_quantile_map_from_config(self):
# std
self.assertAllClose(14.988885, std)

@parameterized.named_parameters(('with_z_score', True),
('without_z_score', False))
def test_create_observation_processing_layer(self, with_z_score):
@parameterized.named_parameters(*_WITH_Z_SCORE_SQRT_PRODUCT_VALUES)
def test_create_observation_processing_layer(self, with_z_score, with_sqrt):
observation_processing_layer = (
feature_ops.get_observation_processing_layer_creator(
self._quantile_file_dir, with_z_score))
self._quantile_file_dir, with_sqrt, with_z_score))

obs_spec = tf.TensorSpec(dtype=tf.float32, shape=(), name='edge_count')
processing_layer = observation_processing_layer(obs_spec)
Expand All @@ -65,24 +71,28 @@ def test_create_observation_processing_layer(self, with_z_score):

outputs = self.evaluate(outputs)

expected_shape = [2, 1, 2]
expected = np.array([[[0.333333, 0.111111]], [[0.777778, 0.604938]]])

if with_sqrt:
expected_shape[2] += 1
expected = np.concatenate([expected, [[[0.57735]], [[0.881917]]]],
axis=-1)

if with_z_score:
self.assertAllEqual([2, 1, 4], outputs.shape)
self.assertAllClose([[[0.333333, 0.57735, 0.111111, -0.555968]],
[[0.777778, 0.881917, 0.604938, -0.155671]]],
outputs)
else:
self.assertAllEqual([2, 1, 3], outputs.shape)
self.assertAllClose(
[[[0.333333, 0.57735, 0.111111]], [[0.777778, 0.881917, 0.604938]]],
outputs)

@parameterized.named_parameters(('with_z_score', True),
('without_z_score', False))
expected_shape[2] += 1
expected = np.concatenate([expected, [[[-0.555968]], [[-0.155671]]]],
axis=-1)

self.assertAllEqual(expected_shape, outputs.shape)
self.assertAllClose(expected.tolist(), outputs)

@parameterized.named_parameters(*_WITH_Z_SCORE_SQRT_PRODUCT_VALUES)
def test_create_observation_processing_layer_for_dummy_features(
self, with_z_score):
self, with_z_score, with_sqrt):
observation_processing_layer = (
feature_ops.get_observation_processing_layer_creator(
self._quantile_file_dir, with_z_score))
self._quantile_file_dir, with_sqrt, with_z_score))

obs_spec = tf.TensorSpec(dtype=tf.float32, shape=(), name='dummy_feature')
processing_layer = observation_processing_layer(obs_spec)
Expand Down
28 changes: 28 additions & 0 deletions compiler_opt/rl/inlining/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_inlining_signature_spec():
"""Returns (time_step_spec, action_spec) for LLVM inlining."""
observation_spec = dict(
(key, tf.TensorSpec(dtype=tf.int64, shape=(), name=key)) for key in (
# Base features
'caller_basic_block_count',
'caller_conditionally_executed_blocks',
'caller_users',
Expand All @@ -38,6 +39,33 @@ def get_inlining_signature_spec():
'edge_count',
'callsite_height',
'cost_estimate',

# Expanded cost features
'sroa_savings',
'sroa_losses',
'load_elimination',
'call_penalty',
'call_argument_setup',
'load_relative_intrinsic',
'lowered_call_arg_setup',
'indirect_call_penalty',
'jump_table_penalty',
'case_cluster_penalty',
'switch_penalty',
'unsimplified_common_instructions',
'num_loops',
'dead_blocks',
'simplified_instructions',
'constant_args',
'constant_offset_ptr_args',
'callsite_cost',
'cold_cc_penalty',
'last_call_to_static_bonus',
'is_multiple_blocks',
'nested_inlines',
'nested_inline_cost_estimate',
'threshold',

# inlining_default is not used as feature in training.
'inlining_default'))
reward_spec = tf.TensorSpec(dtype=tf.float32, shape=(), name='reward')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ train_eval.batch_size=64
train_eval.train_sequence_length=1

get_observation_processing_layer_creator.quantile_file_dir='compiler_opt/rl/inlining/vocab'
get_observation_processing_layer_creator.with_sqrt = False
get_observation_processing_layer_creator.with_z_score_normalization = False

create_agent.policy_network = @q_network.QNetwork

QNetwork.preprocessing_combiner=@tf.keras.layers.Concatenate()
QNetwork.fc_layer_params=(40, 20)
QNetwork.dropout_layer_params=(0.2, 0.2)
QNetwork.fc_layer_params=(40, 40, 20)
QNetwork.dropout_layer_params=(0.2, 0.2, 0.2)
QNetwork.activation_fn=@tf.keras.activations.relu

tf.train.AdamOptimizer.learning_rate = 0.001
Expand Down
5 changes: 3 additions & 2 deletions compiler_opt/rl/inlining/gin_configs/ppo_nn_agent.gin
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@ train_eval.get_signature_spec_fn=@config.get_inlining_signature_spec
train_eval.agent_name='ppo'
train_eval.warmstart_policy_dir=''
train_eval.num_policy_iterations=3000
train_eval.num_iterations=200
train_eval.num_iterations=300
train_eval.batch_size=128
train_eval.train_sequence_length=16
train_eval.deploy_policy_name='saved_collect_policy'
train_eval.use_stale_results=False

get_observation_processing_layer_creator.quantile_file_dir='compiler_opt/rl/inlining/vocab'
get_observation_processing_layer_creator.with_sqrt = False
get_observation_processing_layer_creator.with_z_score_normalization = False

create_agent.policy_network = @actor_distribution_network.ActorDistributionNetwork

ActorDistributionNetwork.preprocessing_combiner=@tf.keras.layers.Concatenate()
ActorDistributionNetwork.fc_layer_params=(40, 20)
ActorDistributionNetwork.fc_layer_params=(40, 40, 20)
ActorDistributionNetwork.dropout_layer_params=None
ActorDistributionNetwork.activation_fn=@tf.keras.activations.relu

Expand Down
Loading

0 comments on commit 8826749

Please sign in to comment.