In [None]:
import numpy as np
import torch
import os

import init_path

all_possible_lang = [
    "open the middle drawer of the cabinet",
    "put the bowl on the stove",
    "put the wine bottle on top of the cabinet",
    "open the top drawer and put the bowl inside",
    "put the bowl on top of the cabinet",
    "push the plate to the front of the stove",
    "put the cream cheese in the bowl",
    "turn on the stove",
    "put the bowl on the plate",
    "put the wine bottle on the rack",
]

'''
    languages:
T1 0 open the middle drawer of the cabinet
T5 1 put the bowl on the stove
T9 2 put the wine bottle on top of the cabinet
T2 3 open the top drawer and put the bowl inside
T6 4 put the bowl on top of the cabinet
T3 5 push the plate to the front of the stove
T7 6 put the cream cheese in the bowl
T10 7 turn on the stove
T4 8 put the bowl on the plate
T8 9 put the wine bottle on the rack
    '''

In [None]:
CLIP_emb = np.load('/home/zhaoyixiu/ISR_project/ACT/output_lang_emb_9tasks-CLIP.npy')
onehot_emb = np.load('/home/zhaoyixiu/ISR_project/ACT/output_lang_emb_9tasks-onehot.npy')

print(CLIP_emb.shape, onehot_emb.shape)

def cal_emb_similarity(emb1, emb2):
    # emb1, emb2 are all vectors
    emb1 = emb1 / np.linalg.norm(emb1)
    emb2 = emb2 / np.linalg.norm(emb2)
    return np.dot(emb1, emb2)

In [3]:
new_order_index = [0, 3, 5, 8, 1, 4, 6, 9, 2, 7]

all_possible_lang = [all_possible_lang[i] for i in new_order_index]
CLIP_emb = CLIP_emb[new_order_index]
onehot_emb = onehot_emb[new_order_index]

In [4]:
import numpy as np
import plotly.graph_objects as go

def plot_embedding_similarity(embeddings, midpoint=0.7, axis_labels=None):
    """
    Plot similarity matrix for a set of embeddings using plotly
    
    Args:
        embeddings: numpy array of shape (n, 512) containing n embeddings
        
    Returns:
        None (displays the plot in notebook)
    """
    # 计算余弦相似度矩阵
    n = len(embeddings)
    similarity_matrix = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            similarity_matrix[i, j] = cal_emb_similarity(embeddings[i], embeddings[j])
    
    # 创建热力图
    fig = go.Figure(data=go.Heatmap(
        z=similarity_matrix,
        text=np.round(similarity_matrix, 3),  # 显示具体数值，保留3位小数
        texttemplate='%{text}',
        textfont={"size": 10},
        hoverongaps=False,
        hovertemplate='Row: %{y}<br>Column: %{x}<br>Similarity: %{z}<extra></extra>',
        colorscale='RdBu',  # 使用红蓝色谱，也可以换成'Viridis'等其他颜色
        zmid=midpoint,  # 将颜色刻度的中点设为0
        colorbar=dict(
            title='Cosine Similarity',
            titleside='right',
            titlefont=dict(size=14)
        )
    ))
    
    # 设置图表布局
    fig.update_layout(
        title=dict(
            text='Embedding Similarity Matrix',
            x=0.5,
            font=dict(size=16)
        ),
        width=800,
        height=800,
        xaxis=dict(
            title='Embedding Index',
            tickmode='array',
            ticktext=['T'+str(i+1) for i in range(n)] if axis_labels is None else axis_labels,
            tickvals=list(range(n))
        ),
        yaxis=dict(
            title='Embedding Index',
            tickmode='array',
            ticktext=['T'+str(i+1) for i in range(n)] if axis_labels is None else axis_labels,
            tickvals=list(range(n))
        )
    )
    
    # 在Jupyter notebook中显示图表
    fig.show()

# 使用示例：
# 假设有一个包含5个512维embedding的数组
# example_embeddings = np.random.randn(5, 512)
# plot_embedding_similarity(example_embeddings)

In [None]:
print("languages:")
for i in range(len(all_possible_lang)):
    print(i, all_possible_lang[i])
plot_embedding_similarity(CLIP_emb, 0.7)

