Skip to content

Commit

Permalink
feature(pu): add attention_map visualize utils
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 authored and jiayilee65 committed Apr 26, 2024
1 parent 3509e4e commit 3da1042
Show file tree
Hide file tree
Showing 13 changed files with 243 additions and 91 deletions.
73 changes: 73 additions & 0 deletions lzero/model/gpt_models/attention_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import os

def visualize_attention_map(model: Transformer, input_embeddings: torch.Tensor, kv_cache: Optional[KeysValues] = None, valid_context_lengths: Optional[torch.Tensor] = None, layer_id: int = 0, head_id: int = 0, suffix='visual_match_memlen1-0-15_v2/attn_map'):
"""
可视化attention map
参数:
model: Transformer模型
input_embeddings: 输入的token embdding序列,shape为(B, T)
kv_cache: 缓存的keys和values,用于支持长序列的推断
valid_context_lengths: 有效的上下文长度,用于处理变长上下文
layer_id: 要可视化的层的编号,从0开始
head_id: 要可视化的头的编号,从0开始
返回:
None
"""
assert 0 <= layer_id < len(model.blocks)
assert 0 <= head_id < model.config.num_heads

# B, T = input_embeddings.shape
B, T, C = input_embeddings.shape
if kv_cache is not None:
_, _, L, _ = kv_cache[layer_id].shape
else:
L = 0

with torch.no_grad():
model.eval()
# hidden_states = model.drop(model.embed(input_ids))
hidden_states = input_embeddings
input_ids = torch.arange(T).expand(B, T)

for i, block in enumerate(model.blocks):
if i < layer_id:
hidden_states = block(hidden_states, None if kv_cache is None else kv_cache[i], valid_context_lengths)
elif i == layer_id:
attention_map = block.attn.get_attention_map(block.ln1(hidden_states), None if kv_cache is None else kv_cache[i], valid_context_lengths)
break

attention_map = attention_map[0, head_id].cpu().numpy() # 取第一个样本的attention map

plt.figure(figsize=(10, 10))
sns.heatmap(attention_map, cmap='coolwarm', square=True, cbar_kws={"shrink": 0.5}, xticklabels=input_ids[0].cpu().numpy(), yticklabels=input_ids[0, -T:].cpu().numpy())
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.xlabel('Key')
plt.ylabel('Query')
plt.title(f'Attention Map of Layer {layer_id} Head {head_id}')
plt.show()
directory = f'/mnt/afs/niuyazhe/code/LightZero/render/{suffix}'
# 检查路径是否存在,不存在则创建
if not os.path.exists(directory):
os.makedirs(directory)
plt.savefig(f'{directory}/attn_map_layer_{layer_id}_head_{head_id}.png')
plt.close()

if __name__ == "__main__":
from transformers import GPT2Tokenizer

# 加载预训练的GPT-2模型和tokenizer
model = Transformer(config)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# 准备输入
text = "The quick brown fox jumps over the lazy dog."
input_ids = tokenizer.encode(text, return_tensors='pt')

# 可视化第0层第0个头的attention map
visualize_attention_map(model, input_ids, layer_id=0, head_id=0)
4 changes: 3 additions & 1 deletion lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@


"device": 'cuda:3',

'analysis_sim_norm': False,
'analysis_dormant_ratio': False,

# 'action_shape': 18, # TODO:for multi-task
# 'action_shape': 6, # TODO:for pong qbert
# 'action_shape': 9,# TODO:for mspacman
Expand Down
32 changes: 21 additions & 11 deletions lzero/model/gpt_models/cfg_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
# for memory env 我们必须保存第一帧的obs, 而在最后一帧进行MCTS时肯定会超出epidoe_length,因此设置context_length比训练的更长以保证不会把第一帧去掉。
# ===================

# 'max_blocks': 16+5,
# "max_tokens": 2 * (16+5), # 1+0+15 memory_length = 0
# "context_length": 2 * (16+5),
# "context_length_for_recurrent": 2 * (16+5),
'max_blocks': 16+5,
"max_tokens": 2 * (16+5), # 1+0+15 memory_length = 0
"context_length": 2 * (16+5),
"context_length_for_recurrent": 2 * (16+5),

# 'max_blocks': 76+5,
# "max_tokens": 2 * (76+5), # 1+60+15 memory_length = 60
Expand All @@ -33,10 +33,10 @@
# "context_length": 2 * (116+5),
# "context_length_for_recurrent": 2 * (116+5),

