In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import numpy as np

from mlagents_envs.environment import UnityEnvironment

In [None]:
if(torch.cuda.is_available()):
    device = torch.device("cuda")
    print(device, torch.cuda.get_device_name(0))
else:
    device= torch.device("cpu")
    print(device)

cpu


  return torch._C._cuda_getDeviceCount() > 0


In [None]:
N_STATES  = 210
N_ACTIONS =4

In [None]:
#generate a tensor of size (1, N_Actions)
a = torch.ones(1, N_ACTIONS)
print(a, a.shape)

tensor([[1., 1.]]) torch.Size([1, 2])


In [None]:
#become a NN parameter with gradients
a = nn.Parameter(torch.ones(1, N_ACTIONS) * 0.0)
print(a)

Parameter containing:
tensor([[0., 0.]], requires_grad=True)


In [None]:
#generate a NN parameter [0, 0]
log_std = nn.Parameter(torch.ones(1, N_ACTIONS) * 0.0)
print(log_std)

Parameter containing:
tensor([[0., 0.]], requires_grad=True)


In [None]:
log_std.exp()

tensor([[1., 1.]], grad_fn=<ExpBackward>)

In [None]:
mu = torch.FloatTensor([[1, 1, 1, 1]])

In [None]:
#expand as the size of mu
log_std.exp().expand_as(mu)

tensor([[1., 1.]], grad_fn=<ExpandBackward>)

In [None]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, mean=0., std=0.1)
        nn.init.constant_(m.bias, 0.1)

In [None]:
class Net(nn.Module):
    def __init__(self, ):
        super(Net, self).__init__()
        
        self.actor = nn.Sequential(
            nn.Linear(N_STATES, 256),
            nn.LayerNorm(256),
            nn.Linear(256, 256),
            nn.LayerNorm(256),
            nn.Linear(256, N_ACTIONS)
        )
        self.log_std = nn.Parameter(torch.ones(1, N_ACTIONS) * 0.0)
        self.apply(init_weights)
    
    def forward(self, x):
        mu    = self.actor(x)
        std   = self.log_std.exp().expand_as(mu)
        dist  = Normal(mu, std)
        return dist

In [None]:
net = Net().to(device)

### Connect to Unity

In [None]:
env = UnityEnvironment(file_name= None, base_port=5004)

This means that some features may not work unless you upgrade the package with the lower version.Please find the versions that work best together from our release page.
https://github.com/Unity-Technologies/ml-agents/releases


In [None]:
env.reset()
behaviorNames = list(env.behavior_specs.keys())
behaviorName = behaviorNames[0]

In [None]:
DecisionSteps, TerminalSteps = env.get_steps(behaviorName)

### Send decision steps to NN to calculate actions

In [None]:
DecisionSteps.obs

[array([[-0.01467304, -0.01468306, -0.5208206 ,  4.        , -0.79952097,
          0.        ,  0.        ,  0.        ],
        [-0.02614026,  0.03401016, -0.45768166,  4.        , -0.0055027 ,
          0.        ,  0.        ,  0.        ],
        [ 0.06363224,  0.03799658, -1.1360741 ,  4.        , -0.4150591 ,
          0.        ,  0.        ,  0.        ]], dtype=float32)]

In [None]:
states = DecisionSteps.obs[0]

In [None]:
states = torch.FloatTensor(states).to(device)

In [None]:
dist = net(states)
print(dist)

Normal(loc: torch.Size([3, 2]), scale: torch.Size([3, 2]))


In [None]:
actions = dist.sample()
print(actions, actions.shape)

tensor([[ 4.2590,  1.0900],
        [ 0.3650,  1.4251],
        [ 2.1630, -0.7001]]) torch.Size([3, 2])


In [None]:
actions = actions.cpu().detach().numpy()
print(actions)

[[ 4.2590265   1.0899847 ]
 [ 0.36503994  1.4251024 ]
 [ 2.1630166  -0.70009315]]


In [None]:
env.set_actions(behaviorName, actions)

In [None]:
env.step()

In [None]:
env.close()

# Play for N steps

In [None]:
env = UnityEnvironment(file_name= None, base_port=5004)

This means that some features may not work unless you upgrade the package with the lower version.Please find the versions that work best together from our release page.
https://github.com/Unity-Technologies/ml-agents/releases


In [None]:
env.reset()
behaviorNames = list(env.behavior_specs.keys())
behaviorName = behaviorNames[0]

In [None]:
for frame in range(200):
    DecisionSteps, TerminalSteps = env.get_steps(behaviorName)
    for AgentID in TerminalSteps.agent_id:
        print("step", frame, "agent ", AgentID, "has terminal step")
    
    if(len(list(DecisionSteps.agent_id))>0):
        state = DecisionSteps.obs[0]
        state = torch.FloatTensor(state).to(device)
        dist = net(state)
        action = dist.sample()
        action = action.cpu().detach().numpy()    
        env.set_actions(behaviorName, action)   
        env.step()

step 22 agent  2 has terminal step
step 38 agent  1 has terminal step
step 49 agent  0 has terminal step
step 61 agent  2 has terminal step
step 70 agent  1 has terminal step
step 83 agent  0 has terminal step
step 92 agent  2 has terminal step
step 97 agent  1 has terminal step
step 106 agent  0 has terminal step
step 125 agent  2 has terminal step
step 128 agent  1 has terminal step
step 132 agent  0 has terminal step
step 147 agent  2 has terminal step
step 157 agent  1 has terminal step
step 169 agent  2 has terminal step
step 173 agent  0 has terminal step
step 186 agent  1 has terminal step
step 188 agent  2 has terminal step
step 198 agent  0 has terminal step


In [None]:
env.close()