In [None]:
print("languages:")
for i in range(len(all_possible_lang)):
    print(i+1, all_possible_lang[i])
plot_embedding_similarity(onehot_emb, 0.7)

In [7]:
def plot_embedding_similarity_bar(embeddings_n, target_embedding, labels=['T1', 'T2', 'T3', 'T4', 'T5', 'T7', 'T8', 'T9', 'T10']):
    """
    Plot bar chart of similarities between multiple embeddings and a target embedding
    
    Args:
        embeddings_n: numpy array of shape (n, 512) containing n embeddings
        target_embedding: numpy array of shape (512,) - the target embedding to compare against
        
    Returns:
        None (displays the plot in notebook)
    """
    # 确保target_embedding是2D数组
    target_embedding_2d = target_embedding.reshape(1, -1)
    
    # 计算余弦相似度
    n = len(embeddings_n)
    similarities = np.zeros((n,))
    for i in range(n):
        similarities[i] = cal_emb_similarity(embeddings_n[i], target_embedding)
    
    # 创建x轴标签
    # labels = [f"T{i+1}" for i in range(len(similarities))]
    # labels = ['T1', 'T2', 'T3', 'T4', 'T5', 'T7', 'T8', 'T9', 'T10']
    
    # 计算y轴的范围
    min_sim = similarities.min()
    # 将最小值向下取整到0.01
    y_min = np.floor(min_sim * 100 - 5) / 100
    
    # 创建柱状图
    fig = go.Figure(data=[
        go.Bar(
            x=labels,
            y=similarities,
            text=np.round(similarities, 4),  # 显示4位小数的具体数值
            textposition='auto',
            hovertemplate='Embedding: %{x}<br>Similarity: %{y:.4f}<extra></extra>',
            marker_color='rgb(55, 83, 109)'
        )
    ])
    
    # 设置图表布局
    fig.update_layout(
        title=dict(
            text='Embedding Similarities with T6',
            x=0.5,
            font=dict(size=16)
        ),
        xaxis=dict(
            title='Embeddings',
            tickangle=45 if len(similarities) > 10 else 0
        ),
        yaxis=dict(
            title='Cosine Similarity',
            range=[y_min, 1.0],  # 设置y轴范围，突出高相似度区间
            tickformat='.3f'  # 显示3位小数
        ),
        width=max(600, len(similarities) * 50),  # 根据数据点数量自适应宽度
        height=500,
        showlegend=False
    )
    
    # # 添加一条参考线在y_min处
    # fig.add_hline(
    #     y=y_min,
    #     line_dash="dash",
    #     line_color="red",
    #     annotation_text=f"Min: {y_min:.3f}",
    #     annotation_position="bottom right"
    # )
    
    # 在Jupyter notebook中显示图表
    fig.show()

In [None]:
print("languages:")
for i in range(len(all_possible_lang)):
    print(i+1, all_possible_lang[i])

indexs = [0, 1, 2, 3, 4, 6, 7, 8, 9]
plot_embedding_similarity_bar(CLIP_emb[indexs], CLIP_emb[5])

In [None]:
print("languages:")
for i in range(len(all_possible_lang)):
    print(i, all_possible_lang[i])

indexs = [0, 1, 2, 3, 4, 6, 7, 8, 9]
plot_embedding_similarity_bar(onehot_emb[indexs], onehot_emb[5])

In [None]:
CLIP_3_emb = np.load('/home/zhaoyixiu/ISR_project/ACT/output_lang_emb_T459.npy')
onehot_3_emb = np.load('/home/zhaoyixiu/ISR_project/ACT/output_lang_emb_onehot-T459.npy')

task_ids = [3, 4, 5, 8]
plot_embedding_similarity(CLIP_3_emb[task_ids], 0.7, axis_labels=['T4', 'T5', 'T6', 'T9'])
plot_embedding_similarity(onehot_3_emb[task_ids], 0.7, axis_labels=['T4', 'T5', 'T6', 'T9'])

indexs = [3, 4, 8]
plot_embedding_similarity_bar(CLIP_3_emb[indexs], CLIP_3_emb[5], labels=['T4', 'T5', 'T9'])
plot_embedding_similarity_bar(onehot_3_emb[indexs], onehot_3_emb[5], labels=['T4', 'T5', 'T9'])