'max_blocks': 266+5,
"max_tokens": 2 * (266+5), # 1+250+15 memory_length = 250
"context_length": 2 * (266+5),
"context_length_for_recurrent": 2 * (266+5),
# 'max_blocks': 266+5,
# "max_tokens": 2 * (266+5), # 1+250+15 memory_length = 250
# "context_length": 2 * (266+5),
# "context_length_for_recurrent": 2 * (266+5),

# 'max_blocks': 16,
# "max_tokens": 2 * 16, # 1+0+15 memory_length = 0
Expand All @@ -63,20 +63,29 @@



"device": 'cuda:4',
"device": 'cuda:0',
'analysis_sim_norm': False,
'analysis_dormant_ratio': False,

'group_size': 8, # NOTE
# 'group_size': 768, # NOTE
'attention': 'causal',
# 'num_layers': 1,
# 'num_layers': 2, # same as <Transformer shine in RL> paper
# 'num_layers': 4,
# 'num_layers': 8,
# 'num_heads': 8,
# # 'embed_dim': 96, # TODO:for memory # same as <Transformer shine in RL> paper
# # 'embed_dim': 768, # TODO:Gpt2 Base
# 'embed_dim': 256, # TODO:


'num_layers': 8,
'num_heads': 8,
# 'embed_dim': 96, # TODO:for memory # same as <Transformer shine in RL> paper
# 'embed_dim': 768, # TODO:Gpt2 Base
# 'embed_dim': 128, # TODO: for memlen=250/500
'embed_dim': 256, # TODO:


# 'num_layers': 12, # TODO:Gpt2 Base
# 'num_heads': 12, # TODO:Gpt2 Base
# 'embed_dim': 768, # TODO:Gpt2 Base
Expand Down Expand Up @@ -119,6 +128,7 @@

'obs_type': 'image_memory', # 'vector', 'image'
'gamma': 1, # 0.5, 0.9, 0.99, 0.999
'dormant_threshold': 0.025,
}
from easydict import EasyDict

Expand Down
43 changes: 43 additions & 0 deletions lzero/model/gpt_models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,46 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, valid_
y = self.resid_drop(self.proj(y))

return y

