In [22]:
import plotly.graph_objects as go
import numpy as np
import os

# モデルのパスや埋め込み範囲などのパラメータの設定
model_path = '/home/lab/eight/KGE-HAKE/models/AdjustHAKE_wn18rr_0'
embedding_range = 0.01
entities_dict_text_file = '/home/lab/eight/KGE-HAKE/data/wn18rr_text/entities.dict'
entities_dict_number_file = '/home/lab/eight/KGE-HAKE/data/wn18rr/entities.dict'
train_file = '/home/lab/eight/KGE-HAKE/data/wn18rr/train.txt'

# エンティティの埋め込みをロード
entity_embedding = np.load(os.path.join(model_path, 'entity_embedding.npy'))
embedding_size = entity_embedding.shape[0]

def load_entities_dict_number(file_path):
    entities_dict_number = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split('\t')
            entities_dict_number[parts[1]] = int(parts[0])
    return entities_dict_number

# エンティティ名の辞書読み込み
entities_dict_number = load_entities_dict_number(entities_dict_number_file)

def extract_relationships(file_path, relation_type):
    relationships = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split('\t')
            if parts[1] == relation_type:
                relationships.append((parts[0], parts[2]))  # head と tail エンティティを抽出
    return relationships

def calculate_mod_values(head_tail_pairs, entities_dict_number, entity_embedding, embedding_size):
    mod_diffs = []
    for head_id_str, tail_id_str in head_tail_pairs:
        head_id = entities_dict_number.get(head_id_str, -1)
        tail_id = entities_dict_number.get(tail_id_str, -1)
        
        if head_id == -1 or tail_id == -1:
            continue  # Skip if entity ID is not found
        
        if head_id >= embedding_size or tail_id >= embedding_size:
            continue  # Skip if entity ID is out of bounds
        
        head_entity = entity_embedding[head_id]
        tail_entity = entity_embedding[tail_id]
        
        _, head_mod = np.split(head_entity, 2)
        _, tail_mod = np.split(tail_entity, 2)
        
        head_mod = np.abs(head_mod)
        tail_mod = np.abs(tail_mod)
        
        mean_head_mod = np.mean(head_mod)
        mean_tail_mod = np.mean(tail_mod)
        
        mod_diff = mean_head_mod - mean_tail_mod
        if np.abs(mod_diff) < 0.4:
            mod_diffs.append(mod_diff)
    
    return mod_diffs

# 関係のリスト
relation_types = [
    '_hypernym',
    # '_derivationally_related_form',
    '_instance_hypernym',
    # '_also_see',
    '_member_meronym',
    # '_synset_domain_topic_of',
    # '_has_part',
    # '_member_of_domain_usage',
    # '_member_of_domain_region',
    # '_verb_group',
    '_similar_to'
]

# フィギュアの作成
fig = go.Figure()

# 各関係の箱ひげ図を追加
for relation in relation_types:
    pairs = extract_relationships(train_file, relation)
    mod_diffs = calculate_mod_values(pairs, entities_dict_number, entity_embedding, embedding_size)
    fig.add_trace(
        go.Box(y=mod_diffs, name=relation, boxpoints=False, jitter=0.3, pointpos=-1.8)
    )

# レイアウトの更新
fig.update_layout(
    title={
        'text': "relation ごとの head と tail エンティティの modulus 差分の箱ひげ図 (head relation tail [ex mammal hypernym dog])",
        'y':0.9,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'
    },
    yaxis_title="head と tail の modulus の差",
    xaxis_title="relation",
    showlegend=True,
    width=1200,
    height=600
)

# プロットを表示
fig.show()


In [24]:
import plotly.graph_objects as go
import numpy as np
import os

