In [None]:
from loguru import logger

from iqrah.morphology import QuranicArabicCorpus
from iqrah.graph import QuranGraphBuilder
from iqrah.quran_api import QuranAPIClient, fetch_quran
from iqrah.quran_api import fetch_quran

QURAN_MORPHOLOGY_DATASET = '../data/quran-morphology-v0.5.csv'
logger.remove()

# Initialize the corpus
corpus = QuranicArabicCorpus()
corpus.load_data(QURAN_MORPHOLOGY_DATASET)

# Initialize Quran client
client = QuranAPIClient(
    cache_dir="/home/samiisd/ws/iqrah/iqrah-knowledge-graph/.cache"
)
quran = await fetch_quran(
    client = client,
    language="en",
    words=True,
    show_progress=False,
    fields=["text_uthmani"],
    word_fields=["text_uthmani"],
)

# Create graph builder
builder = QuranGraphBuilder()

# Build full graph
# G = builder.build_graph(quran, corpus)

# Or build graph for specific chapters
chapters = quran.chapters[:1] + quran.chapters[105:]  # First and last 10 chapters
G = builder.build_graph(quran, corpus, chapters=chapters)

# Store the graph
# nx.write_graphml("graph_data_last_ones.graphml")
# G = nx.read_graphml("graph_data_last_ones.graphml")

In [None]:
a = quran.get_chapter(1)

In [None]:
import networkx as nx
import plotly.graph_objects as go
from IPython.display import display, clear_output
import ipywidgets as widgets
import json

