# Setting up training with codebase

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from air_hockey_challenge.framework import AgentBase

  logger.warn(
pybullet build time: May 20 2022 19:44:17


### Example dummy agent.
- this agent holds its current position/vel
- doesn't do anything.
- used to show how to setup an agent.

### Setup PyTorch Neural Network to convert observation space to action space
### Requirements
- setup Policy Gradient Method
- setup Policy
    - have a neural network (we are learning that)
    - setup TorchApproximator (connect Torch with MushroomRL)
    - plug Approximator into a Parametric Policy
- Plug Policy into Policy Gradient Method
- Train

In [2]:
#super stupid forward pass neural network
class ActionGenerator(nn.Module):
    def __init__(self, input_dim, output_dim, use_cuda = False, dropout=False, activation = nn.LeakyReLU(0.1) ):
        super().__init__()
        
        num_layers = 20
        layer_width = 10
        
        
        layers = [nn.Linear(input_dim, layer_width), activation]
        for i in range(num_layers-1):
            layers.append(nn.Linear(layer_width, layer_width))
            layers.append(activation)
        layers.append(nn.Linear(layer_width, output_dim))
        layers.append(activation)
        
        self.model = nn.Sequential(*layers)
        
    def forward(self, obs):
        out = self.model(obs.float())
        
        # return torch.reshape(out,6)
        return out

In [3]:
network = ActionGenerator(8,6)
print(network(torch.zeros(1,8)).shape)

torch.Size([1, 6])


### Setup DeepDummy Agent

In [4]:
from mushroom_rl.algorithms.policy_search import REINFORCE
from mushroom_rl.policy.deterministic_policy import DeterministicPolicy
from mushroom_rl.policy.gaussian_policy import GaussianPolicy
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.utils.optimizers import AdaptiveOptimizer

  from collections import Hashable


In [5]:
#setting this up

import numpy as np
from air_hockey_challenge.framework.air_hockey_challenge_wrapper import AirHockeyChallengeWrapper
from air_hockey_challenge.framework.challenge_core import ChallengeCore, CustomChallengeCore


def custom_reward_function(base_env, state, action, next_state, absorbing):
    return 9000 #its over 9000

mdp = AirHockeyChallengeWrapper(env="3dof-defend", action_type="position-velocity", interpolation_order=3, custom_reward_function=custom_reward_function, debug=True)

approximator = TorchApproximator(input_shape=12,
                                 output_shape=6,
                                 network=ActionGenerator,
                                 loss=F.smooth_l1_loss,
                                 # optimizer={'class': optim.Adam, 'params': {'lr': .001}},
                                 use_cuda=False)
policy = GaussianPolicy(approximator,torch.eye(6))
old_weights = policy.get_weights()

algorithm_params = {'mdp_info': mdp.info,
                    'policy': policy,
                    'optimizer':  AdaptiveOptimizer(eps=0.01)
                   }


reinforce = REINFORCE(**algorithm_params)

core = CustomChallengeCore(reinforce, mdp)

core.learn(n_episodes=10, n_episodes_per_fit=10, render=False) #render allows us to visualize what's going on

  0%|                                                                                                                                                      | 0/10 [00:00<?, ?it/s]

action: [-0.39131326 -0.46994038 -0.47879491 -1.27660013  1.27516779 -0.66490673]
action: [-0.21617386 -0.19224714  0.56376019  0.34252576  0.07282565 -1.71391308]
action: [-2.5188251  -0.49129517 -0.37864947  1.42563805  0.17539347  0.16668239]
action: [ 0.09410767 -2.76261118  1.06996078 -0.79642386 -1.25866185  0.76582028]
action: [ 1.29784936  0.90683773 -0.48266058  0.29947014 -0.91968283  1.08861356]
action: [-1.4640004  -0.08133104 -0.18319387 -0.75359144  0.68253284  1.21444013]
action: [-0.30600307  0.13810532 -1.38500151 -1.76729852 -1.54856355  0.680039  ]
action: [ 1.00209426  0.1516443   2.15399483 -0.98494164  1.59209347 -0.22366864]
action: [-1.20555     0.25182752 -0.81699814  0.31917586  1.69200789  0.32416088]
action: [ 1.03401035  0.0786174   1.54750229 -0.49911523  0.38093219 -0.25513629]
action: [ 0.37208849  0.60936936  1.31307868 -0.19137297  1.38319803  1.77349167]
action: [-0.2179854  -1.45983857  1.44908606  1.49328223 -0.44923266  1.38211782]
action: [-2.5806

 10%|██████████████▏                                                                                                                               | 1/10 [00:01<00:17,  2.00s/it]

action: [-0.71191336 -0.58792257 -0.34280715  0.4550729  -1.05616797  0.96562253]
action: [-0.6914489   0.54701369 -0.33031059 -0.35258188 -0.68198997 -1.71923052]
action: [-1.21284787 -0.15462207 -1.40747172 -0.6928222  -0.70780173  0.10144848]
action: [-0.34782578 -0.91783384  1.499225   -0.19965432 -1.12896997  0.59097819]
action: [-0.35339096  0.53108156 -0.81217979 -0.42778068  1.29276037 -0.07323511]
action: [-0.35054624  0.95347277  0.77699119 -0.02418953 -0.02439107 -0.5744979 ]
action: [0.91468484 0.42140629 0.78455425 1.68516243 0.01959182 0.36112668]
action: [ 0.72729871  0.39185007  0.56852156 -0.77393934  0.3057662   0.245969  ]
action: [ 0.9484214  -1.49488211  0.60580328  0.46621118 -0.77168836  0.77207234]
action: [ 0.28185286 -2.48881102 -0.251213    0.61332488 -0.18369432 -0.62130228]
action: [-0.64856948 -1.8050655   0.33670209 -1.34401794  0.02867507 -0.42603342]
action: [ 0.34348368  0.20638676  1.07921312  0.46225445 -0.51238492  0.61825108]
action: [ 1.44144003  

 20%|████████████████████████████▍                                                                                                                 | 2/10 [00:04<00:16,  2.01s/it]

action: [-1.14400131 -0.77310929 -0.82004956 -1.12483469 -0.77158581 -2.72581591]
action: [-0.18668898  0.31251182  0.13416429 -0.56995721 -0.99493544 -1.60145272]
action: [-1.48212944 -2.03771136  0.27018362 -0.65648983 -0.28655045 -2.43404672]
action: [-1.46474714  1.00277384 -0.63826135  1.49259041 -1.95602366 -0.51774968]
action: [ 0.63021752 -2.09157309 -1.0972941   0.67221575 -0.80533757  0.64143419]
action: [ 2.20168444  0.47002425  2.11595183  1.509672   -0.02673872  0.65124185]
action: [ 2.19936044 -0.52081312  1.81300176 -0.14257226  0.00625643 -0.1827686 ]
action: [-1.06575575 -1.00927422 -0.62291392 -0.5552798  -0.79069422  0.48321638]
action: [ 0.87524147  1.14532072  0.63442697 -1.03757279 -1.05163012 -0.20620796]
action: [0.75976488 0.05823422 0.2383258  0.34722183 1.36591803 0.30515748]
action: [-0.66654389  0.37545604  0.94709607  0.23387257 -0.76759149  0.10753474]
action: [-0.56107188 -0.21649215  0.81417527 -1.69114136  0.23576977 -1.47632087]
action: [ 0.73341905 -

 30%|██████████████████████████████████████████▌                                                                                                   | 3/10 [00:05<00:11,  1.59s/it]

action: [-2.64246418  1.05943835  0.28641706  0.47627715  0.95492821 -0.68081769]
action: [ 0.02765099  0.36142377  0.33304558 -1.2856108  -1.72378995 -0.99140609]
action: [-1.46052775  0.71843192 -0.13046396  0.15933956 -1.23823097 -1.11508456]
action: [-0.4765427   0.77145023 -0.81156983  0.53583944 -1.07091182 -0.95381638]
action: [-0.40778523 -1.06812871  1.03098593  1.9746943   0.67570006  1.05163674]
action: [ 0.11570477  0.13164409 -2.28531326  2.10678624  1.10237567  0.4129752 ]
action: [ 0.00223543  0.94770892 -0.77978494  0.24280193 -0.87781224  0.59972232]
action: [ 0.30948785 -1.20187232  1.29927016  0.18895047 -0.83333065  0.20965994]
action: [ 2.01167552 -0.44380987 -0.40249039  0.79536672  0.97082109  0.14545621]
action: [-0.67015052 -1.05154894  1.05367668 -0.9259004   1.32483296  1.00088405]
action: [-0.9510998   0.89966813  1.88885182  0.20254746  0.90457888 -0.42753467]
action: [ 1.2420643  -0.90387306 -0.25090955  1.67901191  1.08791265  0.23700536]
action: [ 2.6553

 40%|████████████████████████████████████████████████████████▊                                                                                     | 4/10 [00:05<00:06,  1.01s/it]

action: [2.64978544 0.25435053 0.35964183 1.3461777  0.4139422  0.49402342]
action: [ 0.47131753 -1.62918388 -1.45913586 -0.51074138 -1.25570642  1.55621304]
action: [-7.90858720e-01  1.23502842e+00  4.97178262e-01  1.35815908e+00
  1.25905751e-03  1.78086646e+00]
action: [-0.2674446  -1.13940424 -0.1050199  -1.82676069 -0.35855194  0.29751204]
action: [ 0.02712833 -0.13343415 -0.19671037 -1.01214273  0.3417082   0.65301323]
action: [ 1.83753201  0.22912414 -1.00398267  1.27922555  0.43128678 -1.16910484]
action: [-0.44199068 -0.36577209  0.46649815 -0.39870599  1.4922382   0.43169711]
action: [ 0.21825748  1.32544937 -0.31897945 -1.15460958  0.96359466 -1.97568344]
action: [ 0.25029215  0.06595547 -0.5767035   1.337974   -0.40472047  0.65192359]
action: [ 0.64769048 -0.98381349 -0.42279614 -1.13137134 -1.0876888   0.89345232]
action: [-0.20535269 -0.16020358 -1.47112879 -0.15904827  0.69601385 -0.21691827]
action: [-1.13407446 -0.09383413  1.51617276 -0.0062559   1.23086504 -0.2018008

 50%|███████████████████████████████████████████████████████████████████████                                                                       | 5/10 [00:05<00:03,  1.43it/s]

action: [ 1.33892523  0.3719316  -0.68269458 -0.92626602  1.57877912 -0.90321781]
action: [ 0.25246689  1.4504588   0.57697344  0.37270145  2.74983384 -0.46962835]
action: [ 0.3810391   0.074497    0.04473679 -0.68787133  0.78720263 -1.04204718]
action: [ 0.0870974   1.20960008  0.92197739 -0.42413918  0.21019261  0.61384013]
action: [ 0.57889862 -1.24197776  0.89103957  0.18034671  0.53122108 -0.72574512]
action: [ 0.11137915 -1.08453477 -0.74868454 -1.3648586   0.58347775 -0.10793907]
action: [-0.30830111 -1.50983923  0.0145948  -2.15292934 -1.0046782   0.73829298]
action: [-0.26985067 -1.8640582   2.05116217 -0.41022921 -1.10695031  0.56462212]
action: [-0.44994435 -0.49853145  0.25129605 -1.05906362  1.2463222  -0.03608763]
action: [ 0.94634328 -0.21832587 -0.37920554 -1.29648712 -2.36474555 -0.5697451 ]
action: [ 1.68475948 -0.89903678  0.03761468 -0.62498331 -0.63881646  1.11276723]
action: [ 2.84625734 -1.33265233 -1.00224184 -0.31394405  0.31376662 -0.63331345]
action: [ 0.7618

 60%|█████████████████████████████████████████████████████████████████████████████████████▏                                                        | 6/10 [00:06<00:03,  1.20it/s]

action: [ 0.0950752  -0.15027134  2.26882588  1.07495635  1.64698408  1.74531249]
action: [ 0.14607153  0.15486544  0.17296594  1.74911449 -1.35691847  0.25506755]
action: [ 0.88954778  0.27121388 -2.03177358  1.13496532  0.08183907  1.0506409 ]
action: [-1.70673131 -0.19339329  1.16778227 -1.33866048 -1.44951867 -0.84214807]
action: [-0.43065185 -1.83014     1.94900297 -0.95531007 -0.48864976  1.3954576 ]
action: [-0.1190283   1.01186376  0.79390169 -1.20999658  0.81972836 -0.33585202]
action: [ 0.08652517  0.2875121   0.42291679  1.13637813 -0.75563711 -0.7375999 ]
action: [ 1.41637107  0.90651263  0.90169653  2.1523331  -0.10187649  2.0834983 ]
action: [ 1.45181609 -1.5155333   0.09330817  0.2296852  -0.44472764  0.2240887 ]
action: [-1.10720624  0.6475167   1.05913033  0.25156928 -0.46325913 -0.90694364]
action: [ 0.55077307  1.94884186  0.17204629  1.25500908  2.67386861 -1.38014167]
action: [-0.45610353 -0.20426875 -0.20869342  1.51758532  0.32197698 -0.07130552]
action: [1.30423

 70%|███████████████████████████████████████████████████████████████████████████████████████████████████▍                                          | 7/10 [00:07<00:02,  1.26it/s]

action: [-0.13482711  0.38217668  1.49909451  0.15293719 -1.31216444  0.09670235]
action: [ 1.54606768 -0.1052624   1.37230817 -0.38032708  0.01837523  1.32543289]
action: [ 0.43417619  0.07754381 -0.901594    0.98749097 -0.10726816  0.01518159]
action: [ 0.28127111 -0.77261592  0.36563785  1.40688135  0.33567796 -0.06712616]
action: [-0.63418135 -1.36620525 -2.42714664  1.85230483  0.10902778 -0.18656893]
action: [ 0.53186294  0.37121391 -0.51309969  2.1651988  -0.38075266 -0.74484576]
action: [ 0.0446768   0.12118785 -0.47535256 -1.39540357  0.22815325  0.61862409]
action: [ 1.08960295  0.57365672 -0.8817394  -1.03693986 -1.28221111 -0.78067537]
action: [ 0.70595759  0.16945646 -0.81521146 -0.1666423  -1.24907311  0.37938561]
action: [-0.10788752 -1.29432381  0.56091054  0.80447414  0.16081573  0.04575995]
action: [-1.16420698 -0.18289048  0.09767254 -1.4750456  -0.06588637 -0.1613353 ]
action: [-0.64810705 -1.78405427  0.40636235 -1.66152154 -0.86215827 -0.0965166 ]
action: [0.10656

 80%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                            | 8/10 [00:08<00:02,  1.00s/it]

action: [-1.25614038 -0.93498059  1.32164874  0.03377199  0.26755734  0.24535362]
action: [ 2.79349571  0.31234411 -0.07721096  1.78821002  0.68723226  0.08321195]
action: [ 0.56522273 -0.73121578  0.19792773  0.36380256 -1.83250727 -0.3693527 ]
action: [-0.47655706  0.31766416  0.22200561 -1.63791162  1.25518228 -1.77362138]
action: [-0.15845106 -0.25788373 -1.95794319 -1.71970486 -0.18726898  1.11383494]
action: [ 1.29223016 -0.15999942 -0.39320907  0.85833123  0.49750279  0.10349535]
action: [ 0.77798203  0.25091222 -1.6808504  -0.26793655  1.12458892  0.04950339]
action: [ 0.02287513  0.38335582 -0.69513374 -0.62366402  0.46983151  0.89882367]
action: [ 0.121025   -1.37401313 -0.27126952  2.11610939 -0.71560794  0.42816698]
action: [0.44376348 0.81870233 0.84782657 2.03097095 0.96254046 0.95115049]
action: [-1.79970454  0.10762331 -0.07636185 -1.02392674 -2.25825911 -0.82454865]
action: [-0.52129522 -0.54184753 -0.5888459   1.30080843 -0.2424908   1.34308722]
action: [ 1.6164578   

 90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊              | 9/10 [00:08<00:00,  1.34it/s]

action: [-0.84032537 -1.16752218 -0.5168026   0.37437424  1.30652525  1.9600445 ]
action: [ 1.50526641 -1.43570087 -0.83949371 -1.16235352 -1.21805977 -0.74602047]
action: [-0.36806939  0.33627598  0.65688739  0.76556146  0.8241358  -0.27593445]
action: [ 0.52463859  0.79296724  0.10124396 -0.52435312  0.49566478 -0.68492711]
action: [-2.22984483  1.71050221 -2.12814024  1.26626434  0.27248134  0.26368728]
action: [ 1.46518125 -1.24212755  1.16053007  0.86176995  0.0926549   0.05074622]
action: [ 0.69887946  0.46215809  1.15790417 -0.44949462 -1.28911117 -0.92270658]
action: [-0.2739624   0.41983447  0.21960169  0.69764256  0.09398149 -0.52596001]
action: [-0.02941637  0.31404942  0.50774337  0.69138565  0.85116165  0.18641767]
action: [ 0.9959364   0.7565918   1.54109252  0.43744984 -1.44294608  1.73999337]
action: [-0.59882379 -0.30221911 -0.7005324  -0.83569414  0.49709122 -1.40320122]
action: [ 1.47373645 -2.48460346 -0.2693317  -0.70605175  0.10129699 -0.35743703]
action: [ 3.8033

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:09<00:00,  1.45it/s]

action: [ 1.40038919  0.26878955  0.15585358 -0.61727728 -1.29413548 -0.37114202]
action: [ 0.88628662  1.29735756  1.47122948  0.86402537 -0.45986128 -0.61824051]
action: [ 1.11212966  1.03174023 -0.00823401 -1.50113142  2.41236145  2.03208717]
action: [-0.06290584  0.79629518  0.18876509 -0.67594342  1.82565645 -1.08961714]
action: [-0.39999169  0.01369526 -0.30621392 -0.13036461  1.68043977 -0.99701507]
action: [ 0.49409734 -2.08250157  2.0054156   1.73028067 -0.54113823 -1.04139121]
action: [-0.8904026   1.43372274  2.20775013 -0.85945168  0.07719268 -0.39912197]
action: [ 2.21293044 -0.09193083 -0.05295176 -0.54053385 -1.11687319 -0.11945959]
action: [-0.27261388  0.32330537  1.20969125  0.41098488 -0.95654606 -1.32537158]
action: [-0.53868659  0.59225871 -0.38299125 -1.09593225 -2.27503209 -0.63400407]
action: [-0.0449157   0.00350214 -0.51906284  0.83624272  0.46714134 -0.92824168]
action: [ 0.83189234 -0.6775416   0.76185682 -1.65683962  1.31824122 -1.6653066 ]
action: [ 1.3715

                                                                                                                                                                                  

In [6]:
policy.get_weights() - old_weights

array([ 0.0000000e+00, -1.8626451e-09,  0.0000000e+00, ...,
        1.0030009e-02, -1.3227418e-02,  1.5266195e-02], dtype=float32)

In [16]:
policy.get_weights()

array([-0.14311804,  0.2781309 , -0.17962386, ...,  0.14129315,
        0.23999515,  0.06588355], dtype=float32)

## Success! We can create a training pipeline now using MushroomRL