Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test(luyd): add model test code #728

Merged
merged 7 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ding/entry/tests/test_serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def test_discrete_dt():
from ding.utils import set_pkg_seed
from ding.data import create_dataset
from ding.config import compile_config
from ding.model.template.dt import DecisionTransformer
from ding.model.template.decision_transformer import DecisionTransformer
from ding.policy import DTPolicy
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \
OfflineMemoryDataFetcher, offline_logger, termination_checker
Expand Down
2 changes: 1 addition & 1 deletion ding/example/dt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gym
from ditk import logging
from ding.model.template.dt import DecisionTransformer
from ding.model.template.decision_transformer import DecisionTransformer
from ding.policy import DTPolicy
from ding.envs import DingEnvWrapper, BaseEnvManager, BaseEnvManagerV2
from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper
Expand Down
2 changes: 1 addition & 1 deletion ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .maqac import DiscreteMAQAC, ContinuousMAQAC
from .madqn import MADQN
from .vae import VanillaVAE
from .dt import DecisionTransformer
from .decision_transformer import DecisionTransformer
from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS
from .bcq import BCQ
from .edac import EDAC
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def forward(self, timesteps, states, actions, returns_to_go, tar=None):
action_preds = self.predict_action(h[:, 1]) # predict action given r, s
else:
state_embeddings = self.state_encoder(
states.reshape(-1, 4, 84, 84).type(torch.float32).contiguous()
states.reshape(-1, *self.state_dim).type(torch.float32).contiguous()
) # (batch * block_size, h_dim)
state_embeddings = state_embeddings.reshape(B, T, self.h_dim) # (batch, block_size, h_dim)
returns_embeddings = self.embed_rtg(returns_to_go.type(torch.float32))
Expand Down
2 changes: 1 addition & 1 deletion ding/model/template/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,7 @@ def forward(self, x: torch.Tensor) -> Dict:
>>> # Init input's Keys:
>>> obs_dim, seq_len, bs, action_dim = 128, 64, 32, 4
>>> obs = torch.rand(seq_len, bs, obs_dim)
>>> model = GTrXLDiscreteHead(obs_dim, action_dim)
>>> model = GTrXLDQN(obs_dim, action_dim)
>>> outputs = model(obs)
>>> assert isinstance(outputs, dict)
"""
Expand Down
41 changes: 41 additions & 0 deletions ding/model/template/tests/test_acer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
import pytest
from itertools import product

from ding.model.template import ACER
from ding.torch_utils import is_differentiable

B = 4
obs_shape = [4, (8, ), (4, 64, 64)]
act_shape = [3, (6, )]
args = list(product(*[obs_shape, act_shape]))


@pytest.mark.unittest
class TestACER:

@pytest.mark.parametrize('obs_shape, act_shape', args)
def test_ACER(self, obs_shape, act_shape):
if isinstance(obs_shape, int):
inputs = torch.randn(B, obs_shape)
else:
inputs = torch.randn(B, *obs_shape)
model = ACER(obs_shape, act_shape)

outputs_c = model(inputs, mode='compute_critic')
assert isinstance(outputs_c, dict)
if isinstance(act_shape, int):
assert outputs_c['q_value'].shape == (B, act_shape)
elif len(act_shape) == 1:
assert outputs_c['q_value'].shape == (B, *act_shape)

outputs_a = model(inputs, mode='compute_actor')
assert isinstance(outputs_a, dict)
if isinstance(act_shape, int):
assert outputs_a['logit'].shape == (B, act_shape)
elif len(act_shape) == 1:
assert outputs_a['logit'].shape == (B, *act_shape)

outputs = {**outputs_a, **outputs_c}
loss = sum([v.sum() for v in outputs.values()])
is_differentiable(loss, model)
75 changes: 75 additions & 0 deletions ding/model/template/tests/test_bcq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest
from itertools import product
import torch
from ding.model.template import BCQ
from ding.torch_utils import is_differentiable

B = 4
obs_shape = [4, (8, )]
act_shape = [3, (6, )]
args = list(product(*[obs_shape, act_shape]))


@pytest.mark.unittest
class TestBCQ:

def output_check(self, model, outputs):
if isinstance(outputs, torch.Tensor):
loss = outputs.sum()
elif isinstance(outputs, dict):
loss = sum([v.sum() for v in outputs.values()])
is_differentiable(loss, model)

@pytest.mark.parametrize('obs_shape, act_shape', args)
def test_BCQ(self, obs_shape, act_shape):
if isinstance(obs_shape, int):
inputs_obs = torch.randn(B, obs_shape)
else:
inputs_obs = torch.randn(B, *obs_shape)
if isinstance(act_shape, int):
inputs_act = torch.randn(B, act_shape)
else:
inputs_act = torch.randn(B, *act_shape)
inputs = {'obs': inputs_obs, 'action': inputs_act}
model = BCQ(obs_shape, act_shape)

outputs_c = model(inputs, mode='compute_critic')
assert isinstance(outputs_c, dict)
if isinstance(act_shape, int):
assert torch.stack(outputs_c['q_value']).shape == (2, B)
else:
assert torch.stack(outputs_c['q_value']).shape == (2, B)
self.output_check(model.critic, torch.stack(outputs_c['q_value']))

outputs_a = model(inputs, mode='compute_actor')
assert isinstance(outputs_a, dict)
if isinstance(act_shape, int):
assert outputs_a['action'].shape == (B, act_shape)
elif len(act_shape) == 1:
assert outputs_a['action'].shape == (B, *act_shape)
self.output_check(model.actor, outputs_a)

outputs_vae = model(inputs, mode='compute_vae')
assert isinstance(outputs_vae, dict)
if isinstance(act_shape, int):
assert outputs_vae['recons_action'].shape == (B, act_shape)
assert outputs_vae['mu'].shape == (B, act_shape * 2)
assert outputs_vae['log_var'].shape == (B, act_shape * 2)
assert outputs_vae['z'].shape == (B, act_shape * 2)
elif len(act_shape) == 1:
assert outputs_vae['recons_action'].shape == (B, *act_shape)
assert outputs_vae['mu'].shape == (B, act_shape[0] * 2)
assert outputs_vae['log_var'].shape == (B, act_shape[0] * 2)
assert outputs_vae['z'].shape == (B, act_shape[0] * 2)
if isinstance(obs_shape, int):
assert outputs_vae['prediction_residual'].shape == (B, obs_shape)
else:
assert outputs_vae['prediction_residual'].shape == (B, *obs_shape)

outputs_eval = model(inputs, mode='compute_eval')
assert isinstance(outputs_eval, dict)
assert isinstance(outputs_eval, dict)
if isinstance(act_shape, int):
assert outputs_eval['action'].shape == (B, act_shape)
elif len(act_shape) == 1:
assert outputs_eval['action'].shape == (B, *act_shape)
59 changes: 42 additions & 17 deletions ding/model/template/tests/test_decision_transformer.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,49 @@
import pytest
from itertools import product
import torch
import torch.nn as nn
import torch.nn.functional as F

from ding.model.template import DecisionTransformer
from ding.torch_utils import is_differentiable

args = ['continuous', 'discrete']
action_space = ['continuous', 'discrete']
state_encoder = [None, nn.Sequential(nn.Flatten(), nn.Linear(8, 8), nn.Tanh())]
args = list(product(*[action_space, state_encoder]))
args.pop(1)


@pytest.mark.unittest
@pytest.mark.parametrize('action_space', args)
def test_decision_transformer(action_space):
@pytest.mark.parametrize('action_space, state_encoder', args)
def test_decision_transformer(action_space, state_encoder):
B, T = 4, 6
state_dim = 3
if state_encoder:
state_dim = (2, 2, 2)
else:
state_dim = 3
act_dim = 2
DT_model = DecisionTransformer(
state_dim=state_dim,
act_dim=act_dim,
state_encoder=state_encoder,
n_blocks=3,
h_dim=8,
context_len=T,
n_heads=2,
drop_p=0.1,
continuous=(action_space == 'continuous')
)
DT_model.configure_optimizers(1.0, 0.0003)

is_continuous = True if action_space == 'continuous' else False
timesteps = torch.randint(0, 100, [B, T], dtype=torch.long) # B x T
states = torch.randn([B, T, state_dim]) # B x T x state_dim
if state_encoder:
timesteps = torch.randint(0, 100, [B, 3 * T - 1, 1], dtype=torch.long) # B x T
else:
timesteps = torch.randint(0, 100, [B, T], dtype=torch.long) # B x T
if isinstance(state_dim, int):
states = torch.randn([B, T, state_dim]) # B x T x state_dim
else:
states = torch.randn([B, T, *state_dim]) # B x T x state_dim
if action_space == 'continuous':
actions = torch.randn([B, T, act_dim]) # B x T x act_dim
action_target = torch.randn([B, T, act_dim])
Expand All @@ -51,12 +66,19 @@ def test_decision_transformer(action_space):
state_preds, action_preds, return_preds = DT_model.forward(
timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go
)
assert state_preds.shape == (B, T, state_dim)
if state_encoder:
assert state_preds == None
assert return_preds == None
else:
assert state_preds.shape == (B, T, state_dim)
assert return_preds.shape == (B, T, 1)
assert action_preds.shape == (B, T, act_dim)
assert return_preds.shape == (B, T, 1)

# only consider non padded elements
action_preds = action_preds.view(-1, act_dim)[traj_mask.view(-1, ) > 0]
if state_encoder:
action_preds = action_preds.reshape(-1, act_dim)
else:
action_preds = action_preds.view(-1, act_dim)[traj_mask.view(-1, ) > 0]

if is_continuous:
action_target = action_target.view(-1, act_dim)[traj_mask.view(-1, ) > 0]
Expand All @@ -68,11 +90,14 @@ def test_decision_transformer(action_space):
else:
action_loss = F.cross_entropy(action_preds, action_target)

# print(action_loss)
# is_differentiable(action_loss, DT_model)
is_differentiable(
action_loss, [
DT_model.transformer, DT_model.embed_action, DT_model.predict_action, DT_model.embed_rtg,
DT_model.embed_state
]
) # pass
if state_encoder:
is_differentiable(
action_loss, [DT_model.transformer, DT_model.embed_action, DT_model.embed_rtg, DT_model.state_encoder]
)
else:
is_differentiable(
action_loss, [
DT_model.transformer, DT_model.embed_action, DT_model.predict_action, DT_model.embed_rtg,
DT_model.embed_state
]
)
57 changes: 57 additions & 0 deletions ding/model/template/tests/test_edac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
import pytest
from itertools import product

from ding.model.template import EDAC
from ding.torch_utils import is_differentiable

B = 4
obs_shape = [4, (8, )]
act_shape = [3, (6, )]
args = list(product(*[obs_shape, act_shape]))


@pytest.mark.unittest
class TestEDAC:

def output_check(self, model, outputs):
if isinstance(outputs, torch.Tensor):
loss = outputs.sum()
elif isinstance(outputs, list):
loss = sum([t.sum() for t in outputs])
elif isinstance(outputs, dict):
loss = sum([v.sum() for v in outputs.values()])
is_differentiable(loss, model)

@pytest.mark.parametrize('obs_shape, act_shape', args)
def test_EDAC(self, obs_shape, act_shape):
if isinstance(obs_shape, int):
inputs_obs = torch.randn(B, obs_shape)
else:
inputs_obs = torch.randn(B, *obs_shape)
if isinstance(act_shape, int):
inputs_act = torch.randn(B, act_shape)
else:
inputs_act = torch.randn(B, *act_shape)
inputs = {'obs': inputs_obs, 'action': inputs_act}
model = EDAC(obs_shape, act_shape, ensemble_num=2)

outputs_c = model(inputs, mode='compute_critic')
assert isinstance(outputs_c, dict)
assert outputs_c['q_value'].shape == (2, B)
self.output_check(model.critic, outputs_c)

if isinstance(obs_shape, int):
inputs = torch.randn(B, obs_shape)
else:
inputs = torch.randn(B, *obs_shape)
outputs_a = model(inputs, mode='compute_actor')
assert isinstance(outputs_a, dict)
if isinstance(act_shape, int):
assert outputs_a['logit'][0].shape == (B, act_shape)
assert outputs_a['logit'][1].shape == (B, act_shape)
elif len(act_shape) == 1:
assert outputs_a['logit'][0].shape == (B, *act_shape)
assert outputs_a['logit'][1].shape == (B, *act_shape)
outputs = {'mu': outputs_a['logit'][0], 'sigma': outputs_a['logit'][1]}
self.output_check(model.actor, outputs)
70 changes: 70 additions & 0 deletions ding/model/template/tests/test_ngu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
from itertools import product
import torch
from ding.model.template import NGU
from ding.torch_utils import is_differentiable

B = 4
H = 4
obs_shape = [4, (8, ), (4, 64, 64)]
act_shape = [4, (4, )]
args = list(product(*[obs_shape, act_shape]))


@pytest.mark.unittest
class TestNGU:

def output_check(self, model, outputs):
if isinstance(outputs, torch.Tensor):
loss = outputs.sum()
elif isinstance(outputs, list):
loss = sum([t.sum() for t in outputs])
elif isinstance(outputs, dict):
loss = sum([v.sum() for v in outputs.values()])
is_differentiable(loss, model)

@pytest.mark.parametrize('obs_shape, act_shape', args)
def test_ngu(self, obs_shape, act_shape):
if isinstance(obs_shape, int):
inputs_obs = torch.randn(B, H, obs_shape)
else:
inputs_obs = torch.randn(B, H, *obs_shape)
if isinstance(act_shape, int):
inputs_prev_action = torch.ones(B, act_shape).long()
else:
inputs_prev_action = torch.ones(B, *act_shape).long()
inputs_prev_reward_extrinsic = torch.randn(B, H, 1)
inputs_beta = 2 * torch.ones([4, 4], dtype=torch.long)
inputs = {
'obs': inputs_obs,
'prev_state': None,
'prev_action': inputs_prev_action,
'prev_reward_extrinsic': inputs_prev_reward_extrinsic,
'beta': inputs_beta
}

model = NGU(obs_shape, act_shape, collector_env_num=3)
outputs = model(inputs)
assert isinstance(outputs, dict)
if isinstance(act_shape, int):
assert outputs['logit'].shape == (B, act_shape, act_shape)
elif len(act_shape) == 1:
assert outputs['logit'].shape == (B, *act_shape, *act_shape)
self.output_check(model, outputs['logit'])

inputs = {
'obs': inputs_obs,
'prev_state': None,
'action': inputs_prev_action,
'reward': inputs_prev_reward_extrinsic,
'prev_reward_extrinsic': inputs_prev_reward_extrinsic,
'beta': inputs_beta
}
model = NGU(obs_shape, act_shape, collector_env_num=3)
outputs = model(inputs)
assert isinstance(outputs, dict)
if isinstance(act_shape, int):
assert outputs['logit'].shape == (B, act_shape, act_shape)
elif len(act_shape) == 1:
assert outputs['logit'].shape == (B, *act_shape, *act_shape)
self.output_check(model, outputs['logit'])
Loading
Loading