<a href="https://colab.research.google.com/github/menouarazib/InformationRetrievalInNLP/blob/master/PyGemma_Assistant.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **PyGemma Assistant**

In this notebook, we will demonstrate the creation of an AI Assistant designed to answer Python-related questions. The responses will be generated by a model named `pygemma`. `pygemma` is a fine-tuned version of the [Gemma 2b model](https://www.kaggle.com/models/google/gemma/) and has been trained on a publicly available Python dataset. It is specifically designed to assist developers by providing answers to common questions about the Python programming language, making `pygemma` an invaluable resource for developers seeking Python-related assistance on Kaggle.

More details about `pygemma` can be found [on Kaggle](https://www.kaggle.com/models/menouarazib/pygemma-2b).

To create an AI Assistant using `pygemma`, we need to follow these major steps:

1. **Install the required libraries**
2. **Load `pygemma` from its Kaggle location**
3. **Create the AI assistant which utilizes `pygemma` to generate responses**
4. **Test the AI assistant with some questions in Python**

## 1. **Install the required libraries**

The `transformers` library is used to load `pygemma` from its Kaggle location.

The `accelerate` library is a great tool for running your model on a single or multiple GPUs. It simplifies launching distributed and mixed-precision inference.

The `rich` library in Python is used for rich text and beautiful formatting to the terminal. It can colorize console output, create tables, render markdown, syntax highlight source code, and more. It's a great tool for making your console output more readable and visually appealing.

In [None]:
!pip install -q transformers
!pip install -q accelerate
!pip install -q rich

## 2. **Load `pygemma` from its Kaggle location**

1. **Import necessary libraries**: The `transformers` library is imported for its `AutoTokenizer` and `AutoModelForCausalLM` classes. The `torch` library is imported for its CUDA and tensor functionalities.

2. **Check CUDA availability**: The code checks if CUDA is available. If it is, it sets the device to 'cuda'. It also checks if the CUDA device has a new architecture (compute capability >= 8).

3. **Set data type**: The default data type for tensors is set to `float32`. If the device is 'cuda', the data type is changed to `float16`. If the CUDA device has a new architecture, the data type is further changed to `bfloat16`.

4. **Load pretrained model and tokenizer**: The path to the pretrained model is defined. The tokenizer and the model for causal language modeling are loaded from the pretrained model. The model is automatically mapped to the available device, and the tensor data type is set according to the device. The model is loaded with settings optimized for low CPU memory usage, and only local files are used (no downloads).

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Initialize a flag for CUDA's new architecture
cuda_new_architecture = False

# Check if CUDA is available
if torch.cuda.is_available():
    print("CUDA is available.")
    # Check if the CUDA device has a new architecture (compute capability >= 8)
    cuda_new_architecture = torch.cuda.get_device_capability()[0] >= 8
    # Set the device to 'cuda'
    device = 'cuda'
    # If the CUDA device does not have a new architecture, print an empty line
    if not cuda_new_architecture:
        print('The CUDA device has a new architecture (compute capability >= 8)')
else:
    # If CUDA is not available, set the device to 'cpu'
    device = 'cpu'

# Set the default torch dtype to float32
torch_dtype = torch.float32

# If the device is 'cuda', change the torch dtype to float16
if device == 'cuda':
    torch_dtype = torch.float16
    # If the CUDA device has a new architecture, change the torch dtype to bfloat16
    if cuda_new_architecture:
        torch_dtype = torch.bfloat16

# Define the path to the pretrained model
kaggle_model = "/kaggle/input/pygemma-2b/transformers/v1/2"

# Load the tokenizer from the pretrained model
pygemma_tokenizer = AutoTokenizer.from_pretrained(kaggle_model, local_files_only=True)
# Load the model for causal language modeling from the pretrained model
# Set the device map to 'auto' to automatically map the model to the available device
# Set the torch dtype according to the device
# Set local_files_only to True to only use local files and not download anything
# Set low_cpu_mem_usage to True to optimize for low CPU memory usage
pygemma_model = AutoModelForCausalLM.from_pretrained(kaggle_model, device_map="auto", torch_dtype=torch_dtype,
                                                     local_files_only=True, low_cpu_mem_usage=True)

## 3. **Create the AI assistant which utilizes `pygemma` to generate responses**

### We create utility classes for PyGemma Assistant:

In [None]:
from transformers import StoppingCriteria
import torch
from typing import List

class StoppingCriteriaCustom(StoppingCriteria):
    """
    Custom stopping criteria for text generation to prevent the generation of unnecessary content.

    Args:
        input_len (int): The length of the input sequence.
    """

    def __init__(self, input_len: int, stops_ids: List[int]):
        super().__init__()
        self.input_len = input_len
        # If these tokens are encountered during generation, the process should stop.
        # This is to prevent the generation of unnecessary content.
        # ids = " any further questions."
        self.stops_ids = stops_ids

    def _check_sequence(self, tensor: torch.LongTensor):
        tensor_list = tensor.tolist()[self.input_len:]
        sequence_length = len(self.stops_ids)
        try:
            start_index = tensor_list.index(self.stops_ids[0])
        except ValueError:
            return False
        for i in range(start_index, len(tensor_list) - sequence_length + 1):
            if tensor_list[i:i + sequence_length] == self.stops_ids:
                return True
        return False

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        """
        Checks if a specific sequence appears in the generated text after the input sequence.

        Args:
            input_ids (torch.LongTensor): The IDs of the tokens in the generated text.
            scores (torch.FloatTensor): The scores of the tokens in the generated text.

        Returns:
            bool: True if the specific sequence is found after the input sequence, False otherwise.
        """
        return self._check_sequence(input_ids[0])

In [None]:
from transformers import StoppingCriteriaList
import torch
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast

from typing import List, Tuple, Dict


class PyGemmaResponseGenerator:
    """
    A class that uses a PyGemma model to generate responses to user queries in Python.
    """

    def __init__(self, pygemma: GemmaForCausalLM, pygemma_tokenizer: GemmaTokenizerFast, available_device: str,
                 do_sample: bool = False,
                 temperature: float = 0.2, max_new_tokens: int = 768, top_k: int = 50, top_p: float = 0.99):
        """
        Initializes the PyGemmaResponseGenerator.

        Args:
            pygemma (GemmaForCausalLM): The PyGemma model to use for generating replies.
            pygemma_tokenizer (GemmaTokenizerFast): The PyGemma tokenizer to use for preprocessing text.
            available_device (str): The device to be used for computations ('cuda' if CUDA is available, otherwise 'cpu').
            do_sample (bool): Controls the type of decoding strategy used for text generation.
                If set to False, the model uses greedy decoding that picks the token with the highest probability
                as the next token.
                If set to True, the model uses sampling, meaning it samples the next token from the output
                distribution. The choice of decoding strategy can significantly impact the generated text. Greedy decoding
                tends to generate repetitive and overly confident text, while multinomial sampling can produce more diverse
                and realistic text, albeit at the risk of making more mistakes.
            temperature (float, optional): The temperature to use when sampling from the model's output distribution.
                Defaults to 0.7.
            max_new_tokens (int, optional): The maximum number of tokens to generate. Defaults to 512.
            top_k (int, optional): The number of highest probability vocabulary tokens to keep for top-k-filtering.
             Defaults to 50.
            top_p (float, optional): Only the most probable tokens with probabilities that add up to top_p or higher are
             kept for generation. Defaults to 0.95.
        """
        # Initialize the PyGemma model and tokenizer
        self.pygemma = pygemma
        self.pygemma_tokenizer = pygemma_tokenizer

        # Set the device for computations
        self.available_device = available_device

        # Set the parameters for text generation
        self.do_sample = do_sample
        self.temperature = temperature
        self.max_new_tokens = max_new_tokens
        self.top_k = top_k
        self.top_p = top_p

        # Set the system prompt for the PyGemma model
        self.system_prompt = """Welcome to PyGemma, your AI-powered Python assistant. I'm here to assist you with
        common questions about the Python programming language. Let's dive into Python!"""

        # If these tokens are encountered during generation, the process should stop.
        # This is to prevent the generation of unnecessary content.
        # ids = " any further questions."
        self.stops_ids = [1089, 4024, 3920, 235265]

        # The model's maximum input length used during the finetuning process.
        self.max_input_token_length = 3072

    def _add_prompt_to_chat(self, chat_history: List[Tuple[str, str]], prompt: str) -> List[Dict[str, str]]:
        """
        Adds the new user prompt to chat history and format them into the format required by PyGEMMA.

        Args:
            chat_history (List[Tuple[str, str]]): The chat history between the user and the assistant.
            prompt (str): The new user prompt to which a reply is needed.

        Returns:
            List[Dict[str, str]]: The formatted conversation.
        """
        conversation = [{"role": "system", "content": self.system_prompt}]
        for user, assistant in chat_history:
            conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])

        conversation.append({"role": "user", "content": prompt})

        return conversation

    def _tokenize(self, conversation: List[Dict[str, str]]) -> torch.Tensor:
        """
        Tokenizes the conversation.

        Args:
            conversation (List[Dict[str, str]]): The conversation to tokenize.

        Returns:
            torch.Tensor: The tokenized conversation.
        """

        input_ids = self.pygemma_tokenizer.apply_chat_template(conversation, return_tensors="pt",
                                                               add_generation_prompt=True).to(self.available_device)

        # If the conversation is too long, trim it to fit within model's maximum input length
        if input_ids.shape[1] > self.max_input_token_length:
            input_ids = input_ids[:, -self.max_input_token_length:]
            print(f"Trimmed input from conversation as it exceeded {self.max_input_token_length} tokens.")

        return input_ids

    def generate_response(self, chat_history: List[Tuple[str, str]], prompt: str) -> str:
        """
        Generates a response to the given prompt.

        Args:
            prompt (str): The prompt to which a reply is needed.
            chat_history (List[Tuple[str, str]]): The chat history between the user and the assistant.

        Returns:
            str: The generated reply.
        """
        # Format the chat history and the new user prompt
        conversation = self._add_prompt_to_chat(chat_history=chat_history, prompt=prompt)

        # Tokenize the conversation
        input_ids = self._tokenize(conversation=conversation)

        stopping_criteria = StoppingCriteriaList([StoppingCriteriaCustom(len(input_ids[0]), self.stops_ids)])

        # Generate a reply
        if self.do_sample:
            outputs = self.pygemma.generate(input_ids, max_new_tokens=self.max_new_tokens, do_sample=self.do_sample,
                                            temperature=self.temperature,
                                            top_p=self.top_p, top_k=self.top_k, stopping_criteria=stopping_criteria)
        else:
            outputs = self.pygemma.generate(input_ids, max_new_tokens=self.max_new_tokens,
                                            stopping_criteria=stopping_criteria, forced_eos_token_id=self.stops_ids)

        # The 'outputs' tensor contains both the input tokens and the newly generated tokens.
        # To get only the generated tokens, we slice the tensor to exclude the input tokens.
        # The slicing operation 'outputs[0][input_ids.shape[-1]:]' removes the input tokens from the output.
        encoded_reply = outputs[0][input_ids.shape[-1]:]

        # Decode the generated tokens to convert them back into text.
        # The 'decode' method takes the token ids and converts them into human-readable text.
        # 'skip_special_tokens=True' removes any special tokens that were added during tokenization.
        reply = self.pygemma_tokenizer.decode(encoded_reply, skip_special_tokens=True)

        # Return the generated reply
        return reply

