In [20]:
import pickle
import numpy as np
import torch

# insert the path to the 'networks.py' file
import sys
sys.path.insert(0, "/home/luke/mymujoco/rl/")
import networks

In [9]:
filepath = "/home/luke/mymujoco/rl/models/paper_baseline_4/07-03-23/luke-PC_13:37_A10/"
filename = "DQN_150x100x50_policy_net_001.pickle"

with open(filepath + filename, 'rb') as f:
  loaded_network = pickle.load(f)

In [47]:
n_inputs = 59
n_outputs = 8

# what are the actions? X=gripper prismatic joint, Y=gripper revolute joint, Z=gripper palm, H=gripper height
action_names = ["X_close", "X_open", "Y_close", "Y_open", "Z_close", "Z_open", "H_down", "H_up"]

# lets do an example state vector
state = 2 * np.random.rand((n_inputs)) - 1                   # get random vector of numbers from [-1, 1]
state = np.array([state])                                    # state vectors must be [nested] once and must be Floats
state_tensor = torch.tensor(state, dtype=torch.float32)      # convert to pytorch and change double->float

with torch.no_grad():
  # t.max(1) returns largest column value of each row
  # [1] is second column of max result, the index of max element
  # view(1, 1) selects this element which has max expected reward
  action = loaded_network(state_tensor).max(1)[1].view(1, 1)

# extract the chosen action, which are numbered 0-7
print(f"Action number is: {action.item()}, this means {action_names[action]}")

[[-0.61787562 -0.94073609  0.03166923 -0.7079431   0.33917209  0.55970318
  -0.71272315  0.39500408  0.16143928 -0.98320414 -0.29304594  0.08019227
  -0.31357218 -0.11360578  0.59223452 -0.58075262  0.45089215 -0.7680106
  -0.0413638  -0.80830399  0.49493504 -0.16220988 -0.46356009  0.83005717
   0.81971388 -0.27010761 -0.80884322  0.98472305 -0.75593761 -0.26230838
  -0.77631088  0.92367255  0.71301499 -0.45417705  0.12526529 -0.91493138
   0.3928103  -0.18362171 -0.03480841 -0.09075251  0.5743666   0.5820564
   0.49736643  0.88740174 -0.93908319 -0.83135959  0.73986872  0.78408892
   0.34748686  0.04207258  0.07864777 -0.23729045 -0.91806079 -0.88801257
  -0.11574382 -0.95400668  0.30324101  0.93010949  0.36115624]]
Action number is: 0, this means X_close
