Skip to content

Commit

Permalink
randomized rounding for spsa
Browse files Browse the repository at this point in the history
spsa converts floating point parameter values to integers via round(),
which is dicontinuous.  This patch adds randomized rounding,
which converts x+p (where x is integer and p is within [0,1)) to
x with probability 1-p and x+1 with probability p (thus we get back x+p
in expectation), which is continuous.
  • Loading branch information
wfenchel committed Jun 12, 2016
1 parent 7eebda7 commit 5f63500
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 8 deletions.
29 changes: 23 additions & 6 deletions fishtest/fishtest/rundb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import os
import random
import math
import time
from datetime import datetime
from bson.objectid import ObjectId
Expand Down Expand Up @@ -355,13 +356,27 @@ def approve_run(self, run_id, approver):
self.runs.save(run)
return True

def spsa_param_clip(param, increment, clipping):
def spsa_param_clip_round(param, increment, clipping, rounding):
value = 0.0

if clipping == 'old':
return min(max(param['theta'] + increment, param['min']), param['max'])
value = min(max(param['theta'] + increment, param['min']), param['max'])
else: #clipping == 'careful':
inc = min(abs(increment), abs(param['theta'] - param['min']) / 2, abs(param['theta'] - param['max']) / 2)
inc_sgn = 0 if increment == 0 else increment / abs(increment)
return param['theta'] + inc_sgn * inc
value = param['theta'] + inc_sgn * inc

#'deterministic' rounding calls round() inside the worker.
#'randomized' says 4.p should be 5 with probability p, 4 with probability 1-p,
# and is continuous (albeit after expectation) unlike round().
if rounding == 'randomized':
fl = math.floor(value) #greatest integer <= value, thus works for negative.
if random.uniform(0,1) < value - fl:
value = fl + 1
else:
value = fl

return value

def request_spsa(self, run_id, task_id):
run = self.get_run(run_id)
Expand All @@ -381,6 +396,8 @@ def request_spsa(self, run_id, task_id):
spsa = run['args']['spsa']
if 'clipping' not in spsa:
spsa['clipping'] = 'old'
if 'rounding' not in spsa:
spsa['rounding'] = 'deterministic'

# Generate the next set of tuning parameters
iter_local = spsa['iter'] + 1 #assume at least one completed, and avoid division by zero
Expand All @@ -391,14 +408,14 @@ def request_spsa(self, run_id, task_id):
flip = 1 if bool(random.getrandbits(1)) else -1
result['w_params'].append({
'name': param['name'],
'value': spsa_param_clip(param, c * flip, spsa['clipping']),
'value': spsa_param_clip_round(param, c * flip, spsa['clipping'], spsa['rounding']),
'R': R,
'c': c,
'flip': flip,
})
result['b_params'].append({
'name': param['name'],
'value': spsa_param_clip(param, -c * flip, spsa['clipping']),
'value': spsa_param_clip_round(param, -c * flip, spsa['clipping'], spsa['rounding']),
})

return result
Expand All @@ -418,7 +435,7 @@ def update_spsa(self, run, spsa_results):
R = spsa_results['w_params'][idx]['R']
c = spsa_results['w_params'][idx]['c']
flip = spsa_results['w_params'][idx]['flip']
param['theta'] = spsa_param_clip(param, R * c * result * flip, spsa['clipping']),
param['theta'] = spsa_param_clip_round(param, R * c * result * flip, spsa['clipping'], 'deterministic'),
summary.append({
'theta': param['theta'],
'R': R,
Expand Down
9 changes: 9 additions & 0 deletions fishtest/fishtest/templates/tests_run.mak
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@ Cowardice,150,0,200,10,0.0020"""})['raw_params']}</textarea>
</select>
</div>
</div>
<div class="control-group stop_rule spsa">
<label class="control-label">SPSA rounding:</label>
<div class="controls">
<select name="spsa_rounding">
<option value="deterministic">deterministic</option>
<option value="randomized">randomized</option>
</select>
</div>
</div>
<div class="control-group">
<label class="control-label">Time Control:</label>
<div class="controls">
Expand Down
7 changes: 5 additions & 2 deletions fishtest/fishtest/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def validate_form(request):
'iter': 0,
'num_iter': int(data['num_games'] / 2),
'clipping': request.POST['spsa_clipping'],
'rounding': request.POST['spsa_rounding'],
}
data['spsa']['params'] = parse_spsa_params(request.POST['spsa_raw_params'], data['spsa'])
else:
Expand Down Expand Up @@ -679,8 +680,10 @@ def tests_view(request):
if name == 'spsa' and value != '-':
params = ['param: %s, best: %.2f, start: %.2f, min: %.2f, max: %.2f, c %f, a %f' % \
(p['name'], p['theta'], p['start'], p['min'], p['max'], p['c'], p['a']) for p in value['params']]
value = 'Iter: %d, A: %d, alpha %f, gamma %f, clipping %s\n%s' % (value['iter'], value['A'], value['alpha'], \
value['gamma'], value['clipping'] if 'clipping' in value else 'old', '\n'.join(params))
value = 'Iter: %d, A: %d, alpha %f, gamma %f, clipping %s, rounding %s\n%s' % (value['iter'], value['A'], value['alpha'], value['gamma'],
value['clipping'] if 'clipping' in value else 'old',
value['rounding'] if 'rounding' in value else 'deterministic',
'\n'.join(params))

if 'tests_repo' in run['args']:
if name == 'new_tag':
Expand Down

0 comments on commit 5f63500

Please sign in to comment.