In [15]:
import torch
import numpy as np

!pip install gym[box2d]
import gym
from cogment_verse_torch_agents.muzero.agent import reward_transform, reward_transform_inverse, DynamicsAdapter
from cogment_verse_torch_agents.muzero.networks import resnet, Distributional




You should consider upgrading via the '/usr/local/bin/python -m pip install --upgrade pip' command.[0m


In [43]:
obs_dim = 8
act_dim = 4
hidden_dim = 32
hidden_layers = 2
rbins = 32
rmin = -100
rmax = 100

reward_distribution = Distributional(
    rmin,
    rmax,
    hidden_dim,
    rbins,
    reward_transform,
    reward_transform_inverse,
)

dynamics = DynamicsAdapter(
    resnet(
        obs_dim + act_dim,
        hidden_dim,
        hidden_dim,
        hidden_layers - 1,
        final_act=torch.nn.LeakyReLU(),
    ),
    act_dim,
    hidden_dim,
    obs_dim,
    reward_dist=reward_distribution,
)

In [44]:
replay_buffer = []
batch_size = 32

In [53]:
env = gym.make("LunarLander-v2")
state = env.reset()

def cross_entropy(pred, target):
    return torch.mean(torch.sum(-target * torch.log(pred), dim=1))

def train_step(dynamics, batch):
    state, action, reward, new_state = batch
    pred_next_state, pred_reward_probs, pred_reward = dynamics(state, action.view(-1))
    target_reward_probs = reward_distribution.compute_target(reward)
    loss_r = cross_entropy(pred_reward_probs, target_reward_probs)
    loss_s = torch.mean((pred_next_state-new_state)**2)
    return dict(
        loss_r=loss_r,
        loss_s=loss_s,
        loss=loss_r+loss_s,
    )

for step in range(100):
    action = np.random.randint(0, 4)
    new_state, reward, done, info = env.step(action)
    replay_buffer.append((state, action, reward, new_state))
    
    if done:
        state = env.reset()
    else:
        state = new_state
    
    if len(replay_buffer) > 2 * batch_size:
        idx = np.random.randint(0, len(replay_buffer), batch_size)
        batch = [torch.vstack([torch.tensor(replay_buffer[i][j]) for i in idx]) for j in range(4)]
        info = train_step(dynamics, batch)
        print(info)
    

{'loss_r': tensor(3.4937, grad_fn=<MeanBackward0>), 'loss_s': tensor(0.7542, grad_fn=<MeanBackward0>), 'loss': tensor(4.2479, grad_fn=<AddBackward0>)}
{'loss_r': tensor(3.5100, grad_fn=<MeanBackward0>), 'loss_s': tensor(0.7821, grad_fn=<MeanBackward0>), 'loss': tensor(4.2921, grad_fn=<AddBackward0>)}
{'loss_r': tensor(3.4573, grad_fn=<MeanBackward0>), 'loss_s': tensor(0.7632, grad_fn=<MeanBackward0>), 'loss': tensor(4.2204, grad_fn=<AddBackward0>)}
{'loss_r': tensor(3.5797, grad_fn=<MeanBackward0>), 'loss_s': tensor(0.7996, grad_fn=<MeanBackward0>), 'loss': tensor(4.3793, grad_fn=<AddBackward0>)}
{'loss_r': tensor(3.4521, grad_fn=<MeanBackward0>), 'loss_s': tensor(0.8148, grad_fn=<MeanBackward0>), 'loss': tensor(4.2670, grad_fn=<AddBackward0>)}
{'loss_r': tensor(3.3879, grad_fn=<MeanBackward0>), 'loss_s': tensor(0.7485, grad_fn=<MeanBackward0>), 'loss': tensor(4.1364, grad_fn=<AddBackward0>)}
{'loss_r': tensor(3.4879, grad_fn=<MeanBackward0>), 'loss_s': tensor(0.7793, grad_fn=<MeanBack