### **We create a class that represents the PyGemma AI assistant:**

In [None]:
class PyGemmaAssistant:
    """A class which represents PyGemma AI assistant.

    Attributes:
        _chat_history (list): A list to store the chat history.
        pygemma_response_generator (PyGemmaResponseGenerator): An object of PyGemmaResponseGenerator.
        _use_chat_history (bool): A flag indicating whether to use chat history.

    Args:
        pygemma_response_generator (PyGemmaResponseGenerator): An object of PyGemmaResponseGenerator.
        use_chat_history (bool, optional): A flag indicating whether to use chat history. Defaults to False.
    """

    def __init__(self, pygemma_response_generator: PyGemmaResponseGenerator, use_chat_history: bool = False):
        self._chat_history = []
        self.pygemma_response_generator = pygemma_response_generator
        self._use_chat_history = use_chat_history

    def use_chat_history(self, use_chat: bool):
        self._chat_history = []
        self._use_chat_history = use_chat

    def ask(self, message: str) -> str:
        """Generates a response to a given message.

        If use_chat_history is True, the existing chat history is used for generating the response.
        The user's message and the assistant's response are appended to the chat history.

        Args:
            message (str): The user's message.

        Returns:
            str: The assistant's response.
        """
        chat_history_ = (self._chat_history if self._use_chat_history else [])
        assistant_answer = self.pygemma_response_generator.generate_response(chat_history=chat_history_, prompt=message)
        self._chat_history.append((message, assistant_answer))
        return assistant_answer

