In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer,TextStreamer
import os
import keras
import keras_nlp
# Hide warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

  from .autonotebook import tqdm as notebook_tqdm
2024-06-13 21:21:33.200300: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
model_path='./Gemma-1.1-2b-instruct-dollyfinetuned/'
tokenizer = AutoTokenizer.from_pretrained(model_path, device_map='auto')
Gemma = AutoModelForCausalLM.from_pretrained(model_path,device_map= "auto")#low_cpu_mem_usage = True,

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.76s/it]


Type 1 inference

In [15]:
instruction="Hello"
context='Context: You are a pirate, thus you speak in pirate!'
text = f"### Instruction\n{instruction}\n\n ### Context\n{context}\n\n ### Answer\n"
inputs = tokenizer(text, return_tensors="pt").to("cuda")

outputs = Gemma.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0]))

<bos>### Instruction
Hello

 ### Context
Context: You are a pirate, thus you speak in pirate!

 ### Answer
Ahoy there mateys! I be Captain Bartholomew Blackbeard, and I be the most fearsome pirate in the seven seas! I be known for my fearsome beard, my fearsome treasure, and my fearsome ship, the Black Pearl!


Type 2 inference

In [16]:
from transformers import pipeline
streamer = TextStreamer(tokenizer)
# Example chat in the specified format
chat = [
    {"role": "user", "content": "Hello"},
]
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

pipe = pipeline(
    "text-generation",
    model=Gemma,
    tokenizer=tokenizer,
    streamer=streamer,
)

output = pipe(chat, max_new_tokens=50)
# Extract generated text from pipeline output
generated_text = output[0]['generated_text']

<bos><start_of_turn>user
Hello<end_of_turn>
<start_of_turn>model
Hello! 👋 I'm glad you're here. What can I do for you today? 😊<eos>


UI

In [32]:
from IPython.display import Markdown
import textwrap

class ChatState:
    """
    Manages the conversation history for a turn-based chatbot
    Follows the turn-based conversation guidelines for the Gemma family of models
    documented at https://ai.google.dev/gemma/docs/formatting
    """

    __START_TURN_USER__ = "<start_of_turn>user\n"
    __START_TURN_MODEL__ = "<start_of_turn>model\n"
    __END_TURN__ = "<end_of_turn>\n"

    def __init__(self, model: Gemma, tokenizer: tokenizer, system=""):
        """
        Initializes the chat state.

        Args:
            model: The language model to use for generating responses.
            tokenizer: The tokenizer associated with the model.
            system: (Optional) System instructions or bot description.
        """
        self.model = model
        self.tokenizer = tokenizer
        self.system = system
        self.history = []

    def add_to_history_as_user(self, message):
        """
        Adds a user message to the history with start/end turn markers.
        """
        self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)

    def add_to_history_as_model(self, message):
        """
        Adds a model response to the history with start/end turn markers.
        """
        self.history.append(self.__START_TURN_MODEL__ + message + self.__END_TURN__)

    def get_history(self):
        """
        Returns the entire chat history as a single string.
        """
        return "".join(self.history)

    def get_full_prompt(self):
        """
        Builds the prompt for the language model, including history and system description.
        """
        prompt = self.get_history() + self.__START_TURN_MODEL__
        if len(self.system) > 0:
            prompt = self.system + "\n" + prompt
        return prompt

    def send_message(self, message):
        """
        Handles sending a user message and getting a model response.

        Args:
            message: The user's message.

        Returns:
            The model's response.
        """
        self.add_to_history_as_user(message)
        prompt = self.get_full_prompt()
        inputs = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device)
        outputs = self.model.generate(inputs,max_length=1024, num_return_sequences=1)
        result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        self.add_to_history_as_model(result)
        return result
    

def display_chat(prompt, text):
    # Remove user and model tags from the text
    text = text.replace('<start_of_turn>user\n', '').replace('<start_of_turn>model\n', '')

    # Format the prompt and text for display in Markdown
    formatted_prompt = f"<font size='+1' color='brown'>🙋‍♂️<blockquote>{prompt}</blockquote></font>"
    formatted_text = f"<font size='+1' color='teal'>🤖\n\n{text}\n</font>"

    # Return the formatted Markdown
    return Markdown(formatted_prompt + formatted_text)

Type 3 inference

In [33]:
# Initialize ChatState with Gemma
chat = ChatState(Gemma,tokenizer)
message = "Hello there"
display_chat(message, chat.send_message(message))

<font size='+1' color='brown'>🙋‍♂️<blockquote>Hello there</blockquote></font><font size='+1' color='teal'>🤖

user
Hello there
model
Hello! 👋 I'm glad you're here. What can I do for you today? 😊
</font>

Type 4, UI

In [30]:
import tkinter as tk
from tkinter import scrolledtext
model=Gemma
class ChatApp(tk.Tk):
    def __init__(self, model, tokenizer):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.chat_history = ""
        
        self.title("Chatbot UI")
        self.geometry("600x400")
        
        # Create chat history display
        self.chat_display = scrolledtext.ScrolledText(self, width=60, height=20)
        self.chat_display.pack(pady=10)
        
        # Create message input field
        self.message_entry = tk.Entry(self, width=50)
        self.message_entry.pack(pady=10)
        
        # Create send button
        self.send_button = tk.Button(self, text="Send", command=self.send_message)
        self.send_button.pack()
        
        # Bind Enter key to send message
        self.bind("<Return>", lambda event: self.send_message())
        
        # Start conversation loop
        self.conversation_loop()
    
    def conversation_loop(self):
        self.display_message("Bot", "Hello! 👋 I'm glad you're here. What can I do for you today?")
    
    def send_message(self):
        user_message = self.message_entry.get().strip()
        if user_message:
            # Display user message in chat history
            self.display_message("User", user_message)
            
            if user_message.lower() == "exit":
                self.display_message("Bot", "Goodbye! 👋")
                self.quit()  # Exit the application
            
            # Generate bot's response
            bot_response = self.generate_response(user_message)
            
            # Display bot's response in chat history
            self.display_message("Bot", bot_response)
            
            # Clear message entry field
            self.message_entry.delete(0, tk.END)
    
    def generate_response(self, message):
        # Format chat template
        chat = [{"role": "user", "content": message}]
        prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
        
        # Generate response using pipeline
        output = pipe(prompt, max_new_tokens=100)
        generated_text = output[0]['generated_text']
        
        return generated_text
    
    def display_message(self, role, message):
        # Remove special tokens from the message before displaying
        cleaned_message = message.replace("<bos>", "").replace("<start_of_turn>", "").replace("<end_of_turn>", "").replace("<eos>", "").strip()
        formatted_message = f"{role}: {cleaned_message}\n"
        self.chat_display.insert(tk.END, formatted_message)
        self.chat_display.see(tk.END)  # Scroll to the end of the chat display

# Create the ChatApp instance
app = ChatApp(model, tokenizer)
app.mainloop()

<bos><start_of_turn>user
Howdy partner<end_of_turn>
<start_of_turn>model
Howdy partner! 👋 How can I help you today? 😊 What's on your mind? 😊<eos>
<bos><start_of_turn>user
What the dog doin?<end_of_turn>
<start_of_turn>model
I do not have access to real-time information, therefore I am unable to provide you with the current activity of a dog. For the most up to date information, please check the official website of the dog or check the news for the latest updates.<eos>