@torch.no_grad()
def get_attention_map(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
获取attention map
参数:
x: 输入的序列,shape为(B, T, C)
kv_cache: 缓存的keys和values,用于支持长序列的推断
valid_context_lengths: 有效的上下文长度,用于处理变长上下文
返回:
attention_map: shape为(B, nh, T, L + T)的tensor,表示attention的分布
"""
B, T, C = x.size()
if kv_cache is not None:
b, nh, L, c = kv_cache.shape
assert nh == self.num_heads and b == B and c * nh == C
else:
L = 0

q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)

if kv_cache is not None:
kv_cache.update(k, None) # 这里只需要更新keys,不需要更新values
k, _ = kv_cache.get()

att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))

if valid_context_lengths is not None:
mask = torch.zeros(B, T, L + T, device=att.device)
for i in range(B):
mask[i] = self.mask[L:L + T, :L + T].clone()
mask[i, :, :(L - valid_context_lengths[i])] = 0
mask = mask.unsqueeze(1).expand(-1, att.size(1), -1, -1)
else:
mask = self.mask[L:L + T, :L + T]

att = att.masked_fill(mask == 0, float('-inf'))
att = F.softmax(att, dim=-1)

return att
103 changes: 56 additions & 47 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from einops import rearrange
from einops import rearrange
import gym
import os
from joblib import hash
import numpy as np
import torch
Expand Down Expand Up @@ -115,6 +116,7 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer
# self.context_length = self.config.max_tokens # TODO
# self.context_length_for_recurrent = self.config.max_tokens # TODO
self.dormant_threshold = config.dormant_threshold
self.analysis_dormant_ratio = config.analysis_dormant_ratio

self.transformer = Transformer(config)
self.num_observations_tokens = config.tokens_per_block - 1
Expand Down Expand Up @@ -376,7 +378,15 @@ def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysV
else:
# x = self.transformer(sequences, past_keys_values)
x = self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths)

# ============ visualize_attention_map ================= TODO
from lzero.model.gpt_models.attention_map import visualize_attention_map
for layer_id in range(8):
for head_id in range(8):
visualize_attention_map(self.transformer, sequences, past_keys_values, valid_context_lengths, layer_id=layer_id, head_id=head_id)
import sys
sys.exit(0)
# ========== for visualize ==========

# 1,...,0,1 https://github.com/eloialonso/iris/issues/19
logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps)

Expand Down Expand Up @@ -877,14 +887,17 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer=None, inverse_scalar_t
obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False)

# ========= logging for analysis =========
# calculate dormant ratio of encoder
shape = batch['observations'].shape # (..., C, H, W)
inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64)
dormant_ratio_encoder = cal_dormant_ratio(self.tokenizer.representation_network, inputs.detach(), percentage=self.dormant_threshold)
self.past_keys_values_cache_init_infer.clear()
self.past_keys_values_cache_recurrent_infer.clear()
self.keys_values_wm_list.clear()
torch.cuda.empty_cache()
if self.analysis_dormant_ratio:
# calculate dormant ratio of encoder
shape = batch['observations'].shape # (..., C, H, W)
inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64)
dormant_ratio_encoder = cal_dormant_ratio(self.tokenizer.representation_network, inputs.detach(), percentage=self.dormant_threshold)
self.past_keys_values_cache_init_infer.clear()
self.past_keys_values_cache_recurrent_infer.clear()
self.keys_values_wm_list.clear()
torch.cuda.empty_cache()
else:
dormant_ratio_encoder = 0
# 假设latent_state_roots是一个tensor
latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() # 计算L2范数
# print("L2 Norms:", l2_norms)
Expand Down Expand Up @@ -924,28 +937,15 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer=None, inverse_scalar_t
# reconstructed_images.shape torch.Size([34, 3, 5, 5])
# self.visualize_reconstruction_v1(original_images, reconstructed_images)

# ========== for debugging ==========
# ========== for visualize ==========
# batch['target_policy'].shape torch.Size([2, 17, 4])
# batch['target_value'].shape torch.Size([2, 17, 101])
# batch['rewards'].shape torch.Size([2, 17, 101])

# target_policy = batch['target_policy']
# target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1,101)).reshape(batch['observations'].shape[0],batch['observations'].shape[1],1) # torch.Size([2, 17, 1])
# true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1,101)).reshape(batch['observations'].shape[0],batch['observations'].shape[1],1) # torch.Size([2, 17, 1])


# import matplotlib.pyplot as plt
# # # 保存前三帧图像
# for i in range(1):
# plt.imshow(reconstructed_images[i][0].permute(1, 2, 0).cpu().detach().numpy()) # 将通道从 (C, H, W) 转换为 (H, W, C)
# plt.axis('off') # 关闭坐标轴
# plt.savefig(f'./render/image_frame_reconstructed_{i}.png')
# plt.close()
# # 保存前三帧图像
# for i in range(3):
# plt.imshow(batch['observations'][i][0].permute(1, 2, 0).cpu().detach().numpy()) # 将通道从 (C, H, W) 转换为 (H, W, C)
# plt.axis('off') # 关闭坐标轴
# plt.savefig(f'./render/image_frame_orig_{i}.png')
# plt.close()

# 计算重建损失和感知损失
latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 5, 5), reconstructed_images) # NOTE: for stack=1 TODO
Expand All @@ -959,24 +959,33 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer=None, inverse_scalar_t

# 前向传播,得到预测的观察、奖励和策略等
outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)})

# ========= logging for analysis =========
# calculate dormant ratio of world_model
dormant_ratio_world_model = cal_dormant_ratio(self, {'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, percentage=self.dormant_threshold)
self.past_keys_values_cache_init_infer.clear()
self.past_keys_values_cache_recurrent_infer.clear()
self.keys_values_wm_list.clear()
torch.cuda.empty_cache()
# ========== for debugging ==========
if self.analysis_dormant_ratio:
# calculate dormant ratio of world_model
dormant_ratio_world_model = cal_dormant_ratio(self, {'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, percentage=self.dormant_threshold)
self.past_keys_values_cache_init_infer.clear()
self.past_keys_values_cache_recurrent_infer.clear()
self.keys_values_wm_list.clear()
torch.cuda.empty_cache()
else:
dormant_ratio_world_model = 0

# ========== for visualize ==========
# outputs.logits_policy.shape torch.Size([2, 17, 4])
# outputs.logits_value.shape torch.Size([2, 17, 101])
# outputs.logits_rewards.shape torch.Size([2, 17, 101])

# predict_policy = outputs.logits_policy
# 使用 softmax 对最后一个维度(dim=-1)进行处理
# # 使用 softmax 对最后一个维度(dim=-1)进行处理
# predict_policy = F.softmax(outputs.logits_policy, dim=-1)
# predict_value = inverse_scalar_transform_handle(outputs.logits_value.reshape(-1,101)).reshape(batch['observations'].shape[0],batch['observations'].shape[1],1) # predict_value: torch.Size([2, 17, 1])
# predict_rewards = inverse_scalar_transform_handle(outputs.logits_rewards.reshape(-1,101)).reshape(batch['observations'].shape[0],batch['observations'].shape[1],1) # predict_rewards: torch.Size([2, 17, 1])
# self.visualize_reconstruction_v2(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy) # TODO

# # import pdb; pdb_set_trace()
# self.visualize_reconstruction_v2(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, suffix='visual_match_memlen1-0-15_v2') # TODO
# import sys
# sys.exit(0)
# ========== for visualize ==========

# 为了训练稳定性,使用target_tokenizer计算真实的下一个潜在状态表示
with torch.no_grad():
Expand Down Expand Up @@ -1228,7 +1237,7 @@ def visualize_reconstruction_v2(self, original_images, reconstructed_images, tar
ax[0].set_ylabel('Rewards')

ax0_twin = ax[0].twinx()
ax0_twin.plot(timesteps, target_predict_value[batch_idx, :, 0].cpu().detach().numpy(), 'b-', label='Target Predict Value')
# ax0_twin.plot(timesteps, target_predict_value[batch_idx, :, 0].cpu().detach().numpy(), 'b-', label='Target Predict Value')
ax0_twin.plot(timesteps, predict_value[batch_idx, :, 0].cpu().detach().numpy(), 'b--', label='Predict Value')
ax0_twin.legend(loc='upper right')
ax0_twin.set_ylabel('Value')
Expand Down Expand Up @@ -1261,31 +1270,31 @@ def visualize_reconstruction_v2(self, original_images, reconstructed_images, tar
ax[2].set_yticks([])
ax[2].set_ylabel('Reconstructed', rotation=0, labelpad=30)

# # 绘制predict_policy和target_policy的概率分布柱状图
# 计算柱状图的宽度和偏移量,确保它们不会重叠
bar_width = 0.8 / num_actions
offset = np.linspace(-0.4 + bar_width / 2, 0.4 - bar_width / 2, num_actions)
# 绘制predict_policy和target_policy的概率分布柱状图
bar_width = 8/num_actions # TODO:action_space而变化
for i in range(num_timesteps):
for j in range(num_actions):
ax[3].bar(i + j * bar_width - (num_actions - 1) * bar_width / 2, predict_policy[batch_idx, i, j].item(), width=bar_width, color=colors[j], alpha=0.5)
ax[4].bar(i + j * bar_width - (num_actions - 1) * bar_width / 2, target_policy[batch_idx, i, j].item(), width=bar_width, color=colors[j], alpha=0.5)

ax[3].bar(i + offset[j], predict_policy[batch_idx, i, j].item(), width=bar_width, color=colors[j], alpha=0.5)
ax[4].bar(i + offset[j], target_policy[batch_idx, i, j].item(), width=bar_width, color=colors[j], alpha=0.5)
ax[3].set_xticks(timesteps)
ax[3].set_xticklabels([])
ax[3].set_ylim(0, 1)
ax[3].set_ylabel('Predict Policy')

ax[4].set_xticks(timesteps)
ax[4].set_xticklabels(timesteps)
ax[4].set_ylim(0, 1)
ax[4].set_ylabel('Target Policy')
ax[4].set_xlabel('Timestep')

ax[4].set_ylabel('Target Policy')
# 添加图例
handles = [plt.Rectangle((0, 0), 1, 1, color=colors[i], alpha=0.5) for i in range(num_actions)]
labels = [f'Action {i}' for i in range(num_actions)]
ax[4].legend(handles, labels, loc='upper right', ncol=num_actions)

plt.tight_layout()
plt.savefig(f'/mnt/afs/niuyazhe/code/LightZero/render/{suffix}/reconstruction_visualization_batch_{batch_idx}_v2.png')
directory = f'/mnt/afs/niuyazhe/code/LightZero/render/{suffix}'
# 检查路径是否存在,不存在则创建
if not os.path.exists(directory):
os.makedirs(directory)
plt.savefig(f'{directory}/reconstruction_visualization_batch_{batch_idx}_v3.png')
# plt.savefig(f'./render/{suffix}/reconstruction_visualization_batch_{batch_idx}_v2.png')
plt.close()

Expand Down

0 comments on commit 3da1042

Please sign in to comment.