Skip to content

Commit

Permalink
fix(pu): fix the cuda_cache of cal_dormant_ratio and FeatureAndGradie…
Browse files Browse the repository at this point in the history
…ntHook
  • Loading branch information
puyuan1996 committed Apr 23, 2024
1 parent 1af5f92 commit dfc47e1
Show file tree
Hide file tree
Showing 19 changed files with 254 additions and 137 deletions.
4 changes: 4 additions & 0 deletions lzero/entry/train_muzero_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ def train_muzero_context(
logging.info(f'eval offline finished!')
break

# 训练结束后移除钩子
policy._collect_model.encoder_hook.remove_hooks()
policy._target_model.encoder_hook.remove_hooks()

# Learner's after_run hook.
learner.call_hook('after_run')
return policy
5 changes: 5 additions & 0 deletions lzero/entry/train_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def train_unizero(
policy._collect_model.world_model.past_keys_values_cache_recurrent_infer.clear() # very important
policy._collect_model.world_model.keys_values_wm_list.clear() # TODO: 只适用于recurrent_inference() batch_pad
torch.cuda.empty_cache() # TODO: NOTE



# if collector.envstep > 0:
Expand All @@ -273,6 +274,10 @@ def train_unizero(
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break

# 训练结束后移除钩子
policy._collect_model.encoder_hook.remove_hooks()
policy._target_model.encoder_hook.remove_hooks()

# Learner's after_run hook.
learner.call_hook('after_run')
return policy
9 changes: 4 additions & 5 deletions lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,11 +551,10 @@ def search(
reward_latent_state_batch = network_output.reward_hidden_state
# reset the hidden states in LSTM every ``lstm_horizon_len`` steps in one search.
# which enable the model only need to predict the value prefix in a range (e.g.: [s0,...,s5])
# assert self._cfg.context_length_in_search > 0
reset_idx = (np.array(search_lens) % self._cfg.context_length_in_search == 0)
assert len(reset_idx) == batch_size
reward_latent_state_batch[0][:, reset_idx, :] = 0
reward_latent_state_batch[1][:, reset_idx, :] = 0
# reset_idx = ( (model.timestep + np.array(search_lens)) % self._cfg.context_length_in_search == 0)
# # reset_idx = (np.array(search_lens) % self._cfg.context_length_in_search == 0)
# reward_latent_state_batch[0][:, reset_idx, :] = 0
# reward_latent_state_batch[1][:, reset_idx, :] = 0
# is_reset_list = reset_idx.astype(np.int32).tolist()
reward_hidden_state_c_batch.append(reward_latent_state_batch[0])
reward_hidden_state_h_batch.append(reward_latent_state_batch[1])
Expand Down
45 changes: 29 additions & 16 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,36 +187,49 @@ def __init__(self):
self.grads_after = []

def setup_hooks(self, model):
# Hook to capture features before and after SimNorm
model.sim_norm.register_forward_hook(self.forward_hook)
model.sim_norm.register_full_backward_hook(self.backward_hook)
# Hooks to capture features and gradients at SimNorm
self.forward_handler = model.sim_norm.register_forward_hook(self.forward_hook)
self.backward_handler = model.sim_norm.register_full_backward_hook(self.backward_hook)

def forward_hook(self, module, input, output):
# input[0] is the input to SimNorm, output is the output of SimNorm
self.features_before.append(input[0].detach())
self.features_after.append(output.detach())
with torch.no_grad():
self.features_before.append(input[0])
self.features_after.append(output)

def backward_hook(self, module, grad_input, grad_output):
# grad_input[0] is the gradient at the input to SimNorm
# grad_output[0] is the gradient at the output of SimNorm
self.grads_before.append(grad_input[0].detach())
self.grads_after.append(grad_output[0].detach())
with torch.no_grad():
self.grads_before.append(grad_input[0] if grad_input[0] is not None else None)
self.grads_after.append(grad_output[0] if grad_output[0] is not None else None)

def analyze(self):
# Calculate L2 norms of features
l2_norm_before = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_before]))
l2_norm_after = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_after]))

# Calculate norms of gradients
grad_norm_before = torch.mean(torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_before]))
grad_norm_after = torch.mean(torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_after]))
grad_norm_before = torch.mean(torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_before if g is not None]))
grad_norm_after = torch.mean(torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_after if g is not None]))

# Clear stored data and delete tensors to free memory
self.clear_data()

# Optionally clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()

