Skip to content

Commit

Permalink
Add diagnostics for network and regression data
Browse files Browse the repository at this point in the history
View distribution of weights and their changes
Also view distribution of target Q-values and predicted Q-values
  • Loading branch information
Marcel Nunez committed May 5, 2024
1 parent 67c4dff commit a8e906d
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion train_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,16 +260,38 @@ def main(symmetry,game_type,double_dqn):
q_prediction_errors = predicted_q_values - q_to_train_single_values
loss_value = mse_loss(q_to_train_single_values, predicted_q_values)


weights_and_biases_flat_before_update = np.concatenate([v.numpy().flatten() for v in agent.model.variables])

grads = tape.gradient(loss_value, agent.model.trainable_variables)
optimizer.apply_gradients(zip(grads, agent.model.trainable_variables))

# Record prediction error
writer.add_scalar("loss", loss_value.numpy(), step)
if step % RECORD_HISTOGRAMS == 0:
writer.add_histogram("q-predicted",predicted_q_values.numpy(),step)
writer.add_histogram("q-train",q_to_train_single_values,step)
writer.add_histogram("q-error",q_prediction_errors.numpy(),step)

weights_and_biases_flat = np.concatenate([v.numpy().flatten() for v in agent.model.variables])
writer.add_histogram("weights and biases",weights_and_biases_flat,step)

grads_flat = np.concatenate([v.numpy().flatten() for v in grads])
writer.add_histogram("gradients",grads_flat,step)

weights_and_biases_delta = weights_and_biases_flat - weights_and_biases_flat_before_update
writer.add_histogram("weight-bias-updates",weights_and_biases_delta,step)



# Update policy
if step % SYNC_TARGET_NETWORK == 0:

weights_and_biases_flat = np.concatenate([v.numpy().flatten() for v in agent.model.variables])
weights_and_biases_target = np.concatenate([v.numpy().flatten() for v in target_network.variables])
target_parameter_updates = weights_and_biases_flat - weights_and_biases_target
writer.add_histogram("target parameter updates",target_parameter_updates,step)

target_network.set_weights(agent.model.get_weights())

writer.close()
Expand Down

0 comments on commit a8e906d

Please sign in to comment.