In [None]:
import os
import requests
import json
import openai
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
import time
import re

# Groq API Key (set as environment variable - REQUIRED)
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
if not GROQ_API_KEY:
    raise ValueError("GROQ_API_KEY environment variable not set.")

# Groq API Client
client = openai.OpenAI(
    base_url="https://api.groq.com/openai/v1",
    api_key=GROQ_API_KEY
)

# Mermaid.live URL
MERMAID_LIVE_URL = "https://mermaid.live/edit"

def generate_mermaid_code(prompt):
    """Generates Mermaid.js code using the Groq API, ensuring no round brackets are used."""

    try:
        response = client.chat.completions.create(
            model="mixtral-8x7b-32768",
            messages=[
                {"role": "system", "content": "You are an AI that generates valid Mermaid.js code for flowcharts. "
                                                  "Never use round brackets () in the generated code. "
                                                  "Use only valid Mermaid.js syntax."},
                {"role": "user", "content": f"""Generate Mermaid.js code for the following description:


                Create a valid Mermaid.js code representing four flowcharts arranged in a 2x2 grid. The flowcharts must adhere to the following structure:

                1. **Top Left Flowchart:** Summarizes all key sections of a typical research paper, including Introduction, Methodology, Results, Discussion, and Conclusion.
                2. **Top Right Flowchart:** Represents a detailed breakdown of the Introduction section, highlighting problem statement, objectives, and significance.
                3. **Bottom Left Flowchart:** Explores the Methodology section, illustrating key steps such as data collection, experimental design, and analysis methods.
                4. **Bottom Right Flowchart:** Visualizes Results and Discussion, including findings, insights, and future work suggestions.

                - Do **not** use round brackets `()`.
                - Use only valid Mermaid.js syntax.
                - Respond **only** with the code. Do not add explanations.

                """}
            ],
            max_tokens=500
        )

        # Extract the generated Mermaid code
        mermaid_code = response.choices[0].message.content.strip()

        # Remove any accidental round brackets just in case
        mermaid_code = re.sub(r"[()]", "", mermaid_code)

        return mermaid_code

    except Exception as e:
        print(f"\u274c Error communicating with Groq API: {e}")
        return None

def generate_and_download_image(mermaid_code):
    """Opens Mermaid.live, enters the Mermaid code, waits for rendering, and takes a screenshot."""
    try:
        options = webdriver.ChromeOptions()
        options.add_argument("--headless")  # Run Chrome in headless mode

        driver = webdriver.Chrome(options=options)
        driver.get(MERMAID_LIVE_URL)
        time.sleep(5)  # Wait for page to load

        # Locate the text area and enter the Mermaid code
        text_area = WebDriverWait(driver, 10).until(
            EC.presence_of_element_located((By.CSS_SELECTOR, "textarea"))
        )
        text_area.clear()
        text_area.send_keys(mermaid_code)
        time.sleep(5)  # Allow rendering time

        # ✅ Wait for the flowchart SVG to be generated
        WebDriverWait(driver, 15).until(
            EC.presence_of_element_located((By.CSS_SELECTOR, "svg"))
        )

        # ✅ Screenshot only the flowchart area instead of the full page
        svg_element = driver.find_element(By.CSS_SELECTOR, "svg")
        output_path = "mermaid_diagram.png"
        svg_element.screenshot(output_path)

        print(f"\u2705 Mermaid diagram saved as {output_path}")

        driver.quit()

    except Exception as e:
        print(f"\u274c Error generating diagram: {e}")

if _name_ == "_main_":
    user_prompt = "Generate flowcharts for research paper components"
    mermaid_code = generate_mermaid_code(user_prompt)

    if mermaid_code:
        print("Generated Mermaid code:\n", mermaid_code)
        generate_and_download_image(mermaid_code)
    else:
        print("\u274c Failed to generate Mermaid code.")