In [2]:
import os
import pickle
import numpy as np
import rdflib
import torch
from pykeen.models import DistMult
from pykeen.pipeline import pipeline
from pykeen.triples import TriplesFactory
from sklearn.cluster import KMeans
from pykeen.evaluation import RankBasedEvaluator
from sklearn.metrics import accuracy_score, adjusted_rand_score
import plotly.express as px
from sklearn.metrics.pairwise import cosine_similarity

os.environ['LOKY_MAX_CPU_COUNT'] = '8'

try:
    import torch_directml
except ImportError:
    torch_directml = None

try:
    import matplotlib.pyplot as plt
except ImportError:
    plt = None

try:
    from sklearn.manifold import TSNE
    from sklearn.decomposition import PCA
except ImportError:
    TSNE = None
    PCA = None

  from .autonotebook import tqdm as notebook_tqdm


In [1]:
def train_graph_embeddings(rdf_file_path: str, save_dir: str = "genshin_embedding_results"):
    print(f"Загрузка данных")
    g = rdflib.Graph()
    g.parse(rdf_file_path, format="xml")
    print(f"Загружено триплетов: {len(g)}")
    triples = []
    entity_types = {}
    for s, p, o in g:
        s_str, p_str, o_str = str(s), str(p), str(o)
        triples.append([s_str, p_str, o_str])
        if p_str == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type":
            entity_types.setdefault(s_str, set()).add(o_str)
    triples = np.array(triples, dtype=str)

    tf = TriplesFactory.from_labeled_triples(triples)
    training, testing = tf.split([0.8, 0.2], random_state=42)
    print(f"Сущностей: {len(tf.entity_to_id)}, отношений: {len(tf.relation_to_id)}")
    print("Обучение")

    os.makedirs(save_dir, exist_ok=True)
    device = "cpu"
    result = pipeline(
        training=training,
        testing=testing,
        model="DistMult",
        model_kwargs=dict(embedding_dim=300),
        training_kwargs=dict(num_epochs=300),
        random_seed=42,
        device=device,
    )
    print("Обучение завершено")

    result.save_to_directory(save_dir)
    tf_path = os.path.join(save_dir, "triples_factory.pkl")
    with open(tf_path, "wb") as f:
        pickle.dump(tf, f)
    types_path = os.path.join(save_dir, "entity_types.pkl")
    with open(types_path, "wb") as f:
        pickle.dump(entity_types, f)
    training_path = os.path.join(save_dir, "training.pkl")
    with open(training_path, "wb") as f:
        pickle.dump(training, f)
    testing_path = os.path.join(save_dir, "testing.pkl")
    with open(testing_path, "wb") as f:
        pickle.dump(testing, f)
    return result.model, tf, entity_types, training, testing


def load_saved_model(directory: str = "genshin_embedding_results"):
    from pykeen.models import DistMult as DistMultModel
    model = DistMultModel.load(os.path.join(directory, "model.pkl"))
    tf_path = os.path.join(directory, "triples_factory.pkl")
    with open(tf_path, "rb") as f:
        tf = pickle.load(f)
    types_path = os.path.join(directory, "entity_types.pkl")
    entity_types = {}
    if os.path.exists(types_path):
        with open(types_path, "rb") as f:
            entity_types = pickle.load(f)
    training_path = os.path.join(directory, "training.pkl")
    with open(training_path, "rb") as f:
        training = pickle.load(f)
    testing_path = os.path.join(directory, "testing.pkl")
    with open(testing_path, "rb") as f:
        testing = pickle.load(f)
    return model, tf, entity_types, training, testing


def get_entity_embedding(entity_uri: str, model, tf):
    try:
        entity_id = tf.entity_to_id[entity_uri]
        embedding = model.entity_representations[0](indices=torch.tensor([entity_id])).detach().numpy()[0]
        return embedding
    except KeyError:
        return None


def get_relation_embedding(relation_uri: str, model, tf):
    try:
        relation_id = tf.relation_to_id[relation_uri]
        embedding = model.relation_representations[0](indices=torch.tensor([relation_id])).detach().numpy()[0]
        return embedding
    except KeyError:
        return None