class KnowledgeGraphApp:
    def __init__(self, G, chapter_ids, progress_file='progress.json'):
        self.G = G
        self.chapter_ids = chapter_ids
        self.edge_sequence = ['has_verse', 'has_word_instance', 'is_word', 'has_lemma', 'has_root']
        self.progress_file = progress_file
        self.subG = self.build_subgraph()
        self.question_history = []
        self.total_nodes = len(self.subG.nodes)
        self.node_type_colors = {
            'chapter': 'blue',
            'verse': 'green',
            'word_instance': 'red',
            'word': 'orange',
            'lemma': 'purple',
            'root': 'brown'
        }
        # Initialize node positions
        # self.pos = nx.spring_layout(self.subG, dim=3, seed=42)
        self.pos = nx.spring_layout(self.subG, seed=42)
        # Initialize the figure
        self.fig = go.FigureWidget()
        self.update_visualization(initial=True)
        # Initialize question counter
        self.question_count = 0
        # Initialize output widget
        self.output = widgets.Output()
        # Display the output widget
        display(self.output)
        # Load progress if available
        self.load_progress()

    def build_subgraph(self):
        def build_hierarchical_subgraph(G, start_node, edge_sequence):
            subG = nx.DiGraph()
            current_nodes = start_node if isinstance(start_node, list) else [start_node]
            for edge_type in edge_sequence:
                next_nodes = []
                for node in current_nodes:
                    for _, neighbor, data in G.out_edges(node, data=True):
                        if data.get('type') == edge_type:
                            subG.add_node(node, **G.nodes[node])
                            subG.add_node(neighbor, **G.nodes[neighbor])
                            subG.add_edge(node, neighbor, **data)
                            next_nodes.append(neighbor)
                current_nodes = next_nodes
            return subG
        subG = build_hierarchical_subgraph(self.G, self.chapter_ids, self.edge_sequence)
        # Initialize understanding levels
        for node in subG.nodes:
            subG.nodes[node]['understanding'] = 0  # 0 means not understood
        return subG

    def get_node_color(self, node):
        node_type = self.subG.nodes[node].get('type', 'unknown')
        return self.node_type_colors.get(node_type, 'gray')

    def get_node_size(self, node):
        understanding = self.subG.nodes[node].get('understanding', 0)
        return 1 + (understanding + 1) * 5  # Base size 10, increases with understanding

    def propagate_understanding(self, node, increment):
        visited = set()
        queue = [(node, increment)]

        while queue:
            current_node, current_increment = queue.pop(0)
            if current_node in visited:
                continue
            visited.add(current_node)

            # Update understanding with diminishing returns
            self.subG.nodes[current_node]['understanding'] += current_increment
            if self.subG.nodes[current_node]['understanding'] > 3:
                self.subG.nodes[current_node]['understanding'] = 3  # Cap at 3

            # Propagate to connected nodes based on edge types
            for neighbor in self.subG.successors(current_node):
                edge_data = self.subG.get_edge_data(current_node, neighbor)
                edge_type = edge_data.get('type', '')
                weight = edge_data.get('weight', 1)
                factor = self.get_propagation_factor(edge_type)
                queue.append((neighbor, current_increment * factor * weight))

            for neighbor in self.subG.predecessors(current_node):
                edge_data = self.subG.get_edge_data(neighbor, current_node)
                edge_type = edge_data.get('type', '')
                weight = edge_data.get('weight', 1)
                factor = self.get_propagation_factor(edge_type)
                queue.append((neighbor, current_increment * factor * weight))

    def get_propagation_factor(self, edge_type):
        # Define propagation factors based on edge type
        if edge_type in ['has_verse', 'has_word_instance']:
            return 0.2
        elif edge_type in ['is_word', 'has_lemma']:
            return 0.5
        elif edge_type == 'has_root':
            return 0.7
        else:
            return 0.1

    def select_next_question(self):
        # Group nodes by type
        nodes_by_type = {}
        for node in self.subG.nodes:
            node_type = self.subG.nodes[node].get('type', 'unknown')
            if node_type not in nodes_by_type:
                nodes_by_type[node_type] = []
            understanding = self.subG.nodes[node]['understanding']
            if understanding < 3 and node not in self.question_history:
                nodes_by_type[node_type].append(node)

        # Prioritize types
        type_priority = ['root', 'lemma', 'word', 'word_instance', 'verse', 'chapter']

        for node_type in type_priority:
            if node_type in nodes_by_type and nodes_by_type[node_type]:
                # Sort nodes by understanding and usefulness
                candidates = []
                for node in nodes_by_type[node_type]:
                    understanding = self.subG.nodes[node]['understanding']
                    usefulness = self.subG.degree(node)
                    candidates.append((usefulness, -understanding, node))
                candidates.sort(reverse=True)
                return candidates[0][2]  # Return the top candidate
        return None

    def generate_question(self, node):
        node_type = self.subG.nodes[node].get('type', 'unknown')
        identifier = node.split(':', 1)[1]
        if node_type == 'word':
            question = f"Translate the word '{identifier}'"
        elif node_type == 'lemma':
            question = f"Translate the lemma '{identifier}'"
        elif node_type == 'root':
            question = f"Translate the root '{identifier}'"
        elif node_type == 'word_instance':
            verse_key = self.subG.nodes[node].get('verse_key', 'unknown')
            position = self.subG.nodes[node].get('position', 'unknown')
            question = f"Translate the word at position {position} in verse {verse_key}"
        elif node_type == 'verse':
            question = f"Recite or translate verse '{identifier}'"
        elif node_type == 'chapter':
            question = f"Summarize or recite chapter '{identifier}'"
        else:
            question = f"Explain the {node_type} '{identifier}'"
        return question

    def update_visualization(self, initial=False):
        pos = nx.spring_layout(self.subG, dim=3, seed=42)

        # Prepare data for nodes and edges
        node_x = []
        node_y = []
        node_z = []
        node_size = []
        node_color = []
        node_text = []
        node_types = []

        for node in self.subG.nodes:
            x, y, z = pos[node]
            node_x.append(x)
            node_y.append(y)
            node_z.append(z)
            node_size.append(self.get_node_size(node))
            node_color.append(self.get_node_color(node))
            node_text.append(f"{node}\nUnderstanding: {self.subG.nodes[node]['understanding']:.2f}")
            node_types.append(self.subG.nodes[node].get('type', 'unknown'))

        # Edge data
        edge_x = []
        edge_y = []
        edge_z = []
        for edge in self.subG.edges:
            x0, y0, z0 = pos[edge[0]]
            x1, y1, z1 = pos[edge[1]]
            edge_x.extend([x0, x1, None])
            edge_y.extend([y0, y1, None])
            edge_z.extend([z0, z1, None])

        if initial:
            # Create traces for the initial plot
            self.fig.data = []  # Clear any existing data
            # Edge trace
            edge_trace = go.Scatter3d(
                x=edge_x, y=edge_y, z=edge_z,
                line=dict(width=1, color='#888'),
                hoverinfo='none',
                mode='lines',
                showlegend=False
            )
            self.fig.add_trace(edge_trace)

            # Node traces by type
            unique_types = set(node_types)
            for node_type in unique_types:
                indices = [i for i, t in enumerate(node_types) if t == node_type]
                trace = go.Scatter3d(
                    x=[node_x[i] for i in indices],
                    y=[node_y[i] for i in indices],
                    z=[node_z[i] for i in indices],
                    mode='markers',
                    marker=dict(
                        size=[node_size[i] for i in indices],
                        color=[node_color[i] for i in indices],
                        symbol='circle',
                        line_width=1
                    ),
                    text=[node_text[i] for i in indices],
                    hoverinfo='text',
                    name=node_type
                )
                self.fig.add_trace(trace)

            # Update layout
            self.fig.update_layout(
                showlegend=True,
                legend=dict(title='Node Types'),
                margin=dict(l=0, r=0, b=0, t=0)
            )

            # Display the figure
            display(self.fig)

        else:
            # Update existing traces
            # Update edge trace
            self.fig.data[0].x = edge_x
            self.fig.data[0].y = edge_y
            self.fig.data[0].z = edge_z

            # Update node traces
            unique_types = list(set(node_types))
            for idx, node_type in enumerate(unique_types):
                trace_idx = idx + 1  # Edge trace is at index 0
                indices = [i for i, t in enumerate(node_types) if t == node_type]
                self.fig.data[trace_idx].x = [node_x[i] for i in indices]
                self.fig.data[trace_idx].y = [node_y[i] for i in indices]
                self.fig.data[trace_idx].z = [node_z[i] for i in indices]
                self.fig.data[trace_idx].marker.size = [node_size[i] for i in indices]
                self.fig.data[trace_idx].marker.color = [node_color[i] for i in indices]
                self.fig.data[trace_idx].text = [node_text[i] for i in indices]

    def save_progress(self):
        progress = {node: data['understanding'] for node, data in self.subG.nodes(data=True)}
        with open(self.progress_file, 'w') as f:
            json.dump(progress, f)

    def load_progress(self):
        try:
            with open(self.progress_file, 'r') as f:
                progress = json.load(f)
            for node, understanding in progress.items():
                if node in self.subG.nodes:
                    self.subG.nodes[node]['understanding'] = understanding
            print("Progress loaded successfully.")
            # Update visualization to reflect loaded progress
            self.update_visualization()
        except FileNotFoundError:
            print("No progress file found. Starting fresh.")

    def recompute_layout(self):
        self.pos = nx.spring_layout(self.subG, dim=3, seed=42)
        self.update_visualization()

    def run(self):
        # Start the interaction loop
        self.ask_next_question()

    def ask_next_question(self):
        node_to_ask = self.select_next_question()
        if not node_to_ask:
            with self.output:
                clear_output()
                print("Congratulations! You have mastered all nodes.")
            return

        question = self.generate_question(node_to_ask)
        # Create input widget
        understanding_slider = widgets.IntSlider(
            value=2,
            min=0,
            max=3,
            step=1,
            description='Understanding:',
            style={'description_width': 'initial'},
            continuous_update=False
        )

        # Create a button to submit the response
        submit_button = widgets.Button(description='Submit')

        # Define the button click event handler
        def on_submit(b):
            score = understanding_slider.value
            # Update understanding
            self.subG.nodes[node_to_ask]['understanding'] = score
            # Update question history
            self.question_history.append(node_to_ask)
            if len(self.question_history) > 3:
                self.question_history.pop(0)  # Keep only the last 3 questions
            # Propagate understanding
            self.propagate_understanding(node_to_ask, score / 3.0)  # Normalize increment
            # Increment question counter
            self.question_count += 1
            # Provide feedback to the user
            mastered_nodes = sum(1 for node in self.subG.nodes if self.subG.nodes[node]['understanding'] >= 3)
            total_percentage = (mastered_nodes / self.total_nodes) * 100

            with self.output:
                clear_output()
                print(f"Question: {question}")
                print(f"Your understanding level for '{node_to_ask}' is now {self.subG.nodes[node_to_ask]['understanding']:.2f}")
                print(f"You have mastered {mastered_nodes}/{self.total_nodes} nodes ({total_percentage:.2f}%)")

            # Update visualization
            self.update_visualization()
            # Recompute layout every 5 questions
            if self.question_count % 5 == 0:
                self.recompute_layout()
            # Save progress after each question
            self.save_progress()
            # Ask the next question
            self.ask_next_question()

        # Attach the event handler to the button
        submit_button.on_click(on_submit)

        # Display the question and widgets
        with self.output:
            clear_output()
            print(f"Question: {question}")
            display(understanding_slider, submit_button)

    # ... [Rest of the methods remain unchanged] ...
