In [4]:
import os
import re
import subprocess
from pathlib import Path
import numpy as np
import faiss                       
from sentence_transformers import SentenceTransformer  
from langchain_groq import ChatGroq  

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
def compile_arduino_code(code: str, sketch_name: str = "MySketch"):
    folder = sketch_name
    filename = f"{sketch_name}.ino"
    os.makedirs(folder, exist_ok=True)
    file_path = os.path.join(folder, filename)

    with open(file_path, "w", encoding="utf-8") as f:
        f.write(code)

    try:
        result = subprocess.run(
            ["arduino-cli", "compile", "--fqbn", "arduino:avr:uno", folder],
            capture_output=True,
            text=True,
            check=False,
        )
    except FileNotFoundError:
        return False, "arduino-cli not found or not in PATH"

    log = "\n".join(result.stderr.splitlines()[:20]).strip()
    return result.returncode == 0, log

In [6]:
import os
from langchain_groq import ChatGroq

class LLMWrapper:
    def __init__(
        self,
        model_name: str = "llama3-70b-8192",
        temperature: float = 0.0,
        api_key: str | None = None,
        verbose: bool = True,
    ):
        self.model_name = model_name
        self.temperature = temperature
        self.api_key = api_key or os.getenv("GROQ_API_KEY")
        if not self.api_key:
            raise ValueError(
                "GROQ_API_KEY is missing. Set it as an env var or pass api_key="
            )

        self.llm = ChatGroq(
            groq_api_key=self.api_key,
            model_name=self.model_name,
            temperature=self.temperature,
        )
        if verbose:
            print(
                f"[LLMWrapper] Initialized Groq model '{self.model_name}' (temp={self.temperature})"
            )

    def get_llm(self):
        return self.llm

In [7]:
os.environ["GROQ_API_KEY"] = "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"

llm = LLMWrapper().get_llm()

[LLMWrapper] Initialized Groq model 'llama3-70b-8192' (temp=0.0)


In [8]:
def trim_context_snippets(snippets, tokenizer, max_context_tokens):
    combined = ""
    for snippet in snippets:
        temp = combined + "\n\n" + snippet
        tokens = tokenizer(temp, return_tensors="pt", truncation=False).input_ids
        if tokens.shape[1] > max_context_tokens:
            break
        combined = temp
    return combined

In [9]:
def load_arduino_examples_from_repo(paths):
    # Normalise to list
    if isinstance(paths, (str, os.PathLike)):
        paths = [str(paths)]

    ino_files = []
    for root_dir in paths:
        # Walk each directory separately so os.walk gets a str, not list
        for subdir, _, files in os.walk(root_dir):
            for file in files:
                if file.endswith(".ino"):
                    try:
                        with open(os.path.join(subdir, file),
                                  "r", encoding="utf-8", errors="ignore") as f:
                            ino_files.append(f.read()[:5000])
                    except Exception:
                        pass
    return ino_files

In [10]:
def build_faiss_index_from_arduino_examples(paths):
    """
    Build a FAISS L2 index from all .ino examples under *paths*.
    *paths* can be a single folder or a list/tuple of folders.
    """
    docs = load_arduino_examples_from_repo(paths)
    print(f"[FAISS] Loaded {len(docs)} examples")

    embedder   = SentenceTransformer("all-MiniLM-L6-v2")
    embeddings = embedder.encode(docs).astype("float32")
    dim        = embeddings.shape[1]

    index = faiss.IndexFlatL2(dim)
    index.add(embeddings)
    return docs, embedder, index

In [11]:
def retrieve_context(error_log, embedder, index, docs, top_k=4):
    query_vec = embedder.encode([error_log]).astype("float32")
    _, I = index.search(query_vec, top_k * 3)   # over-fetch
    snippets = []
    for idx in I[0]:
        doc = docs[idx]
        if ("TensorFlowLite" in doc or "LSM9DS1" in doc):
            snippets.append(doc)
        if len(snippets) == top_k:
            break
    return snippets


