In [1]:
import os
del os.environ['CUDA_VISIBLE_DEVICES']

KeyError: 'CUDA_VISIBLE_DEVICES'

# Simple phi-2 chatbot

This is a simple demo that you can use to chat with [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) model. It is based on [this sample code](https://www.gradio.app/guides/creating-a-chatbot-fast#example-using-a-local-open-source-llm-with-hugging-face), but it's been adapted to run phi-2. For more information on changes and adaptations, please consult the README file.

In [1]:
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", device_map="auto", torch_dtype="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)

print("Your device is", device)

Loading checkpoint shards: 100%|██████████████████| 2/2 [00:00<00:00,  2.23it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Your device is cuda


## Configurations

Set your preferrend language below. Allowed values are:

+ `EN` for English, and
+ `ES` for Spanish

If the variable is set to any other value, it'll default to English. Mind that as stated on the [model page](https://huggingface.co/microsoft/phi-2), phi2 has been designed to work primarily with English.

> Language Limitations: The model is primarily designed to understand standard English. Informal English, slang, or any other languages might pose challenges to its comprehension, leading to potential misinterpretations or errors in response.

In [19]:
LANG = "EN" # either EN and ES are valid codes

In [20]:
if LANG == "ES":
    HUMAN_NAME = "Usuario"
    BOT_NAME = "Asistente"
else:
    HUMAN_NAME = "User"
    BOT_NAME = "Assistant"

If you want, you can give the model some context using the `CONTEXT` variable. This will be prepended to the whole conversation.

In [25]:
if LANG == "ES":
    CONTEXT = f"El siguiente texto es una conversación amistosa entre {HUMAN_NAME} y {BOT_NAME} en español."
else:
    CONTEXT = f"The following is a friendly conversation between {HUMAN_NAME} and {BOT_NAME} in English."

## Run and try the model

The following cell contains the main code for the model and runs the Gradio interface.

In [30]:
from typing import List

class StopOnTokens(StoppingCriteria):
    """Stops the model if it produces an 'end of text' token"""
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [50256] # <|endoftext|>
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False

class StopOnNames(StoppingCriteria):
    """
    Stops the model when it starts hallucinating future turns of the conversation

    It stops the token generation when we find a token sequence "\n<name>:", for
    example "\nUser:" or "\nAssistant:".
    """

    EOL_TOKEN = 198
    COLON_TOKEN = 25
    
    def __init__(self, tokenized_names: List[List[int]]):
        self.tokenized_names = tokenized_names
    
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for tokens in self.tokenized_names:
            template = [self.EOL_TOKEN, *tokens, self.COLON_TOKEN]
            if input_ids[0][-len(template):].tolist() == template:
                return True
        return False

chat_name_pattern_end = r'\n.+:$' # matches substrings like `\nUser:` at the end of the string

def predict(message, history):
    history_transformer_format = history + [[message, ""]]
    stop_on_tokens = StopOnTokens()
    stop_on_names = StopOnNames([tokenizer.encode(HUMAN_NAME), tokenizer.encode(BOT_NAME)])

    messages = "".join(["".join(
        [f"\n{HUMAN_NAME}:"+item[0], f"\n{BOT_NAME}:"+item[1]]
    ) for item in history_transformer_format]).strip()
    messages = CONTEXT + '\n' + messages

    model_inputs = tokenizer([messages], return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=256,
        do_sample=True,
        top_p=0.95,
        top_k=1000,
        temperature=1.0,
        num_beams=1,
        stopping_criteria=StoppingCriteriaList([stop_on_tokens, stop_on_names])
        )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    partial_message = ""
    for new_token in streamer:
        partial_message += new_token
        match = re.search(chat_name_pattern_end, partial_message)
        if match:
            partial_message = partial_message[:-len(match.group())]
        yield partial_message
        
gr.ChatInterface(predict).queue().launch()

Running on local URL:  http://127.0.0.1:7875

Thanks for being a Gradio user! If you have questions or feedback, please join our Discord server and chat with us: https://discord.gg/feTf9x3ZSB

To create a public link, set `share=True` in `launch()`.