# Assuming G is your original graph
# Create an instance and run
app = KnowledgeGraphApp(G, ['CHAPTER:1', *[f'CHAPTER:{i}' for i in range(108,114+1)]] )
app.run()



In [None]:
import networkx as nx
import plotly.graph_objects as go
import dash
from dash import dcc, html
from dash.dependencies import Input, Output

# Copy graph J
J = G.copy()

# Remove next_chapter edges
edges_to_remove = [(u, v) for u, v, attrs in J.edges(data=True) if attrs.get('type', None) in ["next_chapter"]]
print("removing edges:", edges_to_remove)
J.remove_edges_from(edges_to_remove)

# nodes_to_remove = [node for node, attrs in J.nodes(data=True) if attrs.get('type', "").startswith("verse")]
# print("removing nodes:", nodes_to_remove)
# J.remove_nodes_from(nodes_to_remove)
#

for i in range(106, 114+1):
    J.add_edge(f"CHAPTER:{i}", "CHAPTER:1", type="fake_edge")
    # J.add_edge("CHAPTER:1", f"CHAPTER:{i}", type="fake_edge")

J = J.reverse()


# Assign layers based on topological generations
for layer, nodes in enumerate(nx.topological_generations(J)):
    for node in nodes:
        J.nodes[node]['layer'] = layer