# モデルのパスや埋め込み範囲などのパラメータの設定
model_path = '/home/lab/eight/KGE-HAKE/models/HAKE_wn18rr_0'
embedding_range = 0.01
entities_dict_text_file = '/home/lab/eight/KGE-HAKE/data/wn18rr_text/entities.dict'
entities_dict_number_file = '/home/lab/eight/KGE-HAKE/data/wn18rr/entities.dict'
train_file = '/home/lab/eight/KGE-HAKE/data/wn18rr/train.txt'
# 対象の値とファイルパスを指定して関数を実行
# target_values = [["genus.n.02", 'red'], ["dog.n.01", 'blue'], ["plant_genus.n.01", 'green'], ["mammal_genus.n.01", 'yellow'], ["bird_genus.n.01", 'purple'], ["corgi.n.01", 'gray']]
target_values = [["genus.n.02", 'red'], ["mammal_genus.n.01", 'yellow'], ["dog.n.01", 'blue']]

# エンティティの埋め込みをロード
entity_embedding = np.load(os.path.join(model_path, 'entity_embedding.npy'))

def load_entities_dict_number(file_path):
    entities_dict_number = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split('\t')
            entities_dict_number[int(parts[0])] = parts[1]
    return entities_dict_number

# エンティティ名の辞書読み込み
entities_dict_text = load_entities_dict_number(entities_dict_text_file)
entities_dict_number = load_entities_dict_number(entities_dict_number_file)

def extract_and_search_data(target_value, train_file_path, entities_dict_text_file, entities_dict_number_file):
    def load_entities_dict(file_path):
        entities_dict = {}
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('\t')
                entities_dict[int(parts[0])] = parts[1]
        return entities_dict

    entities_dict_text = load_entities_dict(entities_dict_text_file)
    entities_dict_number = load_entities_dict(entities_dict_number_file)

    def text_to_num(target_value):
        for index, value in entities_dict_text.items():
            if value == target_value:
                return entities_dict_number[index]

    num = text_to_num(target_value)

    def extract_data(file_path, target_value):
        target_value = "_hypernym\t" + target_value
        extracted_data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if target_value in line:
                    parts = line.strip().split('\t')
                    extracted_data.append(parts[0])
        return extracted_data

    def search_num(extracted_data, file_path):
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                for i in extracted_data:
                    if i in line:
                        parts = line.strip().split('\t')
                        data.append(parts[0])
        return data

    extracted_data = extract_data(train_file_path, num)
    searched_data = search_num(extracted_data, entities_dict_number_file)

    return searched_data

def calculate_mod_values(extracted_data, entity_embedding):
    mod_values = []
    for entity_id_str in extracted_data:
        entity_id = int(entity_id_str)
        entity = entity_embedding[entity_id]
        _, mod = np.split(entity, 2)
        mod = np.abs(mod)
        mean_mod = np.mean(mod)
        mod_values.append(mean_mod)
    print(mod_values, len(mod_values))
    return mod_values

box_plot_data = []

x_axis_labels = []

for target_value, target_color in target_values:
    extracted_data = extract_and_search_data(target_value, train_file, entities_dict_text_file, entities_dict_number_file)
    print(target_value, 'num:', len(extracted_data), extracted_data)

    mod_values = calculate_mod_values(extracted_data, entity_embedding)
    
    box_plot_data.append(mod_values)
    x_axis_labels.append(target_value)

# 箱ひげ図のレイアウト設定
fig = go.Figure()

for i, mod_values in enumerate(box_plot_data, start=1):
    fig.add_trace(go.Box(y=mod_values, name=f"Mod Values {i}", boxpoints=False, jitter=0.3, pointpos=-1.8))

fig.update_layout(
    title={
        'text': "Box Plot of Modulus Values (HAKE)",
        'y':0.9,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'
    },
    yaxis_title="Logarithm of Mod Values",
    xaxis_title="Entities",
    xaxis=dict(
        tickmode='array',
        tickvals=list(range(len(target_values))),
        ticktext=x_axis_labels
    ),
    showlegend=False,
    width=800,  # 横幅のサイズを指定
    height=600  # 縦幅のサイズを指定
)

# プロットを表示
fig.show()




genus.n.02 num: 19 ['21', '81', '265', '284', '427', '542', '692', '716', '883', '1244', '1929', '2110', '4550', '9626', '10288', '13047', '17252', '17504', '23662']



Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.



IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices