In [1]:
import json
import os
import chromadb
import google.generativeai as genai
from statistics import median
from typing import List, Dict
from dotenv import load_dotenv
import gradio as gr
# ==========================================
# 0. CONFIGURATION
# ==========================================
# Export this in your terminal: export GOOGLE_API_KEY="your_key"
load_dotenv()
API_KEY = os.getenv("GEMINI_API_KEY") 
DB_PATH = "./spark_rag_db" # Where vectors will be saved on disk

if not API_KEY:
    print("‚ö†Ô∏è  WARNING: GOOGLE_API_KEY environment variable not set.")
else:
    genai.configure(api_key=API_KEY)

# ==========================================
# 1. PARSING LOGIC (Your Code)
# ==========================================

class SparkLogAnalysis:
    """A container object for all parsed summaries from a Spark event log."""
    def __init__(self, log_file_path):
        self.application_id = "unknown_app"
        self.log_file = os.path.basename(log_file_path)
        self.job_summaries = []
        self.stage_summaries = []
        self.application_summaries = {}

    def set_application_id(self, app_id):
        if app_id:
            self.application_id = app_id

    def add_summary(self, summary_dict):
        if not summary_dict: return
        event_type = summary_dict.get("event_type")
        summary_dict["application_id"] = self.application_id

        if event_type == "JobPerformanceSummary":
            self.job_summaries.append(summary_dict)
        elif event_type == "StagePerformanceSummary":
            self.stage_summaries.append(summary_dict)
        else:
            self.application_summaries[event_type] = summary_dict

    def to_dict(self):
        return {
            "application_id": self.application_id,
            "log_file": self.log_file,
            "job_summaries": self.job_summaries,
            "stage_summaries": self.stage_summaries,
            "application_summaries": self.application_summaries
        }

def analyze_spark_log(log_file_path):
    """Parses a single Spark event log."""
    analysis_object = SparkLogAnalysis(log_file_path)
    stage_data = {}
    job_data = {}
    executor_lifecycle = {}
    disk_related_executor_failures = {}
    blacklisted_executors = []

    try:
        with open(log_file_path, 'r') as f:
            for line in f:
                try:
                    event = json.loads(line)
                except json.JSONDecodeError:
                    continue

                event_type = event.get("Event")
                
                if event_type == "SparkListenerApplicationStart":
                    analysis_object.set_application_id(event.get("App ID"))
                
                elif event_type == "SparkListenerExecutorAdded":
                    executor_id, timestamp = event.get("Executor ID"), event.get("Timestamp")
                    executor_lifecycle.setdefault(executor_id, {})['add_time'] = timestamp

                elif event_type == "SparkListenerExecutorRemoved":
                    executor_id, timestamp = event.get("Executor ID"), event.get("Timestamp")
                    reason = event.get("Removed Reason", "")
                    executor_lifecycle.setdefault(executor_id, {})['remove_time'] = timestamp
                    if "no space left on device" in reason.lower() or "disk" in reason.lower():
                        disk_related_executor_failures.setdefault(reason, {"count": 0, "samples": []})
                        disk_related_executor_failures[reason]["count"] += 1
                        if len(disk_related_executor_failures[reason]["samples"]) < 3:
                            disk_related_executor_failures[reason]["samples"].append({"executor_id": executor_id, "timestamp": timestamp})

                elif event_type == "SparkListenerExecutorBlacklisted":
                    blacklisted_executors.append({"executor_id": event.get("Executor ID"), "timestamp": event.get("Timestamp"), "reason": f"Blacklisted for {event.get('Task ID')} task failures"})

                elif event_type == "SparkListenerJobStart":
                    job_id = event.get("Job ID")
                    job_data[job_id] = {"start_time": event.get("Submission Time")}

                elif event_type == "SparkListenerJobEnd":
                    job_id = event.get("Job ID")
                    if job_id in job_data:
                        duration_ms = event.get("Completion Time") - job_data[job_id]["start_time"]
                        job_result = event.get("Job Result", {}).get("Result")
                        analysis_object.add_summary({ "event_type": "JobPerformanceSummary", "job_id": job_id, "status": job_result, "duration_s": duration_ms / 1000, "summary": f"Job {job_id} finished with status {job_result} in {duration_ms / 1000:.2f} seconds." })
                        del job_data[job_id]

                elif event_type == "SparkListenerTaskEnd":
                    stage_id = event.get("Stage ID")
                    if stage_id is not None:
                        stage_data.setdefault(stage_id, {"tasks": [], "failed_task_count": 0})
                        if event.get("Task End Reason", {}).get("Reason") != "Success": stage_data[stage_id]["failed_task_count"] += 1
                        task_info, task_metrics = event.get("Task Info", {}), event.get("Task Metrics", {})
                        duration_ms = task_info.get("Finish Time", 0) - task_info.get("Launch Time", 0)
                        stage_data[stage_id]["tasks"].append({ "duration_ms": duration_ms, "jvm_gc_time_ms": task_metrics.get("JVM GC Time", 0), "executor_run_time_ms": task_metrics.get("Executor Run Time", 1), "executor_deserialize_time_ms": task_metrics.get("Executor Deserialize Time", 0), "shuffle_fetch_wait_time_ms": task_metrics.get("Shuffle Read Metrics", {}).get("Fetch Wait Time", 0), "memory_spilled_bytes": task_metrics.get("Memory Bytes Spilled", 0), "disk_spilled_bytes": task_metrics.get("Disk Bytes Spilled", 0), })

                elif event_type == "SparkListenerStageCompleted":
                    stage_id = event.get("Stage Info", {}).get("Stage ID")
                    if stage_id in stage_data and len(stage_data[stage_id]["tasks"]) > 0:
                        tasks, failed_task_count = stage_data[stage_id]["tasks"], stage_data[stage_id]["failed_task_count"]
                        task_durations = [t["duration_ms"] for t in tasks]
                        total_duration, total_gc_time, total_runtime = sum(task_durations), sum(t["jvm_gc_time_ms"] for t in tasks), sum(t["executor_run_time_ms"] for t in tasks)
                        total_deserialize_time, total_fetch_wait_time = sum(t["executor_deserialize_time_ms"] for t in tasks), sum(t["shuffle_fetch_wait_time_ms"] for t in tasks)
                        total_spilled_bytes = sum(t["memory_spilled_bytes"] for t in tasks) + sum(t["disk_spilled_bytes"] for t in tasks)
                        summary_parts, potential_issues = [f"Stage {stage_id} completed with {len(tasks)} tasks."], []
                        if failed_task_count > 0: summary_parts.append(f"It experienced {failed_task_count} task failures that required retries."); potential_issues.append("TASK_FAILURES")
                        max_duration, median_duration = max(task_durations), median(task_durations)
                        if max_duration > 3 * median_duration and max_duration > 20000: potential_issues.append("DATA_SKEW"); summary_parts.append(f"Detected potential data skew. Max task duration: {max_duration/1000:.2f}s, median: {median_duration/1000:.2f}s.")
                        if total_runtime > 0 and (total_gc_time / total_runtime) > 0.10: potential_issues.append("HIGH_GC_PRESSURE"); summary_parts.append(f"High JVM GC pressure detected ({ (total_gc_time / total_runtime) * 100:.1f}% of runtime).")
                        if total_spilled_bytes > 0: potential_issues.append("DATA_SPILL"); summary_parts.append(f"Detected data spilling to disk ({total_spilled_bytes / (1024*1024):.2f} MB).")
                        if total_duration > 0 and (total_fetch_wait_time / total_duration) > 0.25: potential_issues.append("SHUFFLE_BOTTLENECK"); summary_parts.append(f"Significant shuffle bottleneck, tasks spent {(total_fetch_wait_time / total_duration) * 100:.1f}% of time waiting for data.")
                        if total_runtime > 0 and (total_deserialize_time / total_runtime) > 0.15: potential_issues.append("DESERIALIZATION_BOTTLENECK"); summary_parts.append(f"High deserialization time ({(total_deserialize_time / total_runtime) * 100:.1f}% of runtime).")
                        analysis_object.add_summary({ "event_type": "StagePerformanceSummary", "stage_id": stage_id, "metrics": { "task_count": len(tasks), "failed_task_count": failed_task_count, "max_task_duration_ms": max_duration, "median_task_duration_ms": median_duration, "total_spilled_mb": total_spilled_bytes / (1024*1024) }, "summary": " ".join(summary_parts), "potential_issues": potential_issues if potential_issues else ["NONE"] })

        if disk_related_executor_failures:
            total_failures = sum(v['count'] for v in disk_related_executor_failures.values())
            analysis_object.add_summary({ "event_type": "ExecutorDiskFailureSummary", "total_failures": total_failures, "summary": f"A critical error occurred where {total_failures} executors failed to launch due to running out of disk space on the nodes.", "failures_by_reason": disk_related_executor_failures })
        
        short_lived_executors = []
        churn_threshold_ms = 60000 
        for exec_id, times in executor_lifecycle.items():
            if 'add_time' in times and 'remove_time' in times and (times['remove_time'] - times['add_time']) < churn_threshold_ms:
                short_lived_executors.append({"executor_id": exec_id, "lifespan_s": (times['remove_time'] - times['add_time']) / 1000})
        if short_lived_executors:
            avg_lifespan_s = sum(e['lifespan_s'] for e in short_lived_executors) / len(short_lived_executors)
            analysis_object.add_summary({ "event_type": "ExecutorChurnSummary", "total_churn_events": len(short_lived_executors), "summary": f"Detected {len(short_lived_executors)} executors with an average lifespan of just {avg_lifespan_s:.2f}s, indicating severe instability.", "average_lifespan_s": round(avg_lifespan_s, 2), "sample_churned_executors": short_lived_executors[:3] })

    except Exception as e:
        print(f"Error reading file: {e}")
        return None

    return analysis_object

