In [None]:
import asyncio
import panel as pn
import param

from panel.custom import JSComponent, ESMEvent

pn.extension(template='material')

This example demonstrates how to wrap an external library (specifically [WebLLM](https://github.com/mlc-ai/web-llm)) as a `JSComponent` and interface it with the `ChatInterface`.

In [None]:

MODELS = {
    'Mistral-7b-Instruct': 'Mistral-7B-Instruct-v0.3-q4f16_1-MLC',
    'SmolLM': 'SmolLM-360M-Instruct-q4f16_1-MLC',
    'Gemma-2b': 'gemma-2-2b-it-q4f16_1-MLC',
    'Llama-3.1-8b-Instruct': 'Llama-3.1-8B-Instruct-q4f32_1-MLC-1k'
}

class WebLLM(JSComponent):

    loaded = param.Boolean(default=False, doc="""
        Whether the model is loaded.""")

    model = param.Selector(default='SmolLM-360M-Instruct-q4f16_1-MLC', objects=MODELS)

    temperature = param.Number(default=1, bounds=(0, 2))

    load_model = param.Event()
    
    _esm = """
    import * as webllm from "https://esm.run/@mlc-ai/web-llm";

    const engines = new Map()

    export async function render({ model }) {
      model.on("msg:custom", async (event) => {
        console.log(event)
        if (event.type === 'load') {
          if (!engines.has(model.model)) {
            engines.set(model.model, await webllm.CreateMLCEngine(model.model))
          }
          model.loaded = true
        } else if (event.type === 'completion') {
          const engine = engines.get(model.model)
          if (engine == null) {
            model.send_msg({'finish_reason': 'error'})
          }
          const chunks = await engine.chat.completions.create({
            messages: event.messages,
            temperature: model.temperature ,
            stream: true,
          })
          for await (const chunk of chunks) {
            model.send_msg(chunk.choices[0])
          }
        }
      })
    }
    """

    def __init__(self, **params):
        super().__init__(**params)
        self._buffer = []

    @param.depends('load_model', watch=True)
    def _load_model(self):
        self.loading = True
        self._send_msg({'type': 'load'})

    @param.depends('loaded', watch=True)
    def _loaded(self):
        self.loading = False
        self.param.load_model.constant = True

    @param.depends('model', watch=True)
    def _update_load_model(self):
        self.param.load_model.constant = False

    def _handle_msg(self, msg):
        self._buffer.insert(0, msg)

    async def create_completion(self, msgs):
        self._send_msg({'type': 'completion', 'messages': msgs})
        latest = None
        while True:
            await asyncio.sleep(0.01)
            if not self._buffer:
                continue
            choice = self._buffer.pop()
            yield choice
            reason = choice['finish_reason']
            if reason == 'error':
                raise RuntimeError('Model not loaded')
            elif reason:
                return

    async def callback(self, contents: str, user: str):
        if not self.loaded:
            yield f'Model `{self.model}` is loading.' if self.param.load_model.constant else 'Load the model'
            return
        message = ""
        async for chunk in llm.create_completion([{'role': 'user', 'content': contents}]):
            message += chunk['delta'].get('content', '')
            yield message

    def menu(self):
        return pn.Column(
            pn.widgets.Select.from_param(self.param.model, sizing_mode='stretch_width'),
            pn.widgets.FloatSlider.from_param(self.param.temperature, sizing_mode='stretch_width'),
            pn.widgets.Button.from_param(
                self.param.load_model, sizing_mode='stretch_width',
                loading=self.param.loading
            )
        )

Having implemented the `WebLLM` component we can render the WebLLM UI:

In [None]:
llm = WebLLM()

pn.Column(llm.menu(), llm).servable(area='sidebar')

And connect it to a `ChatInterface`:

In [None]:
chat_interface = pn.chat.ChatInterface(callback=llm.callback)
chat_interface.send(
    "Load a model and start chatting.",
    user="System",
    respond=False,
)

chat_interface.servable(title='WebLLM')