# Get the multipartite layout
pos = nx.multipartite_layout(J, subset_key='layer')

# Define colors for node types
type_colors = {
    'word_instance': 'red',
    'verse': 'green',
    'lemma': 'purple',
    'chapter': 'blue',
    'word': 'orange'
}

# Define colors for edge types
edge_type_colors = {
    'fake_edge': 'gray',
    'has_root': 'brown',
    'has_lemma': 'blue',
    'has_verse': 'green',
    'has_word_instance': 'red',
    'is_word': 'orange',
    'next_chapter': 'cyan',
    'prev_verse': 'magenta',
    'prev_word_instance': 'yellow'
}

def find_connected_nodes(G, start_node):
    """
    Find all nodes connected through any path to/from the start node in a directed graph.

    Parameters:
    - G: networkx.DiGraph
        The directed graph.
    - start_node: hashable
        The node from which to find connected nodes.

    Returns:
    - connected_nodes: set
        A set of nodes connected to the start_node via any path.
    """
    # Get all descendants (nodes reachable from start_node)
    descendants = nx.descendants(G, start_node)

    # Get all ancestors (nodes that can reach start_node)
    ancestors = nx.ancestors(G, start_node)

    # Combine all related nodes with the start_node itself
    connected_nodes = descendants.union(ancestors).union({start_node})

    return connected_nodes


# Prepare node data
node_data = []
for node, (x, y) in pos.items():
    node_type = J.nodes[node].get('type', 'word')
    node_data.append({
        'id': node,
        'x': x,
        'y': y,
        'color': type_colors.get(node_type, 'gray'),
        'size': 5 + 2 * (J.in_degree(node) + J.out_degree(node)),
        'text': f'Node: {node}<br>Type: {node_type}',
        'type': node_type
    })

