In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import networkx as nx

import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict, deque

import torch
import torch.nn as nn

from melon_clustering import PatternExtractor

class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        # First GCN layer
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        # Second GCN layer
        x = self.conv2(x, edge_index)
        return x

class Node:
    def __init__(self, word, id):
        self.word = word
        self.id = id
        self.count = 0
        self.children = {}

class PatternExtractorWithGNN(PatternExtractor):
    def __init__(self):
        super().__init__()
        self.graph = nx.DiGraph()  # Graph to store word relations
        self.node_embeddings = {}  # To store GCN embeddings

    def set_up_digraph(self, node_id, parsed_nodes):
        if node_id not in parsed_nodes:
            parsed_nodes.add(node_id)
            for child_node in self.id_to_node.get(node_id).children.values():
                self.graph.add_edge(node_id, child_node.id)
                self.set_up_digraph(child_node.id, parsed_nodes)

    def build_graph(self):
        # Convert the tree into a graph structure for GNN training
        edge_index = torch.tensor(list(self.graph.edges)).t().contiguous()
        return edge_index

    def initialize_node_features(self, feature_dim=100):
        # Initialize node features randomly
        num_nodes = self.node_counter
        node_features = torch.randn((num_nodes, feature_dim), requires_grad=True)
        return node_features

    def train_gnn(self, edge_index, node_features, hidden_dim=64, output_dim=100, epochs=200):
        # Train the GNN model
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = GCN(input_dim=node_features.shape[1], hidden_dim=hidden_dim, output_dim=output_dim).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

        # Move data to device
        edge_index = edge_index.to(device)
        node_features = node_features.to(device)

        model.train()
        for epoch in range(epochs):
            optimizer.zero_grad()
            out = model(node_features, edge_index)
            loss = F.mse_loss(out, node_features)  # Unsupervised loss
            loss.backward()
            optimizer.step()
            if epoch % 10 == 0:
                print(f'Epoch {epoch}, Loss: {loss.item()}')

        return out.detach().cpu().numpy()

    def generate_sentence_embeddings(self, gnn_node_embeddings, sentence, morphology, steepness=1.0):
        # Generate sentence embeddings based on the GNN node embeddings
        words = sentence.lower().split()
        node_path = deque([0])
        key_word_index = words.index(morphology.lower())
        words_before = words[:key_word_index]
        words_after = words[key_word_index + 1:]

        current_node = self.preceding_tree
        for word in words_before[::-1]:
            if word in current_node.children:
                current_node = current_node.children[word]
                node_path.appendleft(current_node.id)

        current_node = self.following_tree
        for word in words_after:
            if word in current_node.children:
                current_node = current_node.children[word]
                node_path.append(current_node.id)

        return self.compute_weighted_sentence_embedding(gnn_node_embeddings, node_path, steepness=steepness)

    def compute_weighted_sentence_embedding(self, gnn_node_embeddings, sentence_path, steepness=1.0):
        # Compute the weighted sentence embedding based on the node embeddings
        weighted_embedding = np.zeros(gnn_node_embeddings.shape[1])
        for i, node_id in enumerate(sentence_path):
            weight = 1 - (1 / (1 + np.exp(-steepness * (i + 1))))
            weighted_embedding += weight * gnn_node_embeddings[node_id]
        return weighted_embedding / len(sentence_path)

extractor = PatternExtractorWithGNN()
sentences_dict = {
    'erinnere': [
        'Ich erinnere mich gut',
        'ich erinnere mich nicht',
        'nochmal erinnere ich mich nicht',
        'ich erinnere mich an das Treffen gestern',
        'erinnere mich bitte daran, morgen früh aufzustehen',
        'ich erinnere mich an die schöne Zeit'
    ],
    'erinnert': [
        'wie erinnert man sich nochmal',
        'wo erinnert man sich nochmal',
        'vielleicht erinnert man sich dann nochmal',
        'erinnert mich an meine Kindheit',
        'sie erinnert sich nicht mehr an das Gespräch',
        'er erinnert sich nicht gerne an die Vergangenheit'
    ],
    'erinnerte': [
        'er erinnerte sich plötzlich an den Vorfall',
        'ich erinnerte mich an mein erstes Auto',
        'sie erinnerte sich, dass sie etwas vergessen hatte',
        'erinnerte ich mich an den alten Freund',
        'er erinnerte sich an die Worte seiner Mutter',
        'ich erinnerte mich an den letzten Urlaub'
    ]
}
extractor.initialize(sentences_dict, overlap_threshold = 0.4)
extractor.print_trees()

extractor.set_up_digraph(0, set())
extractor.set_up_digraph(1, set())
# Step 3: Build graph from the optimized tree
edge_index = extractor.build_graph()

# Step 4: Initialize node features
node_features = extractor.initialize_node_features(feature_dim=100)

# Step 5: Train GNN and get embeddings
gnn_node_embeddings = extractor.train_gnn(edge_index, node_features)

# Step 6: Perform clustering and visualize

for sigmoid_steepness in [0.5, 1, 2]:
    sentence_embeddings = []
    sentence_list = []
    for morphology, sentences in sentences_dict.items():
        for sentence in sentences:
            embedding = extractor.generate_sentence_embeddings(gnn_node_embeddings, sentence, morphology, steepness=sigmoid_steepness)
            sentence_embeddings.append(embedding)
            sentence_list.append(sentence)

    # Convert the list of embeddings to a NumPy array
    sentence_embeddings = np.array(sentence_embeddings)

    # Step 7: Cluster the embeddings
    clusters = extractor.cluster_embeddings(sentence_embeddings, n_clusters=3)

    # Step 8: Apply dimensionality reduction and visualize clusters
    for method in ['pca', 'tsne']:
        reduced_embeddings = extractor.reduce_dimensionality(sentence_embeddings, method=method)
        extractor.visualize_clusters(reduced_embeddings, clusters, sentence_list, method=method)