# print(f"L2 Norm of features before SimNorm: {l2_norm_before}")
# print(f"L2 Norm of features after SimNorm: {l2_norm_after}")
# print(f"Gradient Norm before SimNorm: {grad_norm_before}")
# print(f"Gradient Norm after SimNorm: {grad_norm_after}")
return l2_norm_before, l2_norm_after, grad_norm_before, grad_norm_after

def clear_data(self):
del self.features_before[:]
del self.features_after[:]
del self.grads_before[:]
del self.grads_after[:]

def remove_hooks(self):
self.forward_handler.remove()
self.backward_handler.remove()




class RepresentationNetworkGPT(nn.Module):
Expand Down
32 changes: 16 additions & 16 deletions lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,26 @@
# "gru_gating": False,
# # "gru_gating": True,

'tokens_per_block': 2,
'max_blocks': 10,
"max_tokens": 2 * 10, # TODO: horizon:8
# "context_length": 20,
# "context_length_for_recurrent": 20,
# 'tokens_per_block': 2,
# 'max_blocks': 10,
# "max_tokens": 2 * 10, # TODO: horizon:8
# # "context_length": 20,
# # "context_length_for_recurrent": 20,
# "context_length": 2 * 4, # TODO
# "context_length_for_recurrent": 2 * 4,
# "recurrent_keep_deepth": 100,
# "gru_gating": False,
# # "gru_gating": True,

'tokens_per_block': 2,
'max_blocks': 20,
"max_tokens": 2 * 20, # TODO: horizon:8
# "context_length": 2*20,
# "context_length_for_recurrent": 2*20,
"context_length": 2 * 4, # TODO
"context_length_for_recurrent": 2 * 4,
"recurrent_keep_deepth": 100,
"gru_gating": False,
# "gru_gating": True,

# 'tokens_per_block': 2,
# 'max_blocks': 20,
# "max_tokens": 2 * 20, # TODO: horizon:8
# # "context_length": 2*20,
# # "context_length_for_recurrent": 2*20,
# "context_length": 2 * 4, # TODO
# "context_length_for_recurrent": 2 * 4,
# "recurrent_keep_deepth": 100,
# "gru_gating": False,

