# Code World Model Debugger Demo

In this demo, we use CWM's execution tracing capabilities to build an interactive debugger for Python programs.

To get started, spin up an `fgserve` instance. For example, using four GPUs:
```bash
torchrun --nproc-per-node 4 \
    -m serve.fgserve \
    gen_args.host_cache_gb=80 \
    config=serve/configs/cwm.yaml \
    checkpoint_dir=<path/to/checkpoint>
```

Afterwards, update the variables in the next cell and run the notebook to get a small UI.
Adjust the contents of the `code` variable as you like.

A few notes on usage:
- `gen_args.host_cache_gb` will speed up inference but is not required
- "Reset" will reset the debugger, prompting the model to produce a new initial step
- "Step Out" will prompt the model repeatedly until it predicts a return from the current function
- You can change local variables

In [None]:
# Point these to an fgserve instance
FGSERVE_HOST = "localhost"
FGSERVE_PORT = 5678

# Point this to the CWM tokenizer file
TOKENIZER_PATH = "/path/to/checkpoint/tokenizer.model"

In [None]:
import os
import sys
from pathlib import Path

cwm_dir = Path(os.getcwd()).parent
sys.path.append(str(cwm_dir))
from cwm.text.tokenizers import build_tokenizer
from demos.cwmdbg import CWMDebugger, CWMTraceEvent

In [None]:
import ipywidgets as widgets
from IPython.display import display
from pygments import highlight
from pygments.formatters import HtmlFormatter
from pygments.lexers import PythonLexer

In [None]:
class InteractiveDebugger:
    def __init__(self, code: str, host: str = FGSERVE_HOST, port: int = FGSERVE_PORT):
        tokenizer = build_tokenizer(
            name="cwm_instruct",
            path=TOKENIZER_PATH,
        )
        self.dbg = CWMDebugger(code, tokenizer, host, port, temperature=0.6)

        self._setup_ui()

    def _setup_ui(self) -> None:
        # Setup widgets
        self.text_area_src = widgets.HTML(
            value="",
            disabled=True,
            layout=widgets.Layout(width="50%", height="400px"),
            style={"font_family": "monospace"},
        )
        self.text_area_vars = widgets.Textarea(
            value="",
            disabled=False,
            layout=widgets.Layout(width="50%", height="400px"),
            style={"font_family": "monospace"},
        )
        self.next_button = widgets.Button(
            description="Next",
            button_style="info",
        )
        self.step_button = widgets.Button(
            description="Step",
            button_style="info",
        )
        self.step_out_button = widgets.Button(
            description="Step Out",
            button_style="info",
        )
        self.step_back_button = widgets.Button(
            description="Step Back",
            button_style="info",
        )
        self.reset_button = widgets.Button(
            description="Reset",
            button_style="danger",
        )

        displays = widgets.HBox([self.text_area_src, self.text_area_vars])
        buttons = widgets.HBox(
            [
                self.next_button,
                self.step_button,
                self.step_out_button,
                self.step_back_button,
                self.reset_button,
            ]
        )
        self.ui_container = widgets.VBox([buttons, displays])

        # Setup callbacks
        self.next_button.on_click(self._on_next_clicked)
        self.step_button.on_click(self._on_step_clicked)
        self.step_out_button.on_click(self._on_step_out_clicked)
        self.step_back_button.on_click(self._on_step_back_clicked)
        self.reset_button.on_click(self._on_reset_clicked)

    # UI action decorator that handles setting variables, updating UI elements,
    # and catching exceptions
    def _ui_action(func):
        import traceback

        def inner(self, _):
            try:
                self._apply_vars()
                func(self, _)
                self._update_ui()
            except Exception:
                self.text_area_vars.value = traceback.format_exc()

        return inner

    # UI callbacks
    @_ui_action
    def _on_next_clicked(self, _):
        self.dbg.next()

    @_ui_action
    def _on_step_clicked(self, _):
        self.dbg.step()

    @_ui_action
    def _on_step_out_clicked(self, _):
        self.dbg.step_out()

    @_ui_action
    def _on_step_back_clicked(self, _):
        self.dbg.step_back()

    @_ui_action
    def _on_reset_clicked(self, _):
        self.dbg.reset()

    def _update_ui(self):
        frame = self.dbg.current_frame

        # Update source code, highlight current line
        formatter = HtmlFormatter(
            style="friendly",
            linenos="table",
            hl_lines=[frame.line_no] if frame.line_no else [],
        )
        css_styles = formatter.get_style_defs(".highlight")
        html = highlight(
            self.dbg.source,
            PythonLexer(stripnl=False),
            formatter,
        )
        full_html = f'<style>{css_styles}</style><div class="highlight">{html}</div>'
        self.text_area_src.value = full_html

        # Update variables
        if frame.event in (CWMTraceEvent.RETURN, CWMTraceEvent.EXCEPTION):
            self.text_area_vars.value = frame.arg
        else:
            var_str = ""
            for k, v in self.dbg.local_vars.items():
                var_str += f"{k}: {v}\n"
            self.text_area_vars.value = var_str
        self.text_area_vars._orig_value = self.text_area_vars.value

    def _apply_vars(self):
        if self.text_area_vars.value == self.text_area_vars._orig_value:
            # No change
            return

        # Apply variables from input to current frame
        lv = {}
        for ln in self.text_area_vars.value.splitlines():
            try:
                k, v = ln.split(": ", maxsplit=1)
            except Exception:
                return  # invalid format, ignore
            lv[k] = v
        self.dbg.current_frame.local_vars = lv
        # Flush cache for subsequent prompting
        self.dbg.current_frame._tokens = None

    def display(self) -> None:
        self.dbg.reset()
        self._update_ui()
        display(self.ui_container)

In [None]:
# This is the code we want to debug.
# "<< START_OF_TRACE" marks the entry point of the trace and should point to a function
code = """\
def count_letters(s, letter):
    n = 0
    for c in s:
        n += int(c == letter)
    return n

def format_answer(word, letter, count):
    parts = [
        "Found",
        f"{count:04d}",
        "occurrences of the letter",
        letter,
        "in",
        word
    ]
    return " ".join(parts)

def f(c):  # << START_OF_TRACE
    word = "strawberry"
    num = count_letters(word, c)
    ans = format_answer(word, c, num)
    return ans
"""

In [None]:
idbg = InteractiveDebugger(code)
idbg.display()