# ==========================================
# 2. VECTORIZATION (ChromaDB Integration)
# ==========================================

def get_chroma_collection():
    """Returns a persistent ChromaDB collection."""
    # Using PersistentClient to save data to disk so we don't re-index every run
    client = chromadb.PersistentClient(path=DB_PATH)
    # Using the default embedding model (all-MiniLM-L6-v2)
    return client.get_or_create_collection(name="spark_logs")

def ingest_logs(log_directory):
    """Parses logs and ingests them into ChromaDB."""
    collection = get_chroma_collection()
    
    # Check if we already have data (optional optimization)
    if collection.count() > 0:
        print(f"üìö Vector DB already contains {collection.count()} documents. Skipping ingestion.")
        print("To force re-ingestion, delete the 'spark_rag_db' folder.")
        return

    print(f"üìÇ Processing logs in: {log_directory}")
    all_analyses = []

    for filename in os.listdir(log_directory):
        full_path = os.path.join(log_directory, filename)
        if os.path.isfile(full_path):
            print(f"   - Parsing {filename}...")
            analysis = analyze_spark_log(full_path)
            if analysis:
                all_analyses.append(analysis)

    print("üß© Vectorizing and storing data...")
    documents, metadatas, ids = [], [], []

    for analysis in all_analyses:
        # Flatten the object
        summaries = (analysis.job_summaries + 
                     analysis.stage_summaries + 
                     list(analysis.application_summaries.values()))

        for i, summary in enumerate(summaries):
            # 1. Text to Embed
            documents.append(summary["summary"])
            
            # 2. Metadata Cleaning (Chroma requires simple types)
            meta = summary.copy()
            del meta["summary"]
            
            # Convert list/dict fields to strings for metadata storage
            for key, val in meta.items():
                if isinstance(val, (list, dict)):
                    meta[key] = str(val)
            
            metadatas.append(meta)
            ids.append(f"{analysis.application_id}_{i}")

    if documents:
        collection.add(documents=documents, metadatas=metadatas, ids=ids)
        print(f"‚úÖ Successfully ingested {len(documents)} log events into ChromaDB.")
    else:
        print("‚ö†Ô∏è No valid log events found to ingest.")