In [12]:
from prompt_template import correct_code
import re
_FENCE_RE = re.compile(r"```(?:\w+)?\s*\n?(.*?)```", re.DOTALL | re.MULTILINE)


def _regex_extract(text: str) -> str | None:
    m = _FENCE_RE.search(text)
    return m.group(1).strip() if m else None


def extract_clean_code(raw: str, sentinel: str = "<END_OF_CODE>") -> str:
    if sentinel in raw:
        return raw.split(sentinel, 1)[0].strip()
    block = _regex_extract(raw)
    return block if block else raw.strip()


In [13]:
def fix_code(*,buggy_code: str,compiler_error: str,context_snippets: list[str],llm,) -> str:
    prompt = correct_code(
        context="\n".join(context_snippets),
        code=buggy_code,
        error=compiler_error,
    )

    response = llm.invoke(
        prompt,
        temperature=0,
        stop=["<END_OF_CODE>"],
    )

    raw_text = getattr(response, "content", response)
    return extract_clean_code(raw_text) or buggy_code

In [16]:
def iterative_fix(original_code: str, max_attempts: int = 4) -> str:
    code = original_code
    llm = LLMWrapper().get_llm()

    # Build FAISS index once
    docs, embedder, index = build_faiss_index_from_arduino_examples([
    "arduino-examples/examples",
    "arduino-examples/tflm/examples",
    "arduino-examples/lsm9ds1/examples"
    ])

    for attempt in range(1, max_attempts + 1):
        print(f"\n🔁 Attempt {attempt} — compiling…")
        ok, log = compile_arduino_code(code)
        if ok:
            print("✅ Compilation succeeded.")
            return code

        print("❌ Compilation failed.\n🧾 Error log:\n", log)

        context_snippets = retrieve_context(log, embedder, index, docs)
        print("🤖 LLM fixing…")

        fixed_code = fix_code(
            buggy_code=code,
            compiler_error=log,
            context_snippets=context_snippets,
            llm=llm,
        )

        # Early exit if model made no change
        if fixed_code.strip() == code.strip():
            print("⚠️  LLM produced no changes; aborting early.")
            break

        code = fixed_code  

    print("❌ Max attempts reached. Returning last version.")
    return code

In [17]:
from pathlib import Path

buggy_code_path = Path("buggy_code.ino")
buggy_code = buggy_code_path.read_text(encoding="utf-8")

fixed = iterative_fix(buggy_code)
if fixed:
    print("\n—— Final working sketch ——\n")
    print(fixed)
else:
    print("No valid fix found after retries.")

[LLMWrapper] Initialized Groq model 'llama3-70b-8192' (temp=0.0)
[FAISS] Loaded 88 examples

🔁 Attempt 1 — compiling…
❌ Compilation failed.
🧾 Error log:
 C:\Users\medaminekh\semester_project\MySketch\MySketch.ino:10:1: error: 'tflite' does not name a type; did you mean 'fwrite'?
 tflite::Model* model = nullptr;
 ^~~~~~
 fwrite
C:\Users\medaminekh\semester_project\MySketch\MySketch.ino:11:1: error: 'tflite' does not name a type; did you mean 'fwrite'?
 tflite::Tensor* inputTensor = nullptr;
 ^~~~~~
 fwrite
C:\Users\medaminekh\semester_project\MySketch\MySketch.ino:12:1: error: 'tflite' does not name a type; did you mean 'fwrite'?
 tflite::Tensor* outputTensor = nullptr;
 ^~~~~~
 fwrite
C:\Users\medaminekh\semester_project\MySketch\MySketch.ino: In function 'void setup()':
C:\Users\medaminekh\semester_project\MySketch\MySketch.ino:16:3: error: 'LSM9DS1' was not declared in this scope
   LSM9DS1.begin()
   ^~~~~~~
C:\Users\medaminekh\semester_project\MySketch\MySketch.ino:18:7: error: 'mo