# Key Detection Result Visualizer

In [1]:
%pip install yfiles_jupyter_graphs --quiet
try:
  import google.colab
  from google.colab import output
  output.enable_custom_widget_manager()
except:
  pass


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
from neo4j import GraphDatabase
from typing import Dict
from yfiles_jupyter_graphs import GraphWidget
from ipywidgets import widgets, Output
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import numpy as np

uri      = "neo4j://localhost:7689" 
user     = "neo4j"          # your user name 
                              # default is always "neo4j" 
                              # unless you have changed it. 
password = "testtest"

driver = GraphDatabase.driver(uri=uri,auth=(user,password),database='movies')
session = driver.session(database='gitschemas3')

numResults = session.run("""
            MATCH (s:Schema) 
            WHERE s.openAI_GPT_4_response IS NOT NULL
            RETURN count(s) as numSchema""")



def get_schema_name(index=1):    
    schemaName = session.run("""
                MATCH (s:Schema) 
                WHERE s.openAI_GPT_4_response IS NOT NULL
                WITH s
                SKIP $index - 1
                LIMIT 1
                RETURN s.filename as filename""", parameters={'index': index })
    result = schemaName.to_df()['filename'][0]
    return result


def get_schema_f1_score(index=1):    
    schemaName = session.run("""
                MATCH (s:Schema) 
                WHERE s.openAI_GPT_4_response IS NOT NULL
                WITH s
                SKIP $index - 1
                LIMIT 1
                OPTIONAL MATCH (s)-->(:Table)-->(:Column)-[:PK_COLUMN|FK_COLUMN]-(key:PrimaryKey|ForeignKey)
                WITH s, key
                WITH s, collect(distinct key) as schemaKeys
                WITH   s,
                    [key in schemaKeys WHERE "PrimaryKey" IN labels(key)] as primaryKeys,
                    [key in schemaKeys WHERE "ForeignKey" IN labels(key)] as foreignKeys
                WITH s.url as url,
                    size([key in primaryKeys WHERE key.groundTruth = TRUE AND "LLM" IN coalesce(key.detectedBy,[])]) as pk_TP,
                    size([key in primaryKeys WHERE key.groundTruth = FALSE AND "LLM" IN coalesce(key.detectedBy,[])]) as pk_FP,
                    size([key in primaryKeys WHERE key.groundTruth = TRUE AND NOT "LLM" IN coalesce(key.detectedBy,[])]) as pk_FN,
                    size([key in foreignKeys WHERE key.groundTruth = TRUE AND "LLM" IN coalesce(key.detectedBy,[])]) as fk_TP,
                    size([key in foreignKeys WHERE key.groundTruth = FALSE AND "LLM" IN coalesce(key.detectedBy,[])]) as fk_FP,
                    size([key in foreignKeys WHERE key.groundTruth = TRUE AND NOT "LLM" IN coalesce(key.detectedBy,[])]) as fk_FN
                WITH  *,
                        pk_TP + pk_FN as pk_Total,
                        fk_TP + fk_FN as fk_Total,
                        2.0 * pk_TP/(2.0 * pk_TP + pk_FN + pk_FP) as pk_f1_score,
                        2.0 * fk_TP/(2.0 * fk_TP + fk_FN + fk_FP) as fk_f1_score
                RETURN pk_f1_score, fk_f1_score""", parameters={'index': index })
    
    result = schemaName.to_df()[:1]
    return result