## 4. **Test the PyGemma Assistant with some questions in Python**

In [None]:
# We instantiate the objects
pygemma_response_generator = PyGemmaResponseGenerator(pygemma=pygemma_model, pygemma_tokenizer=pygemma_tokenizer, available_device=device)
pygemma_assistant = PyGemmaAssistant(pygemma_response_generator)

In [None]:
from rich.markdown import Markdown

response = pygemma_assistant.ask("What is the difference between a list and a tuple in Python?")
Markdown(response)

In [None]:
from rich.markdown import Markdown

response = pygemma_assistant.ask("How do you handle exceptions in Python?")
Markdown(response)

In [None]:
from rich.markdown import Markdown

response = pygemma_assistant.ask("What are decorators in Python?")
Markdown(response)

In [None]:
from rich.markdown import Markdown

response = pygemma_assistant.ask("What is a lambda function in Python?")
Markdown(response)

In [None]:
from rich.markdown import Markdown

question = """Can you write a Python function that takes a string as input and returns a dictionary where the keys are the
unique characters in the string and the values are the number of times each character appears in the string?"""

response = pygemma_assistant.ask(question)
Markdown(response)

In [None]:
from rich.markdown import Markdown

question = """Can you write a Python function that takes a list of integers as input, and returns a list of those
integers sorted in descending order, but with all the odd numbers at the front of the list and even numbers at the
back? The odd numbers should be sorted in descending order, and the even numbers should be sorted in ascending order."""

response = pygemma_assistant.ask(question)
Markdown(response)

### **Enable To Use Chat History**

In [None]:
pygemma_assistant.use_chat_history(True)

In [None]:
from rich.markdown import Markdown

question1 = """Can you write a Python function that takes a list of numbers as input and returns the sum of all the numbers in the list?"""

response = pygemma_assistant.ask(question1)
Markdown(response)

In [None]:
question2 = """After getting the function for the sum, can you modify it to return the average of the numbers in the list instead?"""

response = pygemma_assistant.ask(question2)
Markdown(response)

In [None]:
response = pygemma_assistant.ask("Thank you")
Markdown(response)