# 'tokens_per_block': 2,
# 'max_blocks': 30,
Expand Down
11 changes: 9 additions & 2 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0,
obs_embeddings_or_act_tokens = {'obs_embeddings': token}
# self.keys_values_wm 会被原地改动 ===============
outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm, kvcache_independent=False, is_init_infer=False)
print('keys_values_wm_size_list_current:', self.keys_values_wm_size_list_current)
# print('keys_values_wm_size_list_current:', self.keys_values_wm_size_list_current)
self.keys_values_wm_size_list_current = [i+1 for i in self.keys_values_wm_size_list_current] # NOTE: +1 ===============
if k == 0:
# 如果k==0,token是action_token,outputs_wm.logits_rewards 是有值的
Expand Down Expand Up @@ -881,6 +881,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer=None, inverse_scalar_t
shape = batch['observations'].shape # (..., C, H, W)
inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64)
dormant_ratio_encoder = cal_dormant_ratio(self.tokenizer.representation_network, inputs.detach(), percentage=self.dormant_threshold)
self.past_keys_values_cache_init_infer.clear()
self.past_keys_values_cache_recurrent_infer.clear()
self.keys_values_wm_list.clear()
torch.cuda.empty_cache()
# 假设latent_state_roots是一个tensor
latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() # 计算L2范数
# print("L2 Norms:", l2_norms)
Expand Down Expand Up @@ -958,7 +962,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer=None, inverse_scalar_t
# ========= logging for analysis =========
# calculate dormant ratio of world_model
dormant_ratio_world_model = cal_dormant_ratio(self, {'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, percentage=self.dormant_threshold)

self.past_keys_values_cache_init_infer.clear()
self.past_keys_values_cache_recurrent_infer.clear()
self.keys_values_wm_list.clear()
torch.cuda.empty_cache()
# ========== for debugging ==========
# outputs.logits_policy.shape torch.Size([2, 17, 4])
# outputs.logits_value.shape torch.Size([2, 17, 101])
Expand Down
10 changes: 6 additions & 4 deletions lzero/model/muzero_context_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
downsample: bool = False,
norm_type: Optional[str] = 'BN',
discrete_action_encoding_type: str = 'one_hot',
context_length: int = 5,
context_length_init: int = 5,
use_sim_norm: bool = False,
*args,
**kwargs
Expand Down Expand Up @@ -152,7 +152,6 @@ def __init__(
self.encoder_hook = FeatureAndGradientHook()
self.encoder_hook.setup_hooks(self.representation_network)


self.dynamics_network = DynamicsNetwork(
observation_shape,
self.action_encoding_dim,
Expand Down Expand Up @@ -214,7 +213,7 @@ def __init__(
nn.Linear(self.pred_hid, self.pred_out),
)
self.timestep = 0
self.context_length = context_length # TODO
self.context_length_init = context_length_init # TODO

def initial_inference(self, obs: torch.Tensor, action_batch=None, current_obs_batch=None) -> MZNetworkOutput:
"""
Expand Down Expand Up @@ -242,14 +241,17 @@ def initial_inference(self, obs: torch.Tensor, action_batch=None, current_obs_ba
# obs_act_dict = {'obs':obs, 'action':action_batch, 'current_obs':current_obs_batch}

if self.training or action_batch is None:
# 训练
self.latent_state = self._representation(obs)
self.timestep = 0
else:
# collect/eval
if action_batch is not None and max(action_batch) == -1: # 一集的第一步
self.latent_state = self._representation(current_obs_batch)
else:
action_batch = torch.from_numpy(np.array(action_batch)).to(self.latent_state.device)
self.recurrent_inference(self.latent_state, action_batch) # 更新self.latent_state
if self.timestep % self.context_length == 0:
if self.timestep % self.context_length_init == 0:
# print(f'self.timestep:{self.timestep}, reset latent_state')
# context reset method TODO: context recent method
self.latent_state = self._representation(current_obs_batch)
Expand Down
4 changes: 2 additions & 2 deletions lzero/model/muzero_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def __init__(
norm_type=norm_type,
embedding_dim=768,
group_size=8,
use_sim_norm=use_sim_norm,# TODO
res_connection_in_dynamics=True,# TODO
use_sim_norm=use_sim_norm, # TODO
res_connection_in_dynamics=True, # TODO
)
# ====== for analysis ======
self.encoder_hook = FeatureAndGradientHook()
Expand Down
34 changes: 23 additions & 11 deletions lzero/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ def forward(self, x):
def __repr__(self):
return f"SimNorm(dim={self.dim})"




class LinearOutputHook:
def __init__(self):
self.outputs = []

def __call__(self, module, module_in, module_out):
self.outputs.append(module_out)
def __call__(self, module, input, output):
self.outputs.append(output)

def cal_dormant_ratio(model, *inputs, percentage=0.025):
hooks = []
Expand All @@ -46,8 +49,7 @@ def cal_dormant_ratio(model, *inputs, percentage=0.025):
dormant_neurons = 0

for _, module in model.named_modules():
# if isinstance(module, nn.Linear):
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
if isinstance(module, (nn.Linear, nn.Conv2d, nn.LSTM)):
hook = LinearOutputHook()
hooks.append(hook)
hook_handlers.append(module.register_forward_hook(hook))
Expand All @@ -56,23 +58,33 @@ def cal_dormant_ratio(model, *inputs, percentage=0.025):
model(*inputs)

for module, hook in zip(
(module
for module in model.modules() if isinstance(module, nn.Linear)),
hooks):
(module for module in model.modules() if isinstance(module, (nn.Linear, nn.Conv2d, nn.LSTM))), hooks):

with torch.no_grad():
for output_data in hook.outputs:
mean_output = output_data.abs().mean(0)
avg_neuron_output = mean_output.mean()
dormant_indices = (mean_output < avg_neuron_output *
percentage).nonzero(as_tuple=True)[0]
total_neurons += module.weight.shape[0]
dormant_neurons += len(dormant_indices)
dormant_indices = (mean_output < avg_neuron_output * percentage).nonzero(as_tuple=True)[0]
if isinstance(module, nn.Linear):
total_neurons += module.weight.shape[0] * output_data.shape[0]
dormant_neurons += len(dormant_indices)
elif isinstance(module, nn.Conv2d):
total_neurons += module.weight.shape[0] * output_data.shape[0] * output_data.shape[2] * output_data.shape[3]
dormant_neurons += len(dormant_indices)
elif isinstance(module, nn.LSTM):
total_neurons += module.hidden_size * module.num_layers * output_data.shape[0] * output_data.shape[1]
dormant_neurons += len(dormant_indices)

for hook in hooks:
hook.outputs.clear()
del hook.outputs

for hook_handler in hook_handlers:
hook_handler.remove()
del hook_handler

if torch.cuda.is_available():
torch.cuda.empty_cache()

return dormant_neurons / total_neurons

Expand Down

0 comments on commit dfc47e1

Please sign in to comment.