# ==========================================
# 3. LLM ORCHESTRATOR (The RAG Logic) - UPDATED
# ==========================================

def get_gemini_response(user_query, collection, app_id=None):
    """
    Generator function that streams the response from Gemini.
    """
    # 1. RETRIEVAL (Same as before)
    query_args = {
        "query_texts": [user_query],
        "n_results": 5
    }
    
    if app_id:
        print(f"üîç Searching logs specifically for App: {app_id}...")
        query_args["where"] = {"application_id": app_id}
    else:
        print(f"üîç Searching across ALL application logs...")

    results = collection.query(**query_args)
    retrieved_docs = results['documents'][0]
    retrieved_meta = results['metadatas'][0]

    if not retrieved_docs:
        yield "I couldn't find any log events matching your question."
        return

    # 2. AUGMENTATION (Same as before)
    context_parts = []
    for doc, meta in zip(retrieved_docs, retrieved_meta):
        current_app = meta.get('application_id', 'unknown')
        info_tag = f"[App: {current_app} | {meta.get('event_type', 'Event')}]"
        issue_tag = f"[Issues: {meta.get('potential_issues', 'None')}]"
        context_parts.append(f"{info_tag} {issue_tag}\nLog Summary: {doc}")

    context_str = "\n---\n".join(context_parts)
    
    if app_id:
        context_intro = f"You are analyzing performance logs for a SPECIFIC Spark Application: {app_id}."
    else:
        context_intro = "You are analyzing a fleet of Spark Applications. The logs provided may belong to different applications."

    system_prompt = f"""
    You are an expert Apache Spark Log Diagnostician. 
    {context_intro}

    Here are the most relevant log summaries retrieved from the system:
    
    {context_str}

    USER QUESTION: "{user_query}"

    INSTRUCTIONS:
    1. Identify root causes based ONLY on the provided log summaries.
    2. If analyzing multiple apps, clearly state which Application ID had which issue.
    3. If the logs mention 'Data Skew', 'GC Pressure', or 'Shuffle', explain what that means.
    4. Suggest concrete configuration changes where applicable.
    5. Keep the tone professional and concise.
    """

    # 3. GENERATION (Streaming Mode)
    try:
        model = genai.GenerativeModel('gemini-2.5-flash')
        # Enable streaming
        response_stream = model.generate_content(system_prompt, stream=True)
        
        partial_text = ""
        for chunk in response_stream:
            if chunk.text:
                partial_text += chunk.text
                yield partial_text  # Yield the updated full text at every step
                
    except Exception as e:
        yield f"‚ùå Error communicating with Gemini API: {e}"

# ==========================================
# 4. GRADIO UI IMPLEMENTATION
# ==========================================

def get_all_app_ids():
    """Helper to fetch unique Application IDs from ChromaDB for the dropdown."""
    try:
        collection = get_chroma_collection()
        # Fetch metadata to find unique App IDs
        # Note: limiting to 1000 for performance; adjust if you have huge datasets
        data = collection.get(limit=1000, include=['metadatas'])
        metadatas = data.get('metadatas', [])
        unique_ids = sorted(list(set(m.get('application_id') for m in metadatas if m.get('application_id'))))
        return unique_ids
    except Exception as e:
        return []

