<a href="https://colab.research.google.com/github/noahdenicola/RL/blob/main/NoahDeNicola_TSP_InstaDeep.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install jax numpy dm-haiku jumanji

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import jax
import jax.numpy as jnp
import numpy as np
import jumanji
import haiku as hk

from jax import value_and_grad



In [3]:
#Get Key
rng = jax.random.PRNGKey(42)

#Defn Class
class Network(hk.Module):
  def __init__(self, hidden_size=64, output_size=1, name='QNet'):
    super().__init__(name=name)
    self.mlp = hk.nets.MLP(output_sizes=[hidden_size, output_size])

  def __call__(self, x):
    return self.mlp(x)

#FeedForward Method
def ffnet(x):
  module = Network()
  return module(x)

#Initialise Model
model = hk.transform(ffnet)
params = model.init(rng, jnp.array([[1., 2., 3., 4.]]))

for layer_name, weights in params.items():
    print(layer_name)
    print("Weights : {}, Biases : {}\n".format(params[layer_name]["w"].shape,params[layer_name]["b"].shape))



QNet/~/mlp/~/linear_0
Weights : (4, 64), Biases : (64,)

QNet/~/mlp/~/linear_1
Weights : (64, 1), Biases : (1,)



In [26]:
#Training Methods
def MSE(params, x, y):
  y_hat = model.apply(params, rng, x)
  y_hat = y_hat.squeeze()
  return jnp.power(y-y_hat, 2).mean()

def step(params, grads):
  return params - lr*grads

def printParams(all_params):
  print("Main Network Parameters")
  pr = True
  for p in all_params:
    for layer_name, weights in p.items():
        print(layer_name)
        print("Weights : {}, Biases : {}\n".format(params[layer_name]["w"].shape,params[layer_name]["b"].shape))
    if pr:
      pr = False
      print("Target Network Parameters")

def makeState(env_obvservation):
  coords = np.array(env_obvservation.coordinates).flatten()
  pos = np.array(env_obvservation.position)
  visd = np.array(env_obvservation.visited_mask)
  traj = np.array(env_obvservation.trajectory)
  return np.concatenate((np.array([pos]),coords,visd,traj))


def DQN_update(key, model, all_params, memory, batchSize=32, gamma=0.99):
  #Sample Minibatch 
  l = len(memory)
  indicies = jax.random.choice(rng, l, (1,min(batchSize, l))).squeeze()

  #Initialise Loop
  Q = jnp.zeros(l)
  Q_hat = jnp.zeros(l)
  loc = 0

  #for experience in minibatch
  for i in indicies:
    exp = memory[i] #exp = [state, action, nextState, reward, done]
  
    #Get max_{a'}target_Q(state,a')
    nxtQ = 0
    if not exp[4]:
      nxtQ = jnp.max(model.apply(all_params[1], key, exp[0]))

    #q_hat = { r + gamma*maxQ(s,a') if not done, r otherwise
    q_hat = exp[3] + gamma*nxtQ

    #get q = main_Q(state, action)
    q = model.apply(all_params[0], key, exp[0])[exp[1]]

    #store
    Q[loc] = q
    Q_hat[loc] = q_hat
    loc +=1

  return Q, Q_hat

  


  and should_run_async(code)


In [5]:
#Test
from sklearn import datasets
from sklearn.model_selection import train_test_split

X, Y = datasets.load_iris(return_X_y=True)

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, random_state=123)

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32),\

samples, features = X_train.shape

X = X_train
Y = Y_train

print(X.shape, Y.shape)


(120, 4) (120,)


In [31]:
#Training Loop
epochs = 300
lr = 1e-3

for i in range(epochs):
  loss, grads = value_and_grad(MSE)(params, X, Y)
  params = jax.tree_map(step, params, grads)

  if i%50 == 0:
        print("MSE : {:.2f}".format(loss))

MSE : 5.14
MSE : 0.15
MSE : 0.07
MSE : 0.06
MSE : 0.06
MSE : 0.06


In [27]:
#Rough Code
gamma = 0.95
lr = 1e-3
dropout = 0.0
envName = 'TSP-v1'
totalSteps = 5000


key = jax.random.PRNGKey(42)
#Initialise Enviroment
env = jumanji.make(envName)
state, timestep = jax.jit(env.reset)(key)

playLength = env.num_cities+1 

#Initialise Agent
class Network(hk.Module):
  def __init__(self, hidden_size=64, output_size=env.num_cities, name='QNet'):
    super().__init__(name=name)
    self.mlp = hk.nets.MLP(output_sizes=[hidden_size, output_size])

  def __call__(self, x):
    return self.mlp(x)

#FeedForward Method
def ffnet(x):
  module = Network()
  return module(x)

model = hk.transform(ffnet)

params = model.init(key, jnp.ones(4*env.num_cities+1))
all_params = [params, params] #[main, target]
printParams(all_params) 

steps = 0
epoch = 0 
while steps < totalSteps:
  epoch += 1
  score = 0 
  state, timestep = jax.jit(env.reset)(key)
  for t in range(playLength):
    






#env.render(state)

# Interact with the (jit-able) environment
#action = env.action_spec().generate_value()          # Action selection (dummy value here)
#state, timestep = jax.jit(env.step)(state, 1) 

Main Network Parameters
QNet/~/mlp/~/linear_0
Weights : (81, 64), Biases : (64,)

QNet/~/mlp/~/linear_1
Weights : (64, 20), Biases : (20,)

Target Network Parameters
QNet/~/mlp/~/linear_0
Weights : (81, 64), Biases : (64,)

QNet/~/mlp/~/linear_1
Weights : (64, 20), Biases : (20,)

[-1.          0.39861631  0.66380632  0.90807796  0.36692286  0.1340189
  0.94626236  0.70353913  0.77658904  0.66050398  0.86718833  0.24655128
  0.95453084  0.88447797  0.24636471  0.31162012  0.65564668  0.19815516
  0.45549071  0.19720936  0.75821579  0.22929907  0.80159271  0.43305802
  0.16452789  0.51343548  0.82576847  0.00415325  0.84096241  0.05243218
  0.41828477  0.39441133  0.69486761  0.11139321  0.34498036  0.94357693
  0.20319915  0.30095315  0.84559286  0.13179469  0.6033287   0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.         -1.     