def find_similar_entities(entity_uri: str, model, tf, entity_types: dict | None = None, top_k: int = 5):
    entity_embedding = get_entity_embedding(entity_uri, model, tf)
    if entity_embedding is None:
        return []
    all_entity_ids = torch.arange(len(tf.entity_to_id))
    all_embeddings = model.entity_representations[0](indices=all_entity_ids).detach().numpy()
    similarities = cosine_similarity([entity_embedding], all_embeddings)[0]
    top_indices = np.argsort(similarities)[::-1][1:]
    results = []
    id_to_entity = {v: k for k, v in tf.entity_to_id.items()}
    filtered = []
    if entity_types and entity_uri in entity_types:
        target_types = entity_types.get(entity_uri, set())
        for idx in top_indices:
            uri = id_to_entity[idx]
            types = entity_types.get(uri, set())
            if target_types & types:
                filtered.append(idx)
            if len(filtered) >= top_k:
                break
    else:
        filtered = list(top_indices[:top_k])
    i = 0
    while len(filtered) < top_k and i < len(top_indices):
        idx = top_indices[i]
        if idx not in filtered:
            filtered.append(idx)
        i += 1
    for idx in filtered[:top_k]:
        similar_entity = id_to_entity[idx]
        similarity = similarities[idx]
        results.append((similar_entity, similarity))
    return results


def reduce_embeddings(embeddings, method: str = "tsne", random_state: int = 42, perplexity: int = 30):
    if method == "pca":
        if PCA is None:
            raise ImportError("Скрипт запущен без scikit-learn. Установи scikit-learn для PCA.")
        reducer = PCA(n_components=2, random_state=random_state)
    else:
        if TSNE is None:
            raise ImportError("Скрипт запущен без scikit-learn. Установи scikit-learn для t-SNE.")
        n_samples = len(embeddings)
        if n_samples < 2:
            raise ValueError("Слишком мало точек для t-SNE: нужно хотя бы 2 сущности.")
        safe_perplexity = min(perplexity, n_samples - 1)
        reducer = TSNE(
            n_components=2,
            random_state=random_state,
            perplexity=safe_perplexity,
            init="pca",
            learning_rate="auto",
        )

    return reducer.fit_transform(embeddings)


def collect_all_embeddings(model, tf):
    all_entity_ids = torch.arange(len(tf.entity_to_id))
    embeddings = model.entity_representations[0](indices=all_entity_ids).detach().numpy()
    id_to_entity = {v: k for k, v in tf.entity_to_id.items()}
    labels = [id_to_entity[i] for i in range(len(embeddings))]
    return embeddings, labels


def plot_embeddings_2d(points_2d, labels, title: str = "Embedding map", max_labels: int = 75, save_path: str | None = None, figsize=(12, 10), cluster_labels=None):
    plt.figure(figsize=figsize)
    if cluster_labels is not None:
        unique_clusters = len(np.unique(cluster_labels))
        cmap = plt.cm.get_cmap('viridis', unique_clusters)
        scatter = plt.scatter(points_2d[:, 0], points_2d[:, 1], s=12, alpha=0.6, c=cluster_labels, cmap=cmap)
        plt.colorbar(scatter, label='Cluster')
    else:
        plt.scatter(points_2d[:, 0], points_2d[:, 1], s=12, alpha=0.6)
    def _shorten(uri: str) -> str:
        if "#" in uri:
            return uri.split("#")[-1]
        return uri

    for i, label in enumerate(labels[:max_labels]):
        short = _shorten(label)
        plt.annotate(short, (points_2d[i, 0], points_2d[i, 1]), fontsize=8, alpha=0.7)
    plt.title(title)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=200)
        print(f"График сохранён в {save_path}")
    else:
        plt.show()


def visualize_embeddings(model: DistMult, triples_factory: TriplesFactory, method: str = "tsne", max_labels: int = 75, save_path: str = "genshin_embedding_results/embeddings.png", html_path: str = "genshin_embedding_results/interactive_embeddings.html", perplexity: int = 30, max_points: int = 1200, random_state: int = 42, num_clusters: int | None = None):
    embeddings, labels = collect_all_embeddings(model, triples_factory)
    if max_points and len(embeddings) > max_points:
        rng = np.random.default_rng(random_state)
        indices = rng.choice(len(embeddings), size=max_points, replace=False)
        embeddings = embeddings[indices]
        labels = [labels[i] for i in indices]
        print(f"Всего сущностей: {len(triples_factory.entity_to_id)}")
    else:
        print(f"Всего сущностей: {len(labels)}.")
    points_2d = reduce_embeddings(embeddings, method=method, perplexity=perplexity)

    cluster_labels = None
    if num_clusters:
        print(f"Кластеризация с {num_clusters} кластерами")
        kmeans = KMeans(n_clusters=num_clusters, random_state=random_state)
        cluster_labels = kmeans.fit_predict(points_2d)

    plot_embeddings_2d(
        points_2d,
        labels,
        title=f"Genshin Ontology Embeddings",
        max_labels=max_labels,
        save_path=save_path,
        cluster_labels=cluster_labels,
    )

    def _shorten(uri: str) -> str:
        if "#" in uri:
            return uri.split("#")[-1]
        return uri

    short_labels = [_shorten(label) for label in labels]
    fig = px.scatter(
        x=points_2d[:, 0],
        y=points_2d[:, 1],
        color=cluster_labels,
        hover_name=short_labels,
        title=f"Interactive Genshin Ontology Embeddings",
        labels={'color': 'Cluster' if cluster_labels is not None else None}
    )

    fig.update_traces(marker=dict(size=12, opacity=0.6))
    fig.write_html(html_path)