def ui_ingest_logs(directory):
    """Wrapper for the ingest button."""
    if not directory:
        return "‚ö†Ô∏è Please enter a valid directory path."
    try:
        ingest_logs(directory)
        return f"‚úÖ Ingestion complete for: {directory}"
    except Exception as e:
        return f"‚ùå Error: {str(e)}"

# ==========================================
# 5. BUILDING THE UI
# ==========================================

with gr.Blocks(title="Spark Log RAG AI", theme=gr.themes.Soft()) as demo:
    
    # Header
    gr.Markdown("# ‚ö° Spark Log Diagnostic AI")
    gr.Markdown("Diagnose performance issues using RAG (Retrieval Augmented Generation).")
    
    with gr.Row():
        
        # --- LEFT COLUMN: Settings & Controls ---
        with gr.Column(scale=1, variant="panel"):
            gr.Markdown("### ‚öôÔ∏è Configuration")
            
            # 1. Ingestion Section
            log_dir_input = gr.Textbox(
                label="Log Directory Path", 
                placeholder=r"C:\path\to\spark-logs",
                value=r"C:\Users\ddev\Documents\projects\spark-event-logs\test-logs"
            )
            ingest_btn = gr.Button("üìÇ Ingest Logs", variant="secondary")
            ingest_status = gr.Textbox(label="Status", interactive=False)
            
            gr.HTML("<hr>") # Visual Separator
            
            # 2. Context Locking
            gr.Markdown("### üéØ Context Scope")
            app_dropdown = gr.Dropdown(
                choices=["All Applications"] + get_all_app_ids(),
                value="All Applications",
                label="Filter by Application ID",
                interactive=True
            )
            refresh_btn = gr.Button("üîÑ Refresh App List", size="sm")

        # --- RIGHT COLUMN: Chat Interface ---
        with gr.Column(scale=4):
            chatbot = gr.Chatbot(height=600, type="messages")
            msg = gr.Textbox(
                placeholder="Ask about failures, data skew, or specific stages...",
                container=False,
                scale=7
            )
            with gr.Row():
                clear = gr.ClearButton([msg, chatbot])
                submit_btn = gr.Button("Submit", variant="primary")

    # ==========================================
    # EVENT HANDLERS
    # ==========================================

    # 1. Ingest Button Click
    ingest_btn.click(
        fn=ui_ingest_logs, 
        inputs=[log_dir_input], 
        outputs=[ingest_status]
    )
    
    # 2. Refresh App List Button
    def update_dropdown():
        ids = get_all_app_ids()
        return gr.Dropdown(choices=["All Applications"] + ids, value="All Applications")
        
    refresh_btn.click(fn=update_dropdown, outputs=[app_dropdown])

    # 3. Chat Logic (Streaming)
    def chat_logic(user_msg, history, app_id):
        # --- FIX: Initialize the collection here ---
        collection = get_chroma_collection() 
        # -------------------------------------------

        # 1. Handle "Select One..." case logic
        target_app = app_id if app_id and app_id != "All Applications" else None

        # 2. Append user message immediately
        history = history + [{"role": "user", "content": user_msg}]
        yield "", history  # Update UI to show user message first

        # 3. Initialize an empty assistant response
        history.append({"role": "assistant", "content": ""})

        # 4. Consume the generator from get_gemini_response
        for partial_response in get_gemini_response(user_msg, collection, app_id=target_app):
            history[-1]["content"] = partial_response
            yield "", history

    # Bind Enter Key and Submit Button
    msg.submit(chat_logic, [msg, chatbot, app_dropdown], [msg, chatbot])
    submit_btn.click(chat_logic, [msg, chatbot, app_dropdown], [msg, chatbot])

# ==========================================
# 6. LAUNCH
# ==========================================
if __name__ == "__main__":
    # Launch with share=False to run locally
    demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True, inline=False)

* Running on local URL:  http://0.0.0.0:7860
* To create a public link, set `share=True` in `launch()`.
