In [2]:
import tkinter as tk
from tkinter import messagebox, scrolledtext, simpledialog
import pandas as pd
import time
import config
import webbrowser


class OracleLabeler:
    def __init__(self):
        # Load configuration values
        text_df_path = config.CURRENT_TRAINING_SET_PATH
        button_labels_path = config.PROMPT_TEMPLATE_LABELS

        # Load DataFrame and button labels from config paths
        self.text_df = pd.read_csv(text_df_path)
        self.button_labels = pd.read_csv(button_labels_path)['label'].tolist()

        # Add the 'human_label' column if it doesn't exist
        if 'human_label' not in self.text_df.columns:
            self.text_df['human_label'] = pd.NA
            
        # Add the 'edge_case' column if it doesn't exist
        if 'edge_case' not in self.text_df.columns:
            self.text_df['edge_case'] = False   
            
        # Add the 'comments' column if it doesn't exist, initialize as object type
        if 'comments' not in self.text_df.columns:
            self.text_df['comments'] = pd.Series([""] * len(self.text_df), dtype="object")

        # Set the index to the first unlabeled row or start from the beginning
        self.index = self.text_df[self.text_df['human_label'].isna()].index[0] if self.text_df['human_label'].isna().any() else 0
        self.update_counter = 0
        self.last_save_time = time.time()

        # Initialize the root window and main frame
        self.root = tk.Tk()
        self.root.title("Oracle Labeling Tool")
        self.setup_main_window()
        
    def setup_main_window(self):
        # Create frames
        self.info_frame = tk.Frame(self.root)
        self.info_frame.pack(pady=10)

        self.navigation_frame = tk.Frame(self.root)
        self.navigation_frame.pack(pady=10)

        self.text_frame = tk.Frame(self.root)
        self.text_frame.pack(pady=10)
        
        self.comments_frame = tk.Frame(self.root)
        self.comments_frame.pack(pady=10)

        self.buttons_frame = tk.Frame(self.root)
        self.buttons_frame.pack(pady=10)

        # Bind key presses to navigation functions
        self.root.bind('<KeyPress>', self.on_key_press)
        
        # Populate frames with widgets
        self.create_info_widgets()
        self.create_navigation_widgets()
        self.create_text_widget()
        self.create_buttons_widget()
        self.create_label_counts_widget()
        self.setup_comments_box()   

    def setup_comments_box(self):
        # Add comments input box
        self.comment_label = tk.Label(self.comments_frame, text="Add Comments:")
        self.comment_label.pack(anchor="w", padx=10, pady=5)
        
        self.comment_box = scrolledtext.ScrolledText(self.comments_frame, wrap=tk.WORD, width=60, height=5)
        self.comment_box.pack(padx=10, pady=5)
        
        # Load existing comment if present
        if pd.notna(self.text_df.loc[self.index, 'comments']):
            self.comment_box.insert(tk.END, self.text_df.loc[self.index, 'comments'])   
            
        # Bind focus in and focus out events to the comment box
        self.comment_box.bind('<FocusIn>', self.disable_key_bindings)
        self.comment_box.bind('<FocusOut>', self.on_focus_out)
        
    def on_focus_out(self, event):
        # Re-enable key bindings and set focus back to root window
        self.enable_key_bindings(event)
        self.root.focus_set()  # Return focus to the main window   

    def disable_key_bindings(self, event):
        # Disable all navigation key bindings when comment box is focused
        self.root.unbind('<KeyPress>')
        self.root.unbind('p')
        self.root.unbind('n')
        self.root.unbind('c')
        self.root.unbind('g')
        self.root.unbind('s')

    def enable_key_bindings(self, event=None):
        # Re-bind all navigation keys when comment box loses focus
        self.root.bind('<KeyPress>', self.on_key_press)
        self.root.bind('p', lambda event: self.go_to_previous())
        self.root.bind('n', lambda event: self.go_to_next())
        self.root.bind('c', lambda event: self.go_to_current())
        self.root.bind('g', lambda event: self.go_to_row())
        self.root.bind('s', lambda event: self.save_and_exit())     
        
    def create_info_widgets(self):
        # Article ID display
        self.article_id_label = tk.Label(self.info_frame, text="Article ID:")
        self.article_id_label.pack(pady=5)

        # Link label to view the article
        self.link_label = tk.Label(self.info_frame, text="Click here to view the article", fg="blue", cursor="hand2")
        self.link_label.pack()
        self.link_label.bind("<Button-1>", lambda e: self.open_article_link())

        # Keyword counts display
        self.keyword_counts_label = tk.Label(self.info_frame, text="Keywords:")
        self.keyword_counts_label.pack(pady=5)

    def create_navigation_widgets(self):
        # Navigation buttons and label
        self.position_label = tk.Label(self.navigation_frame, text="Row: 1")
        self.position_label.pack(side=tk.LEFT, padx=5)

        tk.Button(self.navigation_frame, text="Previous", command=self.go_to_previous).pack(side=tk.LEFT, padx=5)
        tk.Button(self.navigation_frame, text="Next", command=self.go_to_next).pack(side=tk.LEFT, padx=5)
        tk.Button(self.navigation_frame, text="Current", command=self.go_to_current).pack(side=tk.LEFT, padx=5)
        tk.Button(self.navigation_frame, text="Go to Row", command=self.go_to_row).pack(side=tk.LEFT, padx=5)
        tk.Button(self.navigation_frame, text="Save and Exit", command=self.save_and_exit).pack(side=tk.LEFT, padx=5)

    def create_text_widget(self):
        # Scrollable text area to display sample
        self.text_area = scrolledtext.ScrolledText(self.text_frame, wrap=tk.WORD, width=80, height=30)
        self.text_area.pack()
        self.text_area.configure(state='disabled')

    def create_buttons_widget(self):
        # Create flag checkbox
        self.flag_var = tk.BooleanVar()
        self.flag_checkbox = tk.Checkbutton(self.buttons_frame, text="Flag Article", variable=self.flag_var, command=self.on_flag_toggle)
        self.flag_checkbox.pack(side=tk.LEFT, padx=5)
        
        # Create label buttons
        self.label_buttons = {}
        for idx, label in enumerate(self.button_labels, start=1):
            button = tk.Button(self.buttons_frame, text=f"{idx}: {label}", command=lambda l=label: self.on_label_select(l))
            button.pack(side=tk.LEFT, padx=5)
            self.label_buttons[label] = button

        # Confirm button
        tk.Button(self.buttons_frame, text="Confirm", command=self.on_confirm).pack(side=tk.LEFT, padx=5)

        # Create label counts widget below the label buttons
        self.label_counts_label = tk.Label(self.root, text="Label Counts:")
        self.label_counts_label.pack(pady=(10, 0))

        # Create edge case checkbox
        self.edge_case_var = tk.BooleanVar()
        self.edge_case_checkbox = tk.Checkbutton(self.buttons_frame, text="Mark as Edge Case", variable=self.edge_case_var, command=self.on_edge_case_toggle)
        self.edge_case_checkbox.pack(side=tk.LEFT, padx=5)

    def on_edge_case_toggle(self):
        # Update the edge case status for the current sample
        self.text_df.at[self.index, "edge_case"] = self.edge_case_var.get()      
        
    def load_sample(self):
        # Load the sample text and update widgets
        article_id = self.text_df.iloc[self.index]["article_ID"]
        self.flag_var.set(article_id.endswith("_FLAGGED"))
        self.edge_case_var.set(bool(self.text_df.iloc[self.index]["edge_case"]))
        sample = self.text_df.iloc[self.index]["analyze_text"]
        self.position_label.config(text=f"Row: {self.index + 1} of {len(self.text_df)}")

        # Update Article ID label
        self.article_id_label.config(text=f"Article ID: {article_id}")

        # Update link visibility based on the presence of a link
        self.article_link = self.generate_link_from_id(article_id)
        if self.article_link:
            self.link_label.pack()
        else:
            self.link_label.pack_forget()

        # Update the keyword counts display
        keyword_counts = self.text_df.iloc[self.index]['keyword_counts']
        if isinstance(keyword_counts, str):
            keyword_counts = eval(keyword_counts)  # Convert string representation of dictionary to actual dictionary
        keywords_text = ", ".join([f"{keyword}: {count}" for keyword, count in keyword_counts.items()])
        self.keyword_counts_label.config(text=f"Keywords: {keywords_text}")

        # Update text area with the sample text
        self.text_area.configure(state='normal')
        self.text_area.delete(1.0, tk.END)
        self.text_area.insert(tk.END, sample)
        self.text_area.configure(state='disabled')

        # Highlight the appropriate label button if already labeled
        current_label = self.text_df.iloc[self.index]['human_label']
        for label, button in self.label_buttons.items():
            if pd.isna(current_label):
                button.config(relief=tk.RAISED, bg="SystemButtonFace")
            elif current_label == label:
                button.config(relief=tk.SUNKEN, bg="lightblue")
            else:
                button.config(relief=tk.RAISED, bg="SystemButtonFace")

        # Highlight keywords in the text area
        self.highlight_keywords_in_text(self.text_area, sample, keyword_counts.keys())
        
        # Clear and load comments
        self.comment_box.delete('1.0', tk.END)
        if pd.notna(self.text_df.loc[self.index, 'comments']):
            self.comment_box.insert(tk.END, self.text_df.loc[self.index, 'comments'])
            
        # Return focus to main window
        self.on_focus_out(None)

    def highlight_keywords_in_text(self, text_widget, text, keywords):
        """Highlight all instances of each keyword in the text_widget."""
        text_widget.tag_remove("highlight", "1.0", tk.END)  # Clear previous highlights
        for keyword in keywords:
            start_idx = "1.0"
            while True:
                # Search for the keyword in the text (case insensitive)
                start_idx = text_widget.search(keyword, start_idx, tk.END, nocase=True)
                if not start_idx:
                    break
                end_idx = f"{start_idx}+{len(keyword)}c"
                # Apply the highlight tag to the found keyword
                text_widget.tag_add("highlight", start_idx, end_idx)
                text_widget.tag_config("highlight", background="yellow", foreground="black")
                start_idx = end_idx

    def go_to_previous(self):
        if self.index > 0:
            self.save_current_state()  # Save current state before moving
            self.index -= 1
            self.load_sample()
            
    def go_to_current(self):
        # Navigate to the first unlabeled row
        if self.text_df['human_label'].isna().any():
            self.save_current_state()  # Save current state before moving
            self.index = self.text_df[self.text_df['human_label'].isna()].index[0]
            self.load_sample()
            
    def update_flag_status(self):
        article_id = self.text_df.at[self.index, "article_ID"]
        if self.flag_var.get() and not article_id.endswith("_FLAGGED"):
            self.text_df.at[self.index, "article_ID"] = f"{article_id}_FLAGGED"
        elif not self.flag_var.get() and article_id.endswith("_FLAGGED"):
            self.text_df.at[self.index, "article_ID"] = article_id.replace("_FLAGGED", "")

    def on_flag_toggle(self):
        self.update_flag_status()

    def go_to_row(self):
        # Prompt user to enter a row number to navigate to
        row_num = tk.simpledialog.askinteger("Go to Row", "Enter the row number:", parent=self.root, minvalue=1, maxvalue=len(self.text_df))
        if row_num is not None:
            self.save_current_state()  # Save current state before moving
            self.index = row_num - 1
            self.load_sample()

    def go_to_next(self):
        if self.index < len(self.text_df) - 1:
            self.save_current_state()  # Save current state before moving
            self.index += 1
            self.load_sample()

    def on_label_select(self, label):
        # Update the label for the current sample
        self.text_df.at[self.index, "human_label"] = label
        self.update_label_counts()

        # Highlight the selected label button
        for lbl, button in self.label_buttons.items():
            if lbl == label:
                button.config(relief=tk.SUNKEN, bg="lightblue")
            else:
                button.config(relief=tk.RAISED, bg="SystemButtonFace")

    def create_label_counts_widget(self):
        self.update_label_counts()

    def update_label_counts(self):
        # Update the label counts display
        label_counts = self.text_df['human_label'].value_counts(dropna=False)
        counts_text = ' | '.join([f"{label}: {count}" for label, count in label_counts.items() if pd.notna(label)])
        nan_count = self.text_df['human_label'].isna().sum()
        nan_text = f"Unlabeled (NaN): {nan_count}"

        if hasattr(self, 'label_counts_label'):
            self.label_counts_label.config(text=f"Label Counts: {counts_text} | {nan_text}")
        # Update the label counts display
        label_counts = self.text_df['human_label'].value_counts(dropna=False)
        counts_text = '\n'.join([f"{label}: {count}" for label, count in label_counts.items() if pd.notna(label)])
        nan_count = self.text_df['human_label'].isna().sum()
        nan_text = f"Unlabeled (NaN): {nan_count}"

        if hasattr(self, 'label_counts_label'):
            self.label_counts_label.config(text=f"Label Counts: {counts_text} | {nan_text}")
        else:
            self.label_counts_label = tk.Label(self.root, text=f"Label Counts:\n{counts_text}\n{nan_text}")
            self.label_counts_label.pack(pady=5)

    def save_and_exit(self):
        # Save current progress and exit
        self.save_current_state()  # Save current state before moving
        self.save_progress()
        self.root.destroy()

    def on_key_press(self, event):
        if event.keysym.isdigit() and int(event.keysym) in range(1, len(self.button_labels) + 1):
            label = self.button_labels[int(event.keysym) - 1]
            self.on_label_select(label)
        elif event.keysym == 'Return':
            self.on_confirm()

    def on_confirm(self):
        # Check if a label has been selected
        current_label = self.text_df.iloc[self.index]['human_label']
        if pd.isna(current_label):
            confirm = messagebox.askyesno("No Label Selected", "No label has been selected for this article. Do you want to proceed without labeling?")
            if not confirm:
                return
            
        self.save_current_state()  # Save current state before moving        
            
        # Increment the update counter and save progress if needed
        self.update_counter += 1
        if self.update_counter >= 10 or (time.time() - self.last_save_time) >= 300:
            self.save_progress()

        # Move to the next sample
        self.go_to_next()

    def save_current_state(self):
        # Save the current comment and label to the DataFrame
        self.text_df.at[self.index, "comments"] = str(self.comment_box.get('1.0', tk.END).strip())
        current_label = self.text_df.at[self.index, "human_label"]
        if pd.isna(current_label):
            # Optional: Handle the case where a label hasn't been selected
            pass      
        
    def save_progress(self):
        # Save the labeled DataFrame to a CSV file
        try:
            save_path = config.CURRENT_TRAINING_SET_PATH
            self.text_df.to_csv(save_path, index=False)
            print("Progress saved successfully.")
            self.update_counter = 0
            self.last_save_time = time.time()
        except Exception as e:
            messagebox.showerror("Save Error", f"An error occurred while saving progress: {e}")

    def generate_link_from_id(self, article_id):
        bbox = self.text_df.iloc[self.index]['bbox']
        if isinstance(bbox, str):
            bbox = eval(bbox)  # Convert string representation of list to actual list
        if len(bbox) == 4:
            parts = article_id.split('_')
            if len(parts) >= 5:
                page_number = parts[2].split('p')[-1]
                date_part = parts[1]
                newspaper_id = parts[3]
                clip_coords = ','.join(map(str, bbox))
                return f"https://www.loc.gov/resource/{newspaper_id}/{date_part}/ed-1/?sp={page_number}&clip={clip_coords}"
        return None

    def open_article_link(self):
        if self.article_link:
            webbrowser.open_new(self.article_link)

    def run(self):
        # Initially, bind only the global key presses to navigation functions
        self.enable_key_bindings()  # Bind all navigation keys when starting
        self.root.mainloop()

# Sample usage
if __name__ == "__main__":
    labeler = OracleLabeler()
    labeler.run()

Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved successfully.
Progress saved