def find_rdf_file():
    rdf_path = "data/genshin_ontology_clean.rdf"
    if os.path.exists(rdf_path):
        return rdf_path
    for file in os.listdir("."):
        if file.endswith(".rdf"):
            return file
    raise FileNotFoundError("RDF файл не найден")


def evaluate_model(model, training, testing):
    evaluator = RankBasedEvaluator()
    metrics = evaluator.evaluate(model, mapped_triples=testing.mapped_triples, additional_filter_triples=training.mapped_triples)

    mr = metrics.get_metric('mean_rank')
    mrr = metrics.get_metric('mean_reciprocal_rank')
    hits_1 = metrics.get_metric('hits@1')
    hits_3 = metrics.get_metric('hits@3')
    hits_10 = metrics.get_metric('hits@10')

    print(f"MR: {mr:.4f}")
    print(f"MRR: {mrr:.4f}")
    print(f"Hits@1: {hits_1:.4f}, \nHits@3: {hits_3:.4f}, \nHits@10: {hits_10:.4f}")


def perform_clustering(model, tf, entity_types, num_clusters=5, random_state=42):
    embeddings, labels = collect_all_embeddings(model, tf)
    true_labels = []
    type_to_id = {t: i for i, t in enumerate(set.union(*entity_types.values()))} if entity_types else {}
    for label in labels:
        types = entity_types.get(label, set())
        true_label = list(types)[0] if types else 'Unknown'
        true_labels.append(type_to_id.get(true_label, -1))
    points_2d = reduce_embeddings(embeddings, method="tsne")
    plot_embeddings_2d(points_2d, labels, title="Expected Clusters (by Type)", cluster_labels=np.array(true_labels))

    kmeans = KMeans(n_clusters=num_clusters, random_state=random_state)
    cluster_labels = kmeans.fit_predict(points_2d)
    plot_embeddings_2d(points_2d, labels, title="KMeans Clusters", cluster_labels=cluster_labels)

    ari = adjusted_rand_score(true_labels, cluster_labels)
    print(f"ARI: {ari:.4f}")

NameError: name 'DistMult' is not defined

In [None]:
def main():
    save_dir = "genshin_embedding_results"

    try:
        rdf_file = find_rdf_file()
    except FileNotFoundError as e:
        print(f"Ошибка: {e}")
        return

    if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "model.pkl")):
        model, tf, entity_types, training, testing = load_saved_model(save_dir)
    else:
        model, tf, entity_types, training, testing = train_graph_embeddings(rdf_file, save_dir)

    similar = find_similar_entities("http://example.org/ontology/kirara", model, tf, entity_types=entity_types, top_k=3)
    if similar:
        print(" Похожие сущности:")
        for sim_entity, similarity in similar:
            sim_name = sim_entity.split("/")[-1]
            print(f" • {sim_name}: {similarity:.4f}")
    similar = find_similar_entities("http://example.org/ontology/hilichurl", model, tf, entity_types=entity_types,
                                    top_k=3)
    if similar:
        print(" Похожие сущности:")
        for sim_entity, similarity in similar:
            sim_name = sim_entity.split("/")[-1]
            print(f" • {sim_name}: {similarity:.4f}")
    similar = find_similar_entities("http://example.org/ontology/favonius_sword", model, tf, entity_types=entity_types,
                                    top_k=3)
    if similar:
        print(" Похожие сущности:")
        for sim_entity, similarity in similar:
            sim_name = sim_entity.split("/")[-1]
            print(f" • {sim_name}: {similarity:.4f}")

    visualize_embeddings(model, tf, method="tsne", max_labels=75, save_path="genshin_embedding_results/embeddings.png")

    evaluate_model(model, training, testing)

    perform_clustering(model, tf, entity_types, num_clusters=5)

In [None]:
main()