def get_population_f1_score():    
    schemaName = session.run("""
                MATCH (s:Schema) 
                WHERE s.openAI_GPT_4_response IS NOT NULL
                OPTIONAL MATCH (s)-->(:Table)-->(:Column)-[:PK_COLUMN|FK_COLUMN]-(key:PrimaryKey|ForeignKey)
                WITH s, key
                WITH s, collect(distinct key) as schemaKeys
                WITH   s,
                    [key in schemaKeys WHERE "PrimaryKey" IN labels(key)] as primaryKeys,
                    [key in schemaKeys WHERE "ForeignKey" IN labels(key)] as foreignKeys
                WITH s.url as url,
                    size([key in primaryKeys WHERE key.groundTruth = TRUE AND "LLM" IN coalesce(key.detectedBy,[])]) as pk_TP,
                    size([key in primaryKeys WHERE key.groundTruth = FALSE AND "LLM" IN coalesce(key.detectedBy,[])]) as pk_FP,
                    size([key in primaryKeys WHERE key.groundTruth = TRUE AND NOT "LLM" IN coalesce(key.detectedBy,[])]) as pk_FN,
                    size([key in foreignKeys WHERE key.groundTruth = TRUE AND "LLM" IN coalesce(key.detectedBy,[])]) as fk_TP,
                    size([key in foreignKeys WHERE key.groundTruth = FALSE AND "LLM" IN coalesce(key.detectedBy,[])]) as fk_FP,
                    size([key in foreignKeys WHERE key.groundTruth = TRUE AND NOT "LLM" IN coalesce(key.detectedBy,[])]) as fk_FN
                WITH  *,
                        pk_TP + pk_FN as pk_Total,
                        fk_TP + fk_FN as fk_Total,
                        2.0 * pk_TP/(2.0 * pk_TP + pk_FN + pk_FP) as pk_f1_score,
                        2.0 * fk_TP/(2.0 * fk_TP + fk_FN + fk_FP) as fk_f1_score
                RETURN avg(pk_f1_score) as pk_f1_score, avg(fk_f1_score) as fk_f1_score""", parameters={})
    
    result = schemaName.to_df()[:1]
    return result

class GraphWidgetWrapper:
    def __init__(self, index):
        self.index = index
        self.out = widgets.Output()
        display(self.out)    
        self.fetch_and_render()

    # Green #549343
    # Red #B53A37
    def custom_styles_mapping_image(self, index: int, item: Dict):
        columnDataURI = ""
        tableDataURI = ""
        primaryKeyGreyDataURI = ""
        primaryKeyGreenDataURI = ""
        primaryKeyRedDataURI = ""
        foreignKeyGreyDataURI = ""
        foreignKeyGreenDataURI = ""
        foreignKeyRedDataURI = ""

        labelToIcon = {
            'Column': columnDataURI,
            'Table': tableDataURI,
            'PrimaryKey': primaryKeyGreyDataURI,
            'ForeignKey': foreignKeyGreyDataURI
            }

        # try:
        detectedby = ""
        label = ""
        groundTruth = None
        try:
            detectedby = item['properties']['detectedBy']
        except:
            pass
        try: 
            label = item['properties']['label']
        except:
            pass
        try:
            groundTruth = item['properties']['groundTruth']
        except:
            pass

        icon = labelToIcon[item['properties']['label']]
        if label == 'PrimaryKey' and groundTruth == True and "LLM" in detectedby:
            icon = primaryKeyGreenDataURI
        elif label == 'PrimaryKey' and groundTruth == False and "LLM" in detectedby:
            icon = primaryKeyRedDataURI
        elif label == 'ForeignKey' and groundTruth == True and "LLM" in detectedby:
            icon = foreignKeyGreenDataURI
        elif label == 'ForeignKey' and groundTruth == False and "LLM" in detectedby:
            icon = foreignKeyRedDataURI

        return {
                'image': icon
            }
        # except:
        #     return {}

    def custom_node_label_mapping(self, index, node):
        """let the label be the negated index"""
        properties = node.get('properties', {})
        return properties.get('title', properties.get('name', ''));

    def custom_relationship_label_mapping(self, index, node):
        """let the label be the negated index"""
        properties = node.get('properties', {})
        return properties.get('label', ['']);

    def updateIndex(self, index):
        self.index = index
        self.fetch_and_render()

    def fetch_and_render(self):
        result = session.run("""MATCH (s:Schema)-[r]-() 
                    WHERE s.openAI_GPT_4_response IS NOT NULL
                    WITH s, collect(r) as rs
                    SKIP $index - 1
                    LIMIT 1
                    CALL apoc.path.subgraphAll(s, {
                    minLevel: 0,
                    maxLevel: 25
                    })
                    YIELD nodes, relationships
                    RETURN apoc.coll.subtract(nodes, [s]), apoc.coll.subtract(relationships, rs);
                    """, parameters={'index': self.index })
        w = GraphWidget(graph=result.graph())
        w.node_label_mapping = self.custom_node_label_mapping
        w.edge_label_mapping = self.custom_relationship_label_mapping
        w.set_node_styles_mapping(self.custom_styles_mapping_image)
        w.organic_layout()
        w.set_sidebar(enabled=False)
        self.out.clear_output()
        with self.out:
            w.show()


