In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import networkx as nx
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict, deque

import torch
import torch.nn as nn

from melon_clustering import PatternExtractor

class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        # First GCN layer
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        # Second GCN layer
        x = self.conv2(x, edge_index)
        return x

class PatternExtractorWithGNN(PatternExtractor):
    def __init__(self):
        super().__init__()
        self.graph = nx.DiGraph()  # Graph to store word relations
        self.node_embeddings = {}  # To store GCN embeddings

    def initialize_node_features(self, feature_dim=100):
        # Randomly initialize node features
        num_nodes = self.node_counter
        node_features = torch.randn((num_nodes, feature_dim), requires_grad=True)
        return node_features

    def train_gnn(self, edge_index, node_features, hidden_dim=64, output_dim=100, epochs=200):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = GCN(input_dim=node_features.shape[1], hidden_dim=hidden_dim, output_dim=output_dim).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

        # Move data to device
        edge_index = edge_index.to(device)
        node_features = node_features.to(device)

        model.train()
        for epoch in range(epochs):
            optimizer.zero_grad()
            out = model(node_features, edge_index)
            loss = F.mse_loss(out, node_features)  # Unsupervised loss
            loss.backward()
            optimizer.step()
            if epoch % 10 == 0:
                print(f'Epoch {epoch}, Loss: {loss.item()}')

        return out.detach().cpu().numpy()

    def generate_sentence_embeddings(self, gnn_node_embeddings, sentence, morphology, steepness=None):
        words = sentence.lower().split()
        node_path = deque([0])
        key_word_index = words.index(morphology.lower())
        words_before = words[:key_word_index]
        words_after = words[key_word_index + 1:]

        current_node = self.preceding_tree
        for word in words_before[::-1]:
            if word in current_node.children:
                current_node = current_node.children[word]
                node_path.appendleft(current_node.id)

        current_node = self.following_tree
        for word in words_after:
            if word in current_node.children:
                current_node = current_node.children[word]
                node_path.append(current_node.id)

        return self.compute_weighted_sentence_embedding(gnn_node_embeddings, node_path, steepness=steepness)

    def compute_weighted_sentence_embedding(self, gnn_node_embeddings, sentence_path, steepness=None):
        weighted_embedding = np.zeros(gnn_node_embeddings.shape[1])
        for i, node_id in enumerate(sentence_path):
            weight = 1 - (1 / (1 + np.exp(-steepness * (i + 1))))
            weighted_embedding += weight * gnn_node_embeddings[node_id]
        return weighted_embedding / len(sentence_path)

    def cluster_embeddings(self, embeddings, n_clusters=3):
        """
        Apply K-Means clustering to the GNN-generated embeddings.
        """
        kmeans = KMeans(n_clusters=n_clusters)
        clusters = kmeans.fit_predict(embeddings)
        return clusters

    def reduce_dimensionality(self, embeddings, method='pca', n_components=2):
        """
        Reduce dimensionality of embeddings for visualization using PCA, t-SNE.
        """
        if method == 'pca':
            reducer = PCA(n_components=n_components)
        elif method == 'tsne':
            perplexity = min(30, len(embeddings) - 1)  # Ensure perplexity is smaller than the number of samples
            reducer = TSNE(n_components=n_components, perplexity=perplexity)

        reduced_embeddings = reducer.fit_transform(embeddings)
        return reduced_embeddings

    def visualize_clusters(self, reduced_embeddings, clusters, sentence_list=None, method='pca', appendix = None):
        """
        Visualize clusters using a 2D scatter plot.
        """
        plt.figure(figsize=(15, 15))
        scatter = plt.scatter(reduced_embeddings[:, 0], reduced_embeddings[:, 1], c=clusters, cmap='viridis')

        # Annotate sentences for better interpretability
        if sentence_list:
            for i, sentence in enumerate(sentence_list):
                plt.annotate(sentence, (reduced_embeddings[i, 0], reduced_embeddings[i, 1]), fontsize=8, alpha=0.7)

        plt.title(f"Sentence Clustering Visualization ({method}{' ' + str(appendix) if appendix else ''})")
        plt.xlabel(f"{method.upper()} Component 1")
        plt.ylabel(f"{method.upper()} Component 2")
        plt.show()


def print_central_and_edge_sentences(sentence_embeddings, clusters, sentence_list, n=3):
    """
    Prints n central and n edge sentences for each cluster.
    """
    # Get the centroids from the kmeans clustering
    kmeans = KMeans(n_clusters=len(np.unique(clusters)))
    kmeans.fit(sentence_embeddings)
    centroids = kmeans.cluster_centers_

    for cluster_id in range(len(centroids)):
        # Get the indices of sentences in this cluster
        cluster_indices = [i for i, cluster in enumerate(clusters) if cluster == cluster_id]
        cluster_sentences = [sentence_list[i] for i in cluster_indices]
        cluster_embeddings = sentence_embeddings[cluster_indices]

        # Compute the distance of each sentence in the cluster to the centroid
        distances = np.linalg.norm(cluster_embeddings - centroids[cluster_id], axis=1)

        # Get indices of the n closest and n farthest sentences
        closest_indices = np.argsort(distances)[:n]
        farthest_indices = np.argsort(distances)[-n:]

        print(f"\nCluster {cluster_id + 1}:")
        print(f"Centroid: {centroids[cluster_id]}")

        print("\n  Central Sentences (closest to the centroid):")
        for idx in closest_indices:
            print(f"  - {cluster_sentences[idx]}")

        print("\n  Edge Sentences (farthest from the centroid):")
        for idx in farthest_indices:
            print(f"  - {cluster_sentences[idx]}")


