This repository has been archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 41
/
bc.py
136 lines (115 loc) · 4.41 KB
/
bc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
import copy
import math
import time
import d4rl
import gym
import hydra
import numpy as np
import torch
import torch.nn as nn
from omegaconf import DictConfig, OmegaConf
import salina
import salina.rl.functional as RLF
from salina import Workspace, get_arguments, get_class, instantiate_class
from salina.agents import Agents, NRemoteAgent, TemporalAgent
from salina.agents.gyma import AutoResetGymAgent, GymAgent
from salina.logger import TFLogger
from salina.rl.replay_buffer import ReplayBuffer
from salina_examples import weight_init
from salina_examples.offline_rl.d4rl import *
def _state_dict(agent, device):
sd = agent.state_dict()
for k, v in sd.items():
sd[k] = v.to(device)
return sd
def run_bc(action_agent, logger, cfg):
action_agent.set_name("action_agent")
env_evaluation_agent = AutoResetGymAgent(
get_class(cfg.algorithm.env),
get_arguments(cfg.algorithm.env),
n_envs=int(
cfg.algorithm.evaluation.n_envs / cfg.algorithm.evaluation.n_processes
),
)
action_evaluation_agent = copy.deepcopy(action_agent)
env = instantiate_class(cfg.algorithm.env)
train_temporal_action_agent = TemporalAgent(action_agent)
train_temporal_action_agent.to(cfg.algorithm.loss_device)
replay_buffer = d4rl_transition_buffer(env)
evaluation_agent, evaluation_workspace = NRemoteAgent.create(
TemporalAgent(Agents(env_evaluation_agent, action_evaluation_agent)),
num_processes=cfg.algorithm.evaluation.n_processes,
t=0,
n_steps=cfg.algorithm.evaluation.n_timesteps,
epsilon=0.0,
)
evaluation_agent.seed(cfg.algorithm.evaluation.env_seed)
evaluation_agent._asynchronous_call(
evaluation_workspace,
t=0,
n_steps=cfg.algorithm.evaluation.n_timesteps,
epsilon=0.0,
)
logger.message("Learning")
optimizer_args = get_arguments(cfg.algorithm.optimizer)
optimizer_action = get_class(cfg.algorithm.optimizer)(
action_agent.parameters(), **optimizer_args
)
for epoch in range(cfg.algorithm.max_epoch):
if not evaluation_agent.is_running():
creward, done = evaluation_workspace["env/cumulated_reward", "env/done"]
creward = creward[done]
if creward.size()[0] > 0:
ns = []
for i in range(creward.size()[0]):
r = creward[i].item()
ns.append(env.get_normalized_score(r))
logger.add_scalar("evaluation/normalized_", np.mean(ns), epoch)
logger.add_scalar("evaluation/reward", creward.mean().item(), epoch)
for a in evaluation_agent.get_by_name("action_agent"):
a.load_state_dict(_state_dict(action_agent, "cpu"))
evaluation_workspace.copy_n_last_steps(1)
evaluation_agent._asynchronous_call(
evaluation_workspace,
t=1,
n_steps=cfg.algorithm.evaluation.n_timesteps - 1,
epsilon=0.0,
)
batch_size = cfg.algorithm.batch_size
replay_workspace = replay_buffer.select_batch_n(batch_size).to(
cfg.algorithm.loss_device
)
target_action = replay_workspace["action"].detach()
train_temporal_action_agent(
replay_workspace, t=0, n_steps=replay_workspace.time_size()
)
action = replay_workspace["action"]
mse = ((target_action - action) ** 2).sum(-1)
mse_loss = mse.mean()
logger.add_scalar("loss/mse_1", mse_loss.item(), epoch)
optimizer_action.zero_grad()
mse_loss.backward()
if cfg.algorithm.clip_grad > 0:
n = torch.nn.utils.clip_grad_norm_(
action_agent.parameters(), cfg.algorithm.clip_grad
)
logger.add_scalar("monitor/grad_norm", n.item(), epoch)
optimizer_action.step()
@hydra.main(config_path=".", config_name="gym.yaml")
def main(cfg):
logger = instantiate_class(cfg.logger)
logger.save_hps(cfg)
from importlib import import_module
action_agent = instantiate_class(cfg.action_agent)
run_bc(action_agent, logger, cfg)
import os
if __name__ == "__main__":
import torch.multiprocessing as mp
mp.set_start_method("spawn")
main()