In [1]:
import gym
import torch, numpy as np
from torch import nn
import tianshou as ts
from agents import TwoAgentPolicy
from agents.lib_agents import SinePolicy
from utils import make_render_env, make_env, make_discrete_env, make_render_discrete_env
from tianshou.utils import TensorboardLogger, WandbLogger
import time

In [2]:
# create environment

env = make_discrete_env()
train_envs = ts.env.DummyVectorEnv([make_discrete_env for _ in range(3)])
test_envs = ts.env.DummyVectorEnv([make_render_discrete_env for _ in range(5)])



In [3]:
# setup network

class Net(nn.Module):
    def __init__(self, state_shape, action_shape):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True),
            nn.Linear(128, 128), nn.ReLU(inplace=True),
            nn.Linear(128, 128), nn.ReLU(inplace=True),
            nn.Linear(128, np.prod(action_shape)),
        )

    def forward(self, obs, state=None, info={}):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float)
        batch = obs.shape[0]
        # print("obs", obs)
        # print("view", obs.view(batch, -1))
        logits = self.model(obs.view(batch, -1))
        return logits, state

In [4]:
state_shape = env.observation_space.shape or env.observation_space.n
puck_action_shape = env.action_space['puck'].shape or env.action_space['puck'].n
net1 = Net(state_shape, puck_action_shape)
bar_action_shape = env.action_space['bar'].shape or env.action_space['bar'].n
net2 = Net(state_shape, bar_action_shape)

In [5]:
optim1 = torch.optim.Adam(net1.parameters(), lr=1e-4)
optim2 = torch.optim.Adam(net2.parameters(), lr=1e-4)

In [6]:
# p1 = ts.policy.DQNPolicy(net1, optim1, discount_factor=0.9, estimation_step=3, target_update_freq=320)
# p1 = testPolicy()
p1 = SinePolicy()
p2 = ts.policy.DQNPolicy(net2, optim2, discount_factor=0.9, estimation_step=3, target_update_freq=320)
policy = TwoAgentPolicy(observation_space=env.observation_space, action_space=env.action_space, policies=(p1, p2))

In [7]:
# setup collector

train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(2000, 10), exploration_noise=True)
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True)

In [8]:
logger = WandbLogger(
    save_interval=1,
    config={
        'time': time.time(),
    }
)
result = ts.trainer.offpolicy_trainer(
policy, train_collector, test_collector,
max_epoch=10, step_per_epoch=10000, step_per_collect=10,
update_per_step=0.1, episode_per_test=100, batch_size=64, logger=logger)
print(f'Finished training! Use {result}')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: sarthakrout (use `wandb login --relogin` to force relogin)
  warn("The `IPython.html` package has been deprecated since IPython 4.0. "


Epoch #1: 10008it [00:51, 193.25it/s, env_step=10008, len=0, loss=0.083, n/ep=0, n/st=12, rew=0.00]
Epoch #2:   0%|          | 48/10000 [00:00<00:28, 347.93it/s, env_step=10056, len=0, loss=0.083, n/ep=0, n/st=12, rew=0.00]

Epoch #1: test_reward: 26.630872 ± 10.370685, best_reward: 28.439432 ± 4.876546 in #0


Epoch #2: 10008it [00:49, 202.27it/s, env_step=20016, len=0, loss=0.119, n/ep=0, n/st=12, rew=0.00]
Epoch #3:   0%|          | 36/10000 [00:00<00:36, 273.20it/s, env_step=20052, len=0, loss=0.120, n/ep=0, n/st=12, rew=0.00]

Epoch #2: test_reward: 26.908166 ± 10.083828, best_reward: 28.439432 ± 4.876546 in #0


Epoch #3: 10008it [00:52, 190.95it/s, env_step=30024, len=0, loss=0.117, n/ep=0, n/st=12, rew=0.00]
Epoch #4:   0%|          | 36/10000 [00:00<00:33, 297.51it/s, env_step=30060, len=0, loss=0.118, n/ep=0, n/st=12, rew=0.00]

Epoch #3: test_reward: 26.178679 ± 10.517769, best_reward: 28.439432 ± 4.876546 in #0


Epoch #4: 10008it [00:48, 205.76it/s, env_step=40032, len=0, loss=0.067, n/ep=0, n/st=12, rew=0.00]
Epoch #5:   0%|          | 36/10000 [00:00<00:53, 187.23it/s, env_step=40056, len=0, loss=0.066, n/ep=0, n/st=12, rew=0.00]

Epoch #4: test_reward: 26.387496 ± 10.614929, best_reward: 28.439432 ± 4.876546 in #0


Epoch #5: 10008it [00:48, 207.75it/s, env_step=50040, len=0, loss=0.031, n/ep=0, n/st=12, rew=0.00]


KeyboardInterrupt: 