# Create an instance of the PatternExtractorWithGNN
extractor = PatternExtractorWithGNN()

# Step 1: Build the graph from sentences
sentences_dict = {
    'erinnere': [
        'Ich erinnere mich gut',
        'ich erinnere mich nicht',
        'nochmal erinnere ich mich nicht',
        'ich erinnere mich an das Treffen gestern',
        'erinnere mich bitte daran, morgen früh aufzustehen',
        'ich erinnere mich an die schöne Zeit'
    ],
    'erinnert': [
        'wie erinnert man sich nochmal',
        'wo erinnert man sich nochmal',
        'vielleicht erinnert man sich dann nochmal',
        'erinnert mich an meine Kindheit',
        'sie erinnert sich nicht mehr an das Gespräch',
        'er erinnert sich nicht gerne an die Vergangenheit'
    ],
    'erinnerte': [
        'er erinnerte sich plötzlich an den Vorfall',
        'ich erinnerte mich an mein erstes Auto',
        'sie erinnerte sich, dass sie etwas vergessen hatte',
        'erinnerte ich mich an den alten Freund',
        'er erinnerte sich an die Worte seiner Mutter',
        'ich erinnerte mich an den letzten Urlaub'
    ]
}

# language = 'de'
# from melon_clustering import load_sentences, SENTENCES_DIR
# sentences_dict = load_sentences(SENTENCES_DIR / 'erinnern.yaml', language)

extractor.initialize(sentences_dict, overlap = 0)
extractor.print_trees()
# # Build graph from the sentences
# edge_index = extractor.build_graph(sentences_dict)

# # Step 2: Initialize node features
# node_features = extractor.initialize_node_features(feature_dim=100)

# # Step 3: Train GNN and get embeddings
# gnn_node_embeddings = extractor.train_gnn(edge_index, node_features)

# # Rest of your code
# for steepness in [0.5, 1, 2, 3]:
#     sentence_embeddings = []
#     sentence_list = []
#     for morphology, sentences in sentences_dict.items():
#         for sentence in sentences:
#             embedding = extractor.generate_sentence_embeddings(gnn_node_embeddings, sentence, morphology, steepness=steepness)
#             sentence_embeddings.append(embedding)
#             sentence_list.append(sentence)

#     # Convert the list of embeddings to a NumPy array
#     sentence_embeddings = np.array(sentence_embeddings)

#     # Step 5: Cluster the embeddings
#     clusters = extractor.cluster_embeddings(sentence_embeddings, n_clusters=3)

#     # Step 6: Apply dimensionality reduction methods (PCA, t-SNE)
#     for method in ['pca', 'tsne']:
#         reduced_embeddings = extractor.reduce_dimensionality(sentence_embeddings, method=method)

#         # Step 7: Visualize clusters using different methods
#         extractor.visualize_clusters(reduced_embeddings, clusters, method=method, appendix=steepness)

#     # Step 8: Print 3 central and 3 edge sentences for each cluster
#     print_central_and_edge_sentences(sentence_embeddings, clusters, sentence_list, n=3)


optimize <START>
optimize <END>
Preceding Tree (before <MASK>):
<ROOT> (count: 0, id: 0)
  ich (count: 6, id: 2)
    <start> (count: 6, id: 3)
  nochmal (count: 1, id: 9)
    <start> (count: 1, id: 10)
  <start> (count: 3, id: 20)
  wie (count: 1, id: 31)
    <start> (count: 1, id: 32)
  wo (count: 1, id: 37)
    <start> (count: 1, id: 38)
  vielleicht (count: 1, id: 39)
    <start> (count: 1, id: 40)
  sie (count: 2, id: 47)
    <start> (count: 2, id: 48)
  er (count: 3, id: 56)
    <start> (count: 3, id: 57)

Following Tree (after <MASK>):
<ROOT> (count: 0, id: 1)
  mich (count: 8, id: 4)
    gut (count: 1, id: 5)
      <end> (count: 1, id: 6)
    nicht (count: 1, id: 7)
      <end> (count: 1, id: 8)
    an (count: 5, id: 15)
      das (count: 1, id: 16)
        treffen (count: 1, id: 17)
          gestern (count: 1, id: 18)
            <end> (count: 1, id: 19)
      die (count: 1, id: 27)
        schöne (count: 1, id: 28)
          zeit (count: 1, id: 29)
            <end> (count: 1