class PaginationWidget:
    def __init__(self, total_pages):
        self.current_page = 1
        self.total_pages = total_pages
        
        # Create widgets
        self.page_label = widgets.Label(value=f"Page {self.current_page} of {self.total_pages}")
        self.prev_button = widgets.Button(description='Previous')
        self.next_button = widgets.Button(description='Next')
        self.schema_name = widgets.Label(value=f"{get_schema_name(index=self.current_page)}")
        self.f1_score = widgets.Output()
        
        # Connect button click events to functions
        self.prev_button.on_click(self.on_prev_button_click)
        self.next_button.on_click(self.on_next_button_click)
        
    def get_widget(self):
        # Construct layout
        widget_layout = [self.prev_button, self.page_label, self.next_button, self.schema_name, self.f1_score]

        return widget_layout

    def on_prev_button_click(self, b):
        if self.current_page > 1:
            self.current_page -= 1
            self.page_label.value = f"Page {self.current_page} of {self.total_pages}"
            self.schema_name.value = f"{get_schema_name(index=self.current_page)}"
            graph.updateIndex(self.current_page)
            self.show_f1_scores(self.current_page)
            # fetch_and_render(self.current_page)
    
    def on_next_button_click(self, b):
        if self.current_page < self.total_pages:
            self.current_page += 1
            self.page_label.value = f"Page {self.current_page} of {self.total_pages}"
            self.schema_name.value = f"{get_schema_name(index=self.current_page)}"
            graph.updateIndex(self.current_page)
            self.show_f1_scores(self.current_page)
            # fetch_and_render(self.current_page)
    
    def show_f1_scores(self, index=1):
        # F1 scores for two categories (scaled between 0 and 1)
        f1_scores = get_schema_f1_score(index).values.flatten()
        popuation_f1_scores = get_population_f1_score().values.flatten()

        categories = ['Primary', 'Foreign']  # Labels for categories

        # Calculate average F1 score for each category
        average_scores = popuation_f1_scores

        # Creating a bar plot with adjusted figure size
        fig, ax = plt.subplots(figsize=(3, 2))  # Set the figure size

        bars = ax.bar(np.arange(len(categories)), f1_scores, color='lightblue')  # Bar plot with different colors for each category

        # Plotting the average line for each category only over the respective bar
        for i, score in enumerate(average_scores):
            line = ax.axhline(y=score, color='blue', alpha=0.7, linestyle='-', linewidth=2, xmin=(i + 0.1) / len(categories), xmax=(i + 0.9) / len(categories))
            if i == 0:  # Only include the first red line in the legend
                legend_handles = [line]
                legend_labels = ['Overall schema average']

        # Adding labels and title
        ax.set_xlabel('Categories')
        ax.set_ylabel('F1 Score')
        ax.set_title(r'f1 score for Key Predictions $\it{in\ this\ example}$', fontsize=10)  # Adjust fontsize for the title

        # Setting the y-axis limit from 0 to 1
        ax.set_ylim(0, 1)

        # Displaying legend for red line only underneath the plot
        ax.legend(handles=legend_handles, labels=legend_labels, loc='upper center', bbox_to_anchor=(0.5, -0.2), fontsize='small', ncol=1, prop={'size': 8})

        # Displaying the plot
        ax.set_xticks(np.arange(len(categories)))
        ax.set_xticklabels(categories)
        plt.tight_layout()  # Ensure tight layout to prevent overlapping

        canvas = FigureCanvas(fig)  # Create FigureCanvas
          # Create an output widget
        self.f1_score.clear_output()
        with self.f1_score:  # Display the figure inside the widget
            plt.show()


total_pages = numResults.to_df()['numSchema'][0]
pagination = PaginationWidget(total_pages).get_widget()
# pagination.append(show_f1_scores())
display(widgets.HBox(pagination))
graph = GraphWidgetWrapper(1)


KeyError: 0