##### Copyright 2024 Google LLC.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/pytorch_gemma"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/pytorch_gemma.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/pytorch_gemma.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

# Gemma in PyTorch

This is a quick demo of running Gemma inference in PyTorch.
For more details, please check out the Github repo of the official PyTorch implementation [here](https://github.com/google/gemma_pytorch).

**Note that**:
 * The free Kaggle CPU Python runtime and GPU Python runtime are sufficient for running the Gemma 2B models and 7B int8 quantized models.
 * For advanced use cases for other GPUs or TPU, please refer to [README.md](https://github.com/google/gemma_pytorch/blob/main/README.md) in the official repo.

### Gemma setup

To complete this tutorial, you will first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:

Gemma models are hosted by Kaggle. To use Gemma, request access on Kaggle:

- Sign in or register at [kaggle.com](https://www.kaggle.com)
- Open the [Gemma 2 model card](https://www.kaggle.com/models/google/gemma-2) and select _"Request Access"_
- Complete the consent form and accept the terms and conditions


## Install dependencies

In [2]:
%pip install -q -U torch==2.4.0 immutabledict sentencepiece


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


## Download model weights

In [3]:
# Choose variant and machine type
VARIANT = '2b-it' #@param ['2b', '2b-it', '9b', '9b-it', '27b', '27b-it']
MACHINE_TYPE = 'cuda' #@param ['cuda', 'cpu']

CONFIG = VARIANT[:2]
if CONFIG == '2b':
  CONFIG = '2b-v2'

In [4]:
weights_dir = "../models/gemma-2-pytorch-gemma-2-2b-it-v1/"

In [5]:
import os
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

## Download the model implementation

In [7]:
import sys
import torch
sys.path.append('../models/gemma-2-pytorch-gemma-2-2b-it-v1/pytorch/gemma_pytorch')

In [11]:
from gemma.config import GemmaConfig, get_model_config
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os


## Setup the model

In [9]:
# Set up model config.
model_config = get_model_config(CONFIG)
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device('cpu')
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

## Run inference

Below are examples for generating in chat mode and generating with multiple
requests.

The instruction-tuned Gemma models were trained with a specific formatter that
annotates instruction tuning examples with extra information, both during
training and inference. The annotations (1) indicate roles in a conversation,
and (2) delineate turns in a conversation. Below we show a sample code snippet
for formatting the model prompt using the user and model chat templates in a
multi-turn conversation. The relevant tokens are:

- `user`: user turn
- `model`: model turn
- `<start_of_turn>`: beginning of dialogue turn
- `<end_of_turn><eos>`: end of dialogue turn

Read about the Gemma formatting for instruction tuning and system instructions
[here](https://ai.google.dev/gemma/docs/formatting).

In [10]:
# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

results = model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=128,
)
print(results)

Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model

That's a BIG question, California is HUGE! To help me narrow down the best options for YOU, tell me:

* **What are you interested in?**  (Beaches, cities, nature, theme parks, history, culture, food, etc.) 
* **Who are you traveling with?** (Solo, couple, family, friends)
* **How much time do you have?** (A weekend trip, a week, longer?)
* **What's your budget?** (Luxurious, budget-friendly, somewhere in between?)


The more specifics you give me, the better


In [26]:
import re
sample_path = "/kaggle/input/video-transcribation-sample/Sample.txt"
cleaned_text = ''
with open(sample_path, "r") as file:
    content = file.read()
    text_without_timestamps = re.sub(r'\d{2}:\d{2}:\d{2},\d{3} --> \d{2}:\d{2}:\d{2},\d{3}', '', content)
    cleaned_text = re.sub(r'^\d$\n', '', text_without_timestamps, flags=re.MULTILINE)
cleaned_text

"\n[SPEAKER_00]: And today, as Jim proved to us, without talking about artificial intelligence and large language models, I typically say artificial intelligence is autocorrect on steroids because all a large language model does is it predicts what's the most likely next word that you're going to use and then it extrapolates from there.\n\n\n[SPEAKER_00]: So not really very intelligent.\n\n\n[SPEAKER_00]: Obviously, the impact that it has on our lives and on the reality we live in is significant.\n\n\n[SPEAKER_00]: Do you think we will see LLM written code that is submitted to you as a progress?\n\n\n[SPEAKER_01]: I'm convinced it's going to happen, yes.\n\n\n[SPEAKER_01]: And it may well be happening already, maybe on a smaller scale where people really use it more as a help in writing code.\n\n\n[SPEAKER_01]: It's clearly something where automation has always helped people write code.\n\n\n[SPEAKER_01]: I mean, this is not anything new at all.\n\n\n[SPEAKER_01]: We don't write machin

In [28]:
prompt = """Online Meeting Summary & Insights Assistant
Task:
Analyze Discussion Structure: Identify and label key topics/subtopics discussed during the meeting.
Speaker-wise Subtopic Summaries: For each subtopic, provide a brief summary (1-2 sentences) of the key points mentioned by each relevant speaker.
Meeting Conclusion & Key Takeaways:
Short Conclusion (2-3 sentences): Outline the overall discussion and meeting outcome.
Emphasized Key Points (bullet points): Highlight the most crucial decisions, actions, or agreements from the meeting.
Input:

Meeting Transcript: 

{script}

Desired Output Format:

**Topic 1: [ Brief Topic Description ]**
* **Speaker [ID]**: [ Brief Summary of Speaker's Key Points (1-2 sentences) ]
* **Speaker [ID]**: [ Brief Summary of Speaker's Key Points (1-2 sentences) ]
*...

**Topic 2: [ Brief Topic Description ]**
* **Speaker [ID]**: [ Brief Summary of Speaker's Key Points (1-2 sentences) ]
* **Speaker [ID]**: [ Brief Summary of Speaker's Key Points (1-2 sentences) ]
*...

**Conclusion:**
* Brief summary of the overall discussion and meeting outcome (2-3 sentences)

**Key Takeaways:**
* • Crucial Decision/Action 1
* • Crucial Decision/Action 2
* •...

Answer:"""
def clean_text(input_text):
    """
    Removes lines containing only numbers from the input text.

    Args:
        input_text (str): The text to be cleaned.

    Returns:
        str: The cleaned text with number-only lines removed.
    """
    # Split the input text into lines
    lines = input_text.split('\n')
    
    # Filter out lines that contain only numbers (possibly with leading/trailing whitespace)
    cleaned_lines = [line for line in lines if not re.match(r'^\s*\d+\s*$', line)]
    
    # Join the cleaned lines back into a single string
    cleaned_text = '\n'.join(cleaned_lines)
    
    return cleaned_text
cleaned_text = clean_text(cleaned_text).replace('\n\n\n', '\n')
# print(prompt.format(script=cleaned_text))
# Generate sample
results = model.generate(
    prompt.format(script=cleaned_text),
    device=device,
    output_len=256,
)
print(results)

Online Meeting Summary & Insights Assistant
Task:
Analyze Discussion Structure: Identify and label key topics/subtopics discussed during the meeting.
Speaker-wise Subtopic Summaries: For each subtopic, provide a brief summary (1-2 sentences) of the key points mentioned by each relevant speaker.
Meeting Conclusion & Key Takeaways:
Short Conclusion (2-3 sentences): Outline the overall discussion and meeting outcome.
Emphasized Key Points (bullet points): Highlight the most crucial decisions, actions, or agreements from the meeting.
Input:

Meeting Transcript: 


[SPEAKER_00]: And today, as Jim proved to us, without talking about artificial intelligence and large language models, I typically say artificial intelligence is autocorrect on steroids because all a large language model does is it predicts what's the most likely next word that you're going to use and then it extrapolates from there.
[SPEAKER_00]: So not really very intelligent.
[SPEAKER_00]: Obviously, the impact that it has on 

## Learn more

Now that you have learned how to use Gemma in Pytorch, you can explore the many
other things that Gemma can do in [ai.google.dev/gemma](https://ai.google.dev/gemma).
See also these other related resources:

- [Gemma model card](https://ai.google.dev/gemma/docs/model_card)
- [Gemma C++ Tutorial](https://ai.google.dev/gemma/docs/gemma_cpp)
- [Gemma formatting and system instructions](https://ai.google.dev/gemma/docs/formatting)

In [3]:
import requests
import json

def make_predict_request(prompt, url='http://localhost:5000/predict'):
    """
    Makes a POST request to the /predict endpoint with a given prompt.
    
    Args:
    - prompt (str): The input prompt for the GEMMA model.
    - url (str, optional): The URL of the /predict endpoint. Defaults to 'http://localhost:5000/predict'.
    
    Returns:
    - response (requests.Response): The response from the server.
    """
    # Prepare the request data
    data = {'prompt': prompt}
    
    # Convert the data to JSON
    json_data = json.dumps(data)
    
    # Set the appropriate headers (Content-Type for JSON)
    headers = {'Content-Type': 'application/json'}
    
    # Make the POST request
    response = requests.post(url, headers=headers, data=json_data)
    
    return response

# Example usage
if __name__ == '__main__':
    prompt = "Hello, how are you?"  # Your input prompt here
    response = make_predict_request(prompt)
    
    # Check the response status code
    if response.status_code == 202:
        print("Request successfully queued. Response:", response.json())
    else:
        print("Failed to queue request. Status code:", response.status_code)
        print("Response content:", response.text)

Request successfully queued. Response: {'message': 'Request queued for processing'}
