Skip to content

Commit

Permalink
Fixup for some python scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
lightvector committed Aug 25, 2018
1 parent f4725e9 commit 86914ce
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 23 deletions.
4 changes: 3 additions & 1 deletion model.py
Expand Up @@ -928,10 +928,11 @@ def __init__(self,model,for_optimization,require_last_move):
# self.value_loss = 0.5 * (cross_entropy_value_loss + l2_value_loss)
self.value_loss = l2_value_loss

self.weight_sum = tf.reduce_sum(self.target_weights_used)

if for_optimization:
#Prior/Regularization
self.l2_reg_coeff = tf.placeholder(tf.float32)
self.weight_sum = tf.reduce_sum(self.target_weights_used)
self.reg_loss = self.l2_reg_coeff * tf.add_n([tf.nn.l2_loss(variable) for variable in model.reg_variables]) * self.weight_sum

#The loss to optimize
Expand All @@ -945,6 +946,7 @@ def __init__(self,model,target_vars,include_debug_stats):
self.top4_prediction = tf.nn.in_top_k(model.policy_output,policy_target_idxs,4)
self.accuracy1 = tf.reduce_sum(target_vars.target_weights_used * tf.cast(self.top1_prediction, tf.float32))
self.accuracy4 = tf.reduce_sum(target_vars.target_weights_used * tf.cast(self.top4_prediction, tf.float32))
self.valueconf = tf.reduce_sum(tf.square(model.value_output))

#Debugging stats
if include_debug_stats:
Expand Down
2 changes: 1 addition & 1 deletion play.py
Expand Up @@ -209,7 +209,7 @@ def lerp(p,x0,x1,y0,y1):
texts_rev.append("%s %.3f" % (str_coord(loc,board),value))

if value_head_output is not None:
texts_value.append("bv %.2f%%" % (50+50*(value_head_output if board.pla == Board.BLACK else -value_head_output)))
texts_value.append("wv %.2f%%" % (50+50*(value_head_output if board.pla == Board.WHITE else -value_head_output)))

gfx_commands.append("TEXT " + ", ".join(texts_value + texts_rev + texts))

Expand Down
60 changes: 39 additions & 21 deletions test.py
Expand Up @@ -27,17 +27,23 @@
parser = argparse.ArgumentParser(description=description)
parser.add_argument('-gamesh5', help='H5 file of preprocessed game data', required=True)
parser.add_argument('-model-file', help='model file prefix to load', required=True)
parser.add_argument('-rank-idx', help='rank to provide to model for inference', required=True)
parser.add_argument('-rank-idx', help='rank to provide to model for inference', required=False)
parser.add_argument('-require-last-move', help='filter down to only instances where last move is provided', required=False, action="store_true")
parser.add_argument('-use-training-set', help='run on training set instead of test set', required=False, action="store_true")
parser.add_argument('-validation-prop', help='only use this proportion of validation set', required=False)
args = vars(parser.parse_args())

gamesh5 = args["gamesh5"]
model_file = args["model_file"]
rank_idx = int(args["rank_idx"])
rank_idx = (int(args["rank_idx"]) if args["rank_idx"] is not None else 0)

require_last_move = args["require_last_move"]
use_training_set = args["use_training_set"]

validation_prop = 1.0
if "validation_prop" in args and args["validation_prop"] is not None:
validation_prop = float(args["validation_prop"])

def log(s):
print(s,flush=True)

Expand Down Expand Up @@ -105,7 +111,7 @@ def log(s):
sys.stderr.flush()

log("Began session")
log("Testing on " + str(num_h5_val_rows) + " rows")
log("Testing on " + str(int(num_h5_val_rows * validation_prop)) + "/" + str(num_h5_val_rows) + " rows")
log("h5_chunk_size = " + str(h5_chunk_size))

sys.stdout.flush()
Expand Down Expand Up @@ -139,6 +145,9 @@ def run(fetches, rows):
assert(len(model.target_weights_shape) == 0)
assert(len(model.rank_shape) == 1)

if not isinstance(rows, np.ndarray):
rows = np.array(rows)

row_inputs = rows[:,0:input_len].reshape([-1] + model.input_shape)
row_policy_targets = rows[:,policy_target_start:policy_target_start+policy_target_len]
row_value_target = rows[:,value_target_start]
Expand All @@ -152,43 +161,52 @@ def run(fetches, rows):
model.inputs: row_inputs,
model.ranks: ranks_input,
target_vars.policy_targets: row_policy_targets,
target_vars.value_target: row_value_target,
target_vars.target_weights_from_data: row_target_weights,
model.symmetries: [False,False,False],
model.is_training: False
})

def np_array_str(arr,precision):
return np.array_str(arr, precision=precision, suppress_small = True, max_line_width = 200)
def merge_dicts(dicts,merge_list):
keys = dicts[0].keys()
return dict((key,merge_list([d[key] for d in dicts])) for key in keys)

def run_validation_in_batches(fetches):
#Run validation accuracy in batches to avoid out of memory error from processing one supergiant batch
validation_batch_size = 128
num_validation_batches = (num_h5_val_rows+validation_batch_size-1)//validation_batch_size
results = [[] for j in range(len(fetches))]
num_validation_batches = int(num_h5_val_rows * validation_prop + validation_batch_size-1)//validation_batch_size
results = []
for i in range(num_validation_batches):
print(".",end="",flush=True)
rows = h5val[i*validation_batch_size : min((i+1)*validation_batch_size, num_h5_val_rows)]
if not isinstance(rows, np.ndarray):
rows = np.array(rows)

result = run(fetches, rows)
for j in range(len(fetches)):
results[j].extend(result[j])
results.append(result)

print("",flush=True)
return results

def validation_stats_str(vacc1,vacc4,vpolicy_loss,vweight_sum):
return "vacc1 %5.2f%% vacc4 %5.2f%% vploss %f vweight_sum %f" % (vacc1*100/vweight_sum, vacc4*100/vweight_sum, vpolicy_loss/vweight_sum, vweight_sum)

(acc1s,acc4s,policy_losses,weight_outputs) = run_validation_in_batches([
metrics.accuracy1,
metrics.accuracy4,
target_vars.policy_loss,
target_vars.target_weights_used
])
(vacc1,vacc4,vpolicy_loss,vweight_sum) = (np.sum(acc1s),np.sum(acc4s),np.sum(policy_losses),np.sum(weight_outputs))
vstr = validation_stats_str(vacc1,vacc4,vpolicy_loss,vweight_sum)
vmetrics = {
"acc1": metrics.accuracy1,
"acc4": metrics.accuracy4,
"ploss": target_vars.policy_loss,
"vloss": target_vars.value_loss,
"vconf": metrics.valueconf,
"wsum": target_vars.weight_sum,
}

def validation_stats_str(vmetrics_evaled):
return "vacc1 %5.2f%% vacc4 %5.2f%% vploss %f vvloss %f vconf %f" % (
vmetrics_evaled["acc1"] * 100 / vmetrics_evaled["wsum"],
vmetrics_evaled["acc4"] * 100 / vmetrics_evaled["wsum"],
vmetrics_evaled["ploss"] / vmetrics_evaled["wsum"],
vmetrics_evaled["vloss"] / vmetrics_evaled["wsum"],
vmetrics_evaled["vconf"] / vmetrics_evaled["wsum"],
)

vmetrics_evaled = merge_dicts(run_validation_in_batches(vmetrics), np.sum)
vstr = validation_stats_str(vmetrics_evaled)

log(vstr)

Expand Down

0 comments on commit 86914ce

Please sign in to comment.