# Prepare edge data
edge_data = []
for u, v, data in J.edges(data=True):
    edge_type = data.get('type', 'default')
    x0, y0 = pos[u]
    x1, y1 = pos[v]
    edge_data.append({
        'source': u,
        'target': v,
        'x': [x0, x1, None],
        'y': [y0, y1, None],
        'color': edge_type_colors.get(edge_type, 'lightgray'),
        'type': edge_type
    })

def create_figure(highlight_node=None):
    fig = go.Figure()
    connected_nodes = set()
    # Add edges
    if highlight_node is None:
        # Normal state: show all edges with their original colors
        for edge in edge_data:
            fig.add_trace(go.Scatter(
                x=edge['x'],
                y=edge['y'],
                line=dict(width=0.5, color=edge['color']),
                hoverinfo='none',
                mode='lines'
            ))
    else:
        # Highlighted state: emphasize connected edges
        # for edge in edge_data:
        #     if edge['source'] == highlight_node or edge['target'] == highlight_node:
        #         connected_edges.add((edge['source'], edge['target']))

        connected_edges = set()
        connected_nodes = find_connected_nodes(J, highlight_node)
        for u, v in J.edges():
            if u in connected_nodes and v in connected_nodes:
                connected_edges.add((u, v))


        for edge in edge_data:
            color = edge['color'] if (edge['source'], edge['target']) in connected_edges else 'rgba(200,200,200,0.3)'
            width = 1.5 if (edge['source'], edge['target']) in connected_edges else 0.5
            fig.add_trace(go.Scatter(
                x=edge['x'],
                y=edge['y'],
                line=dict(width=width, color=color),
                hoverinfo='none',
                mode='lines'
            ))

    # Add nodes
    node_x = [node['x'] for node in node_data]
    node_y = [node['y'] for node in node_data]
    node_colors = []
    node_sizes = []

    for node in node_data:
        if highlight_node is None:
            # Normal state
            node_colors.append(node['color'])
            node_sizes.append(node['size'])
        else:
            # Highlighted state
            is_connected = node['id'] in connected_nodes
            if node['id'] == highlight_node:
                node_colors.append(node['color'])
                node_sizes.append(node['size'] * 1.5)  # Make highlighted node bigger
            elif is_connected:
                node_colors.append(node['color'])
                node_sizes.append(node['size'] * 1.2)  # Make connected nodes slightly bigger
            else:
                node_colors.append('rgba(200,200,200,0.3)')  # Fade out unconnected nodes
                node_sizes.append(node['size'])

    fig.add_trace(go.Scatter(
        x=node_x,
        y=node_y,
        mode='markers',
        marker=dict(
            size=node_sizes,
            color=node_colors,
            line=dict(width=0.5, color='black')
        ),
        hoverinfo='text',
        text=[node['text'] for node in node_data],
        customdata=[node['id'] for node in node_data]
    ))

    # Update layout
    fig.update_layout(
        title="Interactive DAG Layout in Topological Order",
        showlegend=False,
        margin=dict(l=40, r=40, t=40, b=40),
        xaxis=dict(showgrid=False, zeroline=False, visible=False),
        yaxis=dict(showgrid=False, zeroline=False, visible=False),
        height=1200,
        hovermode='closest'
    )

    return fig

# Initialize Dash app
app = dash.Dash(__name__)

app.layout = html.Div([
    dcc.Graph(
        id='dag-graph',
        figure=create_figure(),
        clear_on_unhover=True
    )
])

@app.callback(
    Output('dag-graph', 'figure'),
    [Input('dag-graph', 'hoverData')]
)
def update_highlight(hoverData):
    if hoverData is None:
        return create_figure()

    point_index = hoverData['points'][0]['pointIndex']
    hovered_node = node_data[point_index]['id']
    return create_figure(highlight_node=hovered_node)#     return create_figure(highlight_node=hovered_node)

app.run_server(debug=False, host='0.0.0.0', port=8050)