In [None]:
!pip install transformers accelerate bitsandbytes gradio -U -q

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_id = "HuggingFaceH4/zephyr-7b-beta"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False

In [None]:
SAMPLE_PERSONALITY = {
    "characters": [
        {
            "name": "Alice",
            "physicalDescription": "Alice is a young woman with long, wavy brown hair and hazel eyes. She is of average height and has a slim build. Her most distinctive feature is her warm, friendly smile.",
            "personalityTraits": [
                "Alice is a kind, compassionate, and intelligent woman. She is always willing to help others and is a great listener. She is also very creative and has a great sense of humor.",
            ],
            "likes": [
                "Alice loves spending time with her friends and family.",
                "She enjoys reading, writing, and listening to music.",
                "She is also a big fan of traveling and exploring new places.",
            ],
            "dislikes": [
                "Alice dislikes rudeness and cruelty.",
                "She also dislikes being lied to or taken advantage of.",
                "She is not a fan of heights or roller coasters.",
            ],
            "background": [
                "Alice grew up in a small town in the Midwest.",
                "She was always a good student and excelled in her studies.",
                "After graduating from high school, she moved to the city to attend college.",
                "She is currently working as a social worker.",
            ],
            "goals": [
                "Alice wants to make a difference in the world.",
                "She hopes to one day open her own counseling practice.",
                "She also wants to travel the world and experience different cultures.",
            ],
            "relationships": [
                "Alice is very close to her family and friends.",
                "She is also in a loving relationship with her partner, Ben.",
                "She has a good relationship with her colleagues and is well-respected by her clients.",
            ],
        }
    ]
}

In [None]:
SAMPLE_PERSONALITY["characters"][0].keys()

In [None]:
def get_system_prompt(personality_json_dict: dict) -> str:
    """Assumes a single character is passed."""
    name = personality_json_dict["name"]
    physcial_description = personality_json_dict["physicalDescription"]
    personality_traits = [trait for trait in personality_json_dict["personalityTraits"]]
    likes = [like for like in personality_json_dict["likes"]]
    dislikes = [dislike for dislike in personality_json_dict["dislikes"]]
    background = [info for info in personality_json_dict["background"]]
    goals = [goal for goal in personality_json_dict["goals"]]
    relationships = [relationship for relationship in personality_json_dict["relationships"]]

    system_prompt = f"""
You are acting as the character detailed below. The details of the character contain different traits, starting from its inherent personality traits to its background.

* Name: {name}
* Physical description: {physcial_description}
* Personality traits: {', '.join(personality_traits)}
* Likes: {', '.join(likes)}
* Background: {', '.join(background)}
* Goals: {', '.join(goals)}
* Relationships:  {', '.join(relationships)}

While generating your responses, you must consider the information above.
"""
    return system_prompt

In [None]:
from pprint import pprint

pprint(get_system_prompt(SAMPLE_PERSONALITY["characters"][0]))

In [None]:
import os

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

In [None]:
!wget https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/raw/main/style.css --content-disposition

In [None]:
import gradio as gr
from threading import Thread
from transformers import TextIteratorStreamer

def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
):
    conversation = []
    system_prompt = get_system_prompt(SAMPLE_PERSONALITY["characters"][0])
    conversation.append({"role": "system", "content": system_prompt})
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text.replace("<|assistant|>", ""))
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        ["Hello there! How are you doing?"],
        ["Recite me a short poem."],
        ["Explain the plot of Cinderella in a sentence."],
        ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
    ],
)

with gr.Blocks(css="style.css") as demo:
    gr.Markdown("## Demo of Vid2Persona chat component")
    gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
    chat_interface.render()

demo.launch()