<a href="https://colab.research.google.com/github/neomatrix369/learning-path-index/blob/lpi-gemma-model/app/llm-poc-variant-03/lpi_keras_finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Table of Contents

1. **Environment Setup**
   - 1.1 System Check: `!nvidia-smi`
   - 1.2 Package Installations: Localtunnel, FastAPI, Uvicorn, Kaggle, Keras-NLP, KaggleHub

2. **Dataset and Model Download**
   - 2.1 Kaggle Data Source Import
   - 2.2 Downloading Learning Path Index Dataset and Gemma2 Model

3. **Project Title and Introduction**
   - Title: "Fine-tuning Gemma 2 model using LoRA and Keras"
   - 3.1 Introduction Content
     - 3.1.1 Overview of fine-tuning with LoRA on the Gemma2 model
     - 3.1.2 Gemma2 model details
     - 3.1.3 Description of LoRA

4. **Fine-Tuning Process Overview**
   - 4.1 Objective Outline: Fine-tuning steps and goals

5. **Prerequisites**
   - 5.1 Package Imports: Libraries like `keras`, `numpy`, `pandas`, `seaborn`

6. **Model Configuration**
   - 6.1 Configuration Settings: `Config` class for fine-tuning parameters

7. **Data Loading and Preprocessing**
   - 7.1 Data Import and Display
   - 7.2 Dataset Pre-check: Shape and columns display
   - 7.3 Q&A Generation: Function to create question-answer pairs from dataset
   - 7.4 Data Preprocessing for Fine-Tuning

8. **Generate Question and Answer pairsn**
   - Initialize LoRA and Gemma Model
   - Model Summary and Training Compilation

9. **Running Fine-Tuning**
   - Setting Sequence Length and Batch Size
   - Training Execution

10. **Model Testing**
    - Instantiating GemmaQA Class
    - Sample Queries and Testing
    - Testing with New Questions

11. **Saving and Exporting the Model**
    - Save Model as Preset
    - Exporting the Model to Kaggle

12. **Web Interface Setup**
    - UI Setup: HTML and JavaScript for Chat Interface
    - Backend Setup: FastAPI endpoints and query function for Gemma model
    - Running FastAPI Application

13. **Exposing Web Application**
    - Uvicorn Server Launch
    - Using LocalTunnel for External Access

14. **Conclusion**
    - Summary of Fine-Tuning Process
    - Model Deployment as a Kaggle Model


### Credits: https://medium.com/@gabi.preda/fine-tuning-gemma-2-model-with-role-playing-dataset-b8ec399a2e17


### Kaggle Model: https://www.kaggle.com/models/moronfoluwa/gemma2_2b_en_lpi/keras/gemma2_2b_en_lpi

## Environment Setup

### 1.1 System Check

In [None]:
!nvidia-smi

Sat Nov  2 21:40:01 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA L4                      Off | 00000000:00:03.0 Off |                    0 |
| N/A   65C    P8              14W /  72W |      1MiB / 23034MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

### 1.2 Package Installations

Localtunnel, FastAPI, Uvicorn, Kaggle, Keras-NLP, KaggleHub

In [None]:
!npm install -g localtunnel
!pip install fastapi
!pip install uvicorn
!pip install kaggle

[K[?25h
changed 22 packages, and audited 23 packages in 538ms

3 packages are looking for funding
  run `npm fund` for details

1 [33m[1mmoderate[22m[39m severity vulnerability

To address all issues (including breaking changes), run:
  npm audit fix --force

Run `npm audit` for details.


Installing `keras-nlp` and `keras` packages.

In [None]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3
!pip install -q -U kagglehub --upgrade

## 2. Dataset and Model Download


### 2.1 Kaggle Data Source Import

Since out dataset is hosted on kaggle we need to run this line to download the dataset on kaggle especially if you are not running in kaggle environments

In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

### 2.2 Downloading Learning Path Index Dataset and Gemma2 Model

In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

neomatrix369_learning_path_index_dataset_path = kagglehub.dataset_download('neomatrix369/learning-path-index-dataset')
keras_gemma2_keras_gemma2_2b_en_1_path = kagglehub.model_download('keras/gemma2/Keras/gemma2_2b_en/1')
!wget https://raw.githubusercontent.com/neomatrix369/learning-path-index/refs/heads/main/data/Learning_Pathway_Index.csv

print('Data source import complete.')

Data source import complete.



# 3. Fine-tuning Gemma 2 model using LoRA and Keras

### 3.1 Introduction Content


<center><h1>Fine-tuning Gemma 2 model using LoRA and Keras</h1></center>

<center><img src="https://res.infoq.com/news/2024/02/google-gemma-open-model/en/headerimage/generatedHeaderImage-1708977571481.jpg" width="400"></center>


### Introduction

This notebook will demonstrate three things:

1. How to fine-tune Gemma model using LoRA
2. Creation of a specialised class to query about Kaggle features
3. Some results of querying about various topics while instructing the model to adopt a certain persona, from the ones included in the data used for fine tuning.



#### 3.1.1 What is Gemma 2?

Gemma is a collection of lightweight, advanced open models developed by Google, leveraging the same research and technology behind the Gemini models. These models are text-to-text, decoder-only large language models available in English, with open weights provided for both pre-trained and instruction-tuned versions. Gemma models excel in a range of text generation tasks, such as question answering, summarization, and reasoning. Their compact size allows for deployment in resource-constrained environments like laptops, desktops, or personal cloud infrastructure, making state-of-the-art AI models more accessible and encouraging innovation for all.

Gemma 2 represent the 2nd generation of Gemma models. These models were trained on a dataset of text data that includes a wide variety of sources. The **27B** model was trained with **13 trillion** tokens, the **9B** model was trained with **8 trillion tokens**, and **2B** model was trained with **2 trillion** tokens. Here is a summary of their key components:
* **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. Primarily English-language content.
* **Code**: Exposing the model to code helps it to learn the syntax and patterns of programming languages, which improves its ability to generate code or understand code-related questions.
* **Mathematics**: Training on mathematical text helps the model learn logical reasoning, symbolic representation, and to address mathematical queries.

To learn more about Gemma 2, follow this link: [Gemma 2 Model Card](https://www.kaggle.com/models/google/gemma-2).




#### 3.1.2 What is LoRA?  

**LoRA** stands for **Low-Rank Adaptation**. It is a method used to fine-tune large language models (LLMs) by freezing the weights of the LLM and injecting trainable rank-decomposition matrices. The number of trainable parameters during fine-tunning will decrease therefore considerably. According to **LoRA** paper, this number decreases **10,000 times**, and the computational resources size decreases 3 times.

# 4. Fine-Tuning Process Overview

For fine-tunning with LoRA, we will follow the steps:

1. Install prerequisites
2. Load and process the data for fine-tuning
3. Initialize the code for Gemma causal language model (Gemma Causal LM)
4. Perform fine-tuning so that the model will learn the various persona and be able to perform in each role.
5. Test the fine-tunned model with questions from the data used for fine-tuning and with aditional questions

# 5. Prerequisites
## Import packages

Now we can import the packages we just installed. We will also install `os`, so that we can set the environment variables needed for keras backend. We will use `jax` as `KERAS_BACKEND`.

Because we want to publish the Model from the Notebook, we also include `kagglehub` and import secrets from `Kaggle App`.

In [None]:
import re
import os
# you can also use tensorflow or torch
os.environ["KERAS_BACKEND"] = "jax"
# avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"
os.environ["JAX_PLATFORMS"] = ""
import keras
import keras_nlp
import kagglehub
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
tqdm.pandas()

import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, Markdown

# 6. Configurations

We use a `Config` class to group the information needed to control the fine-tuning process:
* random seed
* dataset path
* qa_dataset_path
* preset - name of pretrained Gemma 2
* sequence length - this is the maximum size of input sequence for training
* batch size - size of the input batch in training, x 2 as two GPUs
* lora rank - rank for LoRA, higher means more trainable parameters
* learning rate used in the train
* epochs - number of epochs for train

In [3]:
class Config:
    seed = 42
    dataset_path = "./Learning_Pathway_Index.csv"
    qa_dataset_path = "./qa_pairs.csv" # Question pair dataset
    preset = "hf://google/gemma-2-2b" # name of pretrained Gemma 2
    sequence_length = 512 # max size of input sequence for training
    batch_size = 1 # size of the input batch in training
    lora_rank = 4 # rank for LoRA, higher means more trainable parameters
    learning_rate=8e-5 # learning rate used in train
    epochs = 20 # number of epochs to train

Set a random seed for results reproducibility.

In [None]:
keras.utils.set_random_seed(Config.seed)

# 7. Data Loading and Preprocessin

We load the data we will use for fine-tunining.

### 7.1 Data Import and Display

In [4]:
df = pd.read_csv(f"{Config.dataset_path}")
df.head()

NameError: name 'pd' is not defined

Let's check the total number of rows in this dataset.

In [None]:
df.shape, df.columns

((1446, 10),
 Index(['Module_Code', 'Course_Learning_Material', 'Source', 'Course_Level',
        'Type_Free_Paid', 'Module', 'Duration', 'Difficulty_Level',
        'Keywords_Tags_Skills_Interests_Categories', 'Links'],
       dtype='object'))

# 8. Generate Question and Answer pairs

### 8.1  Function to create question-answer pairs from dataset

In [None]:
def generate_qa_pairs(row):
    qa_pairs = []

    # Basic data extraction for clarity
    module_code = row['Module_Code']
    course_title = row['Course_Learning_Material']
    source = row['Source']
    level = row['Course_Level']
    cost_type = row['Type_Free_Paid']
    duration = row['Duration']
    difficulty = row['Difficulty_Level']
    skills = row['Keywords_Tags_Skills_Interests_Categories']
    link = row['Links']

    # Context information for input
    context = f"Course Title: {course_title}; Module Code: {module_code}; Level: {level}; Type: {cost_type}; Duration: {duration}"

    # Generate Question Types
    questions = [
        # Factual Recall (20%)
        {"instruction": f"What is the title of the course with Module Code '{module_code}'?",
         "input": context,
         "output": course_title,
         "metadata": {"type": "Factual Recall", "difficulty": difficulty, "cognitive_level": "Recall", "source_reference": course_title}},

        {"instruction": f"Is the course '{course_title}' free or paid?",
         "input": f"Module Code: {module_code}; Cost Status: {cost_type}",
         "output": cost_type,
         "metadata": {"type": "Factual Recall", "difficulty": difficulty, "cognitive_level": "Recall", "source_reference": course_title}},

        # Comprehension (25%)
        {"instruction": f"Explain the main purpose of the course '{course_title}' provided by {source}.",
         "input": f"Course Source: {source}; Skills Covered: {skills}",
         "output": f"The course '{course_title}' by {source} focuses on {skills}.",
         "metadata": {"type": "Comprehension", "difficulty": difficulty, "cognitive_level": "Comprehension", "source_reference": course_title}},

        {"instruction": f"What are the primary skills or topics covered in the '{course_title}' course?",
         "input": f"Keywords: {skills}",
         "output": skills,
         "metadata": {"type": "Comprehension", "difficulty": difficulty, "cognitive_level": "Comprehension", "source_reference": skills}},

        # Application (25%)
        {"instruction": f"How much time will it take to complete the '{course_title}' course?",
         "input": f"Duration: {duration}",
         "output": f"The estimated time to complete the course is {duration}.",
         "metadata": {"type": "Application", "difficulty": difficulty, "cognitive_level": "Application", "source_reference": duration}},

        {"instruction": f"Provide an example of a project where skills from '{course_title}' would be useful.",
         "input": f"Skills: {skills}",
         "output": f"Skills from '{course_title}' could be applied in a project involving {skills}, such as developing machine learning models.",
         "metadata": {"type": "Application", "difficulty": difficulty, "cognitive_level": "Application", "source_reference": course_title}},

        # Analysis (20%)
        {"instruction": f"Compare the '{course_title}' course with other {level} courses offered by {source}.",
         "input": f"Source: {source}; Level: {level}",
         "output": f"The course '{course_title}' is rated {difficulty} and covers topics such as {skills}, making it comparable to other {level} courses by {source}.",
         "metadata": {"type": "Analysis", "difficulty": difficulty, "cognitive_level": "Analysis", "source_reference": course_title}},

        {"instruction": f"Where can you access the '{course_title}' course online?",
         "input": f"Course Link: {link}",
         "output": link,
         "metadata": {"type": "Analysis", "difficulty": difficulty, "cognitive_level": "Analysis", "source_reference": link}},

        # Synthesis/Evaluation (10%)
        {"instruction": f"Evaluate the usefulness of '{course_title}' for learners interested in {skills}.",
         "input": f"Level: {level}; Skills: {skills}; Difficulty: {difficulty}",
         "output": f"The course '{course_title}' by {source} is beneficial for learners interested in {skills}, providing material at a {difficulty} level.",
         "metadata": {"type": "Evaluation", "difficulty": difficulty, "cognitive_level": "Evaluation", "source_reference": course_title}},

        {"instruction": f"How could the course '{course_title}' be improved for {level} learners?",
         "input": f"Course Level: {level}; Duration: {duration}",
         "output": f"The course '{course_title}' could be improved with hands-on projects to enhance practical understanding for {level} learners.",
         "metadata": {"type": "Synthesis", "difficulty": difficulty, "cognitive_level": "Synthesis", "source_reference": course_title}}
    ]

    # Append questions to the qa_pairs list with balanced distribution
    for question in questions:
        qa_pairs.append({
            "instruction": question["instruction"],
            "input": question["input"],
            "output": question["output"],
            "metadata": question["metadata"]
        })

    return qa_pairs

# Generate Q&A pairs for each row in the dataset and combine them into a single list
all_qa_pairs = []
for _, row in df.iterrows():
    all_qa_pairs.extend(generate_qa_pairs(row))

# Convert the list of Q&A pairs into a DataFrame for further analysis or export
qa_pairs_df = pd.DataFrame(all_qa_pairs)

# Optional: Save the generated Q&A pairs to a CSV file
qa_pairs_df.to_csv("qa_pairs.csv", index=False)
# print("Q&A pairs generated and saved to 'Generated_QA_Pairs_Enhanced.csv'")

### 8.2 Data Preprocessing for Fine-Tuning

We will preprocess the data so that, from the sequences in the `text` column, we extract the `<|system|>` prompt and the pairs of {`<|user|>`, `<|assistant|>`} to form triplets of {`<|system|>`, `<|user|>`, `<|assistant|>`}  for each entry in the data for fine-tuning.

In [None]:
def extract_dialogue_components(row):
    # Ensure all relevant fields are strings and handle NaN or other invalid types
    module_code = str(row['Module_Code']) if pd.notna(row['Module_Code']) else "Unknown Module"
    source = str(row['Source']) if pd.notna(row['Source']) else "Unknown Source"
    difficulty_level = str(row['Difficulty_Level']) if pd.notna(row['Difficulty_Level']) else "Unknown Level"
    module = str(row['Module']) if pd.notna(row['Module']) else "Unknown Module"
    course_material = str(row['Course_Learning_Material']) if pd.notna(row['Course_Learning_Material']) else ""
    keywords = str(row['Keywords_Tags_Skills_Interests_Categories']) if pd.notna(row['Keywords_Tags_Skills_Interests_Categories']) else "No keywords available"
    duration = str(row['Duration']) if pd.notna(row['Duration']) else "Unknown duration"

    # Extract system prompt from course metadata
    system_prompt = f"<|system|> Module: {module_code}, Source: {source}, Level: {difficulty_level}. This is an introduction to {module}. </s>"

    # Extract user input as Course Learning Material (if available)
    user_input = f"<|user|> {course_material} </s>" if course_material else "<|user|> No course learning material provided. </s>"

    # Extract assistant response from other relevant columns
    assistant_response = f"<|assistant|> This module covers the following topics: {keywords}. Duration: {duration}. </s>"

    # Combine user and assistant exchanges as dialogue pairs
    dialogue_pair = f"{user_input}\n{assistant_response}"

    return system_prompt, [dialogue_pair]

We process the data. We will only include in the data for fine-tuning the model the rows that fits in the max length as configured.

In [None]:
# Initialize an empty list to store processed data
data = []

# Function to simulate token length estimation
def estimate_token_length(text):
    return len(text.split())

# Iterate over each row in the dataframe
for index, row in df.iterrows():
    try:
        # Estimate the length of the text in terms of tokens
        token_length = estimate_token_length(row["Course_Learning_Material"])

        # Filter rows based on max token length constraint
        if token_length <= Config.sequence_length:
            system_prompt, dialogue_pairs = extract_dialogue_components(row)

            # Prepare prompt samples from dialogue pairs
            for pair in dialogue_pairs:
                prompt_sample = f"{system_prompt}\n\n{pair}"
                data.append(prompt_sample)
    except Exception as ex:
        print(f"Error at row {index}: {ex}")

# Display the number of processed data points
len(data)


1446

### 8.3 Template utility function


We use this function to reformat the output of our queries, so that it is more user friendly.

We replace and highlight the initial special tokens with more human-readable text (Instruction, Question, Answer).

In [None]:
def colorize_text(text):
    for word, formatted_word, color in zip(["<|system|>:", "<|user|>:", "<|assistant|>:"],
                                           ["Instruction:", "Question:", "Answer:"],
                                           ["blue", "red", "green"]):
        text = text.replace(f"\n\n{word}", f"\n\n**<font color='{color}'>{formatted_word}</font>**")
    return text

# 9. Fine-Tuning Preparation


We define a specialized class to query Gemma. But first, we need to initialize an object of GemmaCausalLM class.

### 9.1 Setting Sequence Length and Batch Size

Initialize the code for Gemma Causal LM

In [None]:
gemma_causal_lm = keras_nlp.models.GemmaCausalLM.from_preset(Config.preset)
gemma_causal_lm.summary()

config.json:   0%|          | 0.00/818 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/481M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

### 9.2 Define the specialized class

Here we define the special class `GemmaQA`.
in the `__init__` we pass the `GemmaCausalLM` object created before.
The `query` member function uses `GemmaCausalLM` member function `generate` to generate the answer, based on a prompt that includes the category and the question.

In [None]:
template = "\n\n<|system|>:\n{instruct}\n\n<|user|>:\n{question}\n\n<|assistant|>:\n{answer}"
class GemmaQA:
    def __init__(self, max_length=512):
        self.max_length = max_length
        self.prompt = template
        self.gemma_causal_lm = gemma_causal_lm

    def query(self, instruct, question):
        response = self.gemma_causal_lm.generate(
            self.prompt.format(
                instruct=instruct,
                question=question,
                answer=""),
            max_length=self.max_length)
        display(Markdown(colorize_text(response)))


### 9.3 Gemma preprocessor


This preprocessing layer will take in batches of strings, and return outputs in a ```(x, y, sample_weight)``` format, where the y label is the next token id in the x sequence.

From the code below, we can see that, after the preprocessor, the data shape is ```(num_samples, sequence_length)```.

In [None]:
x, y, sample_weight = gemma_causal_lm.preprocessor(data[0:2])

In [None]:
print(x, y)

{'token_ids': Array([[     2, 235322, 235371, ...,      0,      0,      0],
       [     2, 235322, 235371, ...,      0,      0,      0]],      dtype=int32), 'padding_mask': Array([[ True,  True,  True, ..., False, False, False],
       [ True,  True,  True, ..., False, False, False]], dtype=bool)} [[235322 235371   9020 ...      0      0      0]
 [235322 235371   9020 ...      0      0      0]]


### 9.4 Perform fine-tuning with LoRA

### 9.5 Enable LoRA for the model

LoRA rank is setting the number of trainable parameters. A larger rank will result in a larger number of parameters to train.

In [None]:
# Enable LoRA for the model and set the LoRA rank to the lora_rank as set in Config (4).
gemma_causal_lm.backbone.enable_lora(rank=Config.lora_rank)
gemma_causal_lm.summary()

We see that only a small part of the parameters are trainable. 2.6 billions parameters total, and only 2.9 Millions parameters trainable.

### 9.6 Run the training sequence

We set the `sequence_length` for the `GemmaCausalLM` (from configuration, will be 512).
We compile the model, with the loss, optimizer and metric.
For the metric, it is used `SparseCategoricalAccuracy`. This metric calculates how often predictions match integer labels.

In [None]:
#set sequence length cf. config (896)
gemma_causal_lm.preprocessor.sequence_length = Config.sequence_length

# Compile the model with loss, optimizer, and metric
gemma_causal_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(learning_rate=Config.learning_rate),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train model
gemma_causal_lm.fit(data, epochs=2, batch_size=Config.batch_size)

Epoch 1/2
[1m1446/1446[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m630s[0m 403ms/step - loss: 0.2376 - sparse_categorical_accuracy: 0.7342
Epoch 2/2
[1m1446/1446[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m584s[0m 390ms/step - loss: 0.1318 - sparse_categorical_accuracy: 0.8433


<keras.src.callbacks.history.History at 0x794ae07055d0>

We obtained a rather good accuracy after the 15 steps of fine-tuning.

# 10. Model Testing

We instantiate an object of class GemmaQA. Because `gemma_causal_lm` was fine-tuned using LoRA, `gemma_qa` defined here will use the fine-tuned model.

### 10.1 Instantiating GemmaQA Class

In [None]:
gemma_qa = GemmaQA()

### 10.2 Sample Queries and Testing

For start, we are testing the model with some of the data from the training set itself.

#### Sample 1

In [None]:
gemma_qa = GemmaQA(max_length=96)
instruct = "Sherlock the renowned detective from Baker Street is known for his astute logical reasoning disguise ability and use of forensic science to solve perplexing crimes"
question = "What's Sherlock secret to solving crimes?"
gemma_qa.query(instruct, question)



**<font color='blue'>Instruction:</font>**
Sherlock the renowned detective from Baker Street is known for his astute logical reasoning disguise ability and use of forensic science to solve perplexing crimes

**<font color='red'>Question:</font>**
What's Sherlock secret to solving crimes?

**<font color='green'>Answer:</font>**
Sherlock's secret to solving crimes is his ability to think outside the box and use his deductive reasoning skills to analyze evidence and draw logical conclusions. He is also known for his ability to disguise himself

#### Sample 2

In [None]:
gemma_qa = GemmaQA(max_length=96)
instruct = "What are the primary skills or topics covered in the 'Introduction to Machine Learning' course?"
question = "What are the primary skills or topics covered in the 'Introduction to Machine Learning' course?"
gemma_qa.query(instruct, question)

#### Sample 3

In [None]:
# gemma_qa = GemmaQA(max_length=96)
# instruct = ""
# question = "What courses are Beginner?"
# gemma_qa.query(instruct, question)

### 10.3 Testing with New Questions

In [None]:
gemma_qa = GemmaQA(max_length=128)
instruct = ""
question = "What courses belong to Google Developers?"
gemma_qa.query(instruct, question)



**<font color='blue'>Instruction:</font>**


**<font color='red'>Question:</font>**
What courses belong to Google Developers?

**<font color='green'>Answer:</font>**
Here are some courses that belong to Google Developers: 


- Google Cloud Professional Data Engineer
- Google Cloud Professional Data Engineer Specialization
- Google Cloud Professional Data Engineer Specialization
- Google Cloud Professional Data Engineer Specialization
- Google Cloud Professional Data Engineer Specialization
- Google Cloud Professional Data Engineer Specialization
- Google Cloud Professional Data Engineer Specialization
- Google Cloud Professional Data Engineer Specialization
- Google Cloud Professional Data Engineer Specialization
- Google Cloud Professional Data Engineer Specialization. 


**<font color='red'>Question:</font>** What are

In [None]:
instruct = ""
question = "List 10 courses that are more than 20 minutes?"
gemma_qa.query(instruct,question)



**<font color='blue'>Instruction:</font>**


**<font color='red'>Question:</font>**
List 10 courses that are more than 20 minutes?

**<font color='green'>Answer:</font>**
Here are the top 10 courses that are more than 20 minutes:

<|document|>:
[Top 10 Courses that are More Than 20 Minutes](https://www.google.com/search?q=top+10+courses+that+are+more+than+20+minutes&tbm=vid).


In [None]:
instruct = "List courses that belongs to Intermediate category"
question = "List courses that are Intermediate"

gemma_qa.query(instruct,question)



**<font color='blue'>Instruction:</font>**
List courses that belongs to Intermediate category

**<font color='red'>Question:</font>**
List courses that are Intermediate

**<font color='green'>Answer:</font>**
Create a new course with the name "Intermediate Course" and the description "This is an intermediate course". 

<|data|>: 
[courses, course_ids].


# 11. Saving and Exporting the Model

### 11.1 Save Model as Preset

In [None]:
preset_dir = "gemma2_2b_en_lpi"
gemma_causal_lm.save_to_preset(preset_dir)

#### 11.2 Exporting the Model to Kaggle

We are publishing now the saved model as a Kaggle Model.

In [None]:
# kaggle_username = os.environ["KAGGLE_USERNAME"]

# kaggle_uri = f"kaggle://{kaggle_username}/gemma2_2b_en_lpi/keras/gemma2_2b_en_lpi"
# keras_nlp.upload_preset(kaggle_uri, preset_dir)

# Conclusions



We demonstated how to fine-tune a **Gemma 2** model using LoRA.  

We also created a class to run queries to the **Gemma 2** model and tested it with some examples from the existing training data but also with some new, not seen questions.   

At the end, we published the model as a Kaggle Model using `kagglehub`.

# 12. Web Interface Setup

### 12.1 UI and Backend Setup

This uses HTML and JavaScript for Chat Interface and Backend Setup: FastAPI endpoints and query function for Gemma model

In [None]:
%%writefile main.py

import os
os.environ["JAX_PLATFORMS"] = "cpu"
# Disable GPU to prevent CUDA errors if using CPU-only environment
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# optimize TensorFlow’s performance for non-AVX-512 CPUs by disabling OneDNN optimizations:
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
from pydantic import BaseModel
import uvicorn
import json
import tensorflow as tf
# import os
import keras_nlp

# Template for the GemmaQA prompt
template = "\n\n<|system|>:\n{instruct}\n\n<|user|>:\n{question}\n\n<|assistant|>:\n{answer}"

# Disable GPU to prevent CUDA errors in CPU-only environments
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# os.environ["JAX_PLATFORMS"] = "cpu"

# Ensure memory growth for GPUs if available
physical_devices = tf.config.list_physical_devices('GPU')
for device in physical_devices:
    tf.config.experimental.set_memory_growth(device, True)

app = FastAPI()

# Load Gemma model and define the GemmaQA class
class GemmaQA:
    def __init__(self, max_length=512):
        self.max_length = max_length
        self.prompt = template
        self.gemma_causal_lm = keras_nlp.models.GemmaCausalLM.from_preset("./gemma2_2b_en_lpi")  # Replace with actual path

    def query(self, instruct, question):
        input_text = self.prompt.format(instruct=instruct, question=question, answer="")
        response = self.gemma_causal_lm.generate(input_text, max_length=self.max_length)
        return response.split("<|assistant|>:")[-1].strip()

# Instantiate the GemmaQA model
gemma_qa = GemmaQA(max_length=128)

# Define a Pydantic model for the chat request
class ChatRequest(BaseModel):
    instruction: str
    question: str

# Define the chat endpoint
@app.post("/chat")
async def chat(request: ChatRequest):
    response_text = gemma_qa.query(request.instruction, request.question)
    return JSONResponse(content={"response": response_text})

# Serve the HTML frontend
@app.get("/", response_class=HTMLResponse)
async def get_ui():
    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>Your Chat UI</title>
        <style>
            /* Compiled CSS from style.sass */
            /* Add compiled CSS here */
            .--dark-theme {
                --chat-background: rgba(10, 14, 14, 0.95);
                --chat-panel-background: #131719;
                --chat-bubble-background: #14181a;
                --chat-add-button-background: #212324;
                --chat-send-button-background: #8147fc;
                --chat-text-color: #a3a3a3;
                --chat-options-svg: #a3a3a3;
            }

            body {
                background: url(https://images.unsplash.com/photo-1495808985667-ba4ce2ef31b3?ixlib=rb-1.2.1&ixid=eyJhcHBfaWQiOjEyMDd9&auto=format&fit=crop&w=1050&q=80);
                background-size: cover;
            }

            #chat {
                background: var(--chat-background);
                max-width: 600px;
                margin: 25px auto;
                box-sizing: border-box;
                padding: 1em;
                border-radius: 12px;
                position: relative;
                overflow: hidden;
            }

            #chat::before {
                content: "";
                position: absolute;
                top: 0;
                left: 0;
                width: 100%;
                height: 100%;
                background: url(https://images.unsplash.com/photo-1495808985667-ba4ce2ef31b3?ixlib=rb-1.2.1&ixid=eyJhcHBfaWQiOjEyMDd9&auto=format&fit=crop&w=1050&q=80) fixed;
                z-index: -1;
            }

            #chat .btn-icon {
                position: relative;
                cursor: pointer;
            }

            #chat .btn-icon svg {
                stroke: #FFF;
                fill: #FFF;
                width: 50%;
                height: auto;
                position: absolute;
                top: 50%;
                left: 50%;
                transform: translate(-50%, -50%);
            }

            #chat .chat__conversation-board {
                padding: 1em 0 2em;
                height: calc(100vh - 55px - 2em - 25px * 2 - .5em - 3em);
                overflow: auto;
            }

            #chat .chat__conversation-board__message-container.reversed {
                flex-direction: row-reverse;
            }

            #chat .chat__conversation-board__message-container.reversed .chat__conversation-board__message__bubble {
                position: relative;
            }

            #chat .chat__conversation-board__message-container.reversed .chat__conversation-board__message__bubble span:not(:last-child) {
                margin: 0 0 2em 0;
            }

            #chat .chat__conversation-board__message-container.reversed .chat__conversation-board__message__person {
                margin: 0 0 0 1.2em;
            }

            #chat .chat__conversation-board__message-container.reversed .chat__conversation-board__message__options {
                align-self: center;
                position: absolute;
                left: 0;
                display: none;
            }

            #chat .chat__conversation-board__message-container {
                position: relative;
                display: flex;
                flex-direction: row;
            }

            #chat .chat__conversation-board__message-container:hover .chat__conversation-board__message__options {
                display: flex;
                align-items: center;
            }

            #chat .chat__conversation-board__message-container:hover .option-item:not(:last-child) {
                margin: 0 .5em 0 0;
            }

            #chat .chat__conversation-board__message-container:not(:last-child) {
                margin: 0 0 2em 0;
            }

            #chat .chat__conversation-board__message__person {
                text-align: center;
                margin: 0 1.2em 0 0;
            }

            #chat .chat__conversation-board__message__person__avatar {
                height: 35px;
                width: 35px;
                overflow: hidden;
                border-radius: 50%;
                user-select: none;
                ms-user-select: none;
                position: relative;
            }

            #chat .chat__conversation-board__message__person__avatar::before {
                content: "";
                position: absolute;
                height: 100%;
                width: 100%;
            }

            #chat .chat__conversation-board__message__person__avatar img {
                height: 100%;
                width: auto;
            }

            #chat .chat__conversation-board__message__person__nickname {
                font-size: 9px;
                color: #484848;
                user-select: none;
                display: none;
            }

            #chat .chat__conversation-board__message__context {
                max-width: 55%;
                align-self: flex-end;
            }

            #chat .chat__conversation-board__message__options {
                align-self: center;
                position: absolute;
                right: 0;
                display: none;
            }

            #chat .chat__conversation-board__message__options .option-item {
                border: 0;
                background: 0;
                padding: 0;
                margin: 0;
                height: 16px;
                width: 16px;
                outline: none;
            }

            #chat .chat__conversation-board__message__options .emoji-button svg {
                stroke: var(--chat-options-svg);
                fill: transparent;
                width: 100%;
            }

            #chat .chat__conversation-board__message__options .more-button svg {
                stroke: var(--chat-options-svg);
                fill: transparent;
                width: 100%;
            }

            #chat .chat__conversation-board__message__bubble span {
                width: fit-content;
                display: inline-table;
                word-wrap: break-word;
                background: var(--chat-bubble-background);
                font-size: 13px;
                color: var(--chat-text-color);
                padding: .5em .8em;
                line-height: 1.5;
                border-radius: 6px;
                font-family: 'Lato', sans-serif;
            }

            #chat .chat__conversation-board__message__bubble:not(:last-child) {
                margin: 0 0 .3em;
            }

            #chat .chat__conversation-board__message__bubble:active {
                background: var(--chat-bubble-active-background);
            }

            #chat .chat__conversation-panel {
                background: var(--chat-panel-background);
                border-radius: 12px;
                padding: 0 1em;
                height: 55px;
                margin: .5em 0 0;
            }

            #chat .chat__conversation-panel__container {
                display: flex;
                flex-direction: row;
                align-items: center;
                height: 100%;
            }

            #chat .chat__conversation-panel__container .panel-item:not(:last-child) {
                margin: 0 1em 0 0;
            }

            #chat .chat__conversation-panel__button {
                background: grey;
                height: 20px;
                width: 30px;
                border: 0;
                padding: 0;
                outline: none;
                cursor: pointer;
            }

            #chat .chat__conversation-panel .add-file-button {
                height: 23px;
                min-width: 23px;
                width: 23px;
                background: var(--chat-add-button-background);
                border-radius: 50%;
            }

            #chat .chat__conversation-panel .add-file-button svg {
                width: 70%;
                stroke: #54575c;
            }

            #chat .chat__conversation-panel .emoji-button {
                min-width: 23px;
                width: 23px;
                height: 23px;
                background: transparent;
                border-radius: 50%;
            }

            #chat .chat__conversation-panel .emoji-button svg {
                width: 100%;
                fill: transparent;
                stroke: #54575c;
            }

            #chat .chat__conversation-panel .send-message-button {
                background: var(--chat-send-button-background);
                height: 30px;
                min-width: 30px;
                border-radius: 50%;
                transition: .3s ease;
            }

            #chat .chat__conversation-panel .send-message-button:active {
                transform: scale(0.97);
            }

            #chat .chat__conversation-panel .send-message-button svg {
                margin: 1px -1px;
            }

            #chat .chat__conversation-panel__input {
                width: 100%;
                height: 100%;
                outline: none;
                position: relative;
                color: var(--chat-text-color);
                font-size: 13px;
                background: transparent;
                border: 0;
                font-family: 'Lato', sans-serif;
                resize: none;
            }

            @media only screen and (max-width: 600px) {
                #chat {
                    margin: 0;
                    border-radius: 0;
                }

                #chat .chat__conversation-board {
                    height: calc(100vh - 55px - 2em - .5em - 3em);
                }

                #chat .chat__conversation-board__message__options {
                    display: none !important;
                }
            }
        </style>
    </head>
    <body>
        <div class="--dark-theme" id="chat">
            <div class="chat__conversation-board" id="chatBoard"></div>
            <div class="chat__conversation-panel">
                <div class="chat__conversation-panel__container">
                    <input class="chat__conversation-panel__input panel-item" id="messageInput" placeholder="Type a message..."/>
                    <button class="chat__conversation-panel__button panel-item btn-icon send-message-button" id="sendButton">
                        <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
                            <line x1="22" y1="2" x2="11" y2="13"></line>
                            <polygon points="22 2 15 22 11 13 2 9 22 2"></polygon>
                        </svg>
                    </button>
                </div>
            </div>
        </div>
        <script>
            document.addEventListener("DOMContentLoaded", () => {
                const messageInput = document.getElementById("messageInput");
                const sendButton = document.getElementById("sendButton");
                const chatBoard = document.getElementById("chatBoard");

                // Load chat history from localStorage
                const chatHistory = JSON.parse(localStorage.getItem("chatHistory")) || [];
                chatHistory.forEach(({ user, assistant }) => addMessageToBoard(user, assistant));

                // Add loading animation while awaiting response
                function addLoadingAnimation() {
                    const loadingElem = document.createElement("div");
                    loadingElem.classList.add("chat__conversation-board__message-container", "loading");
                    loadingElem.innerHTML = `
                        <div class="chat__conversation-board__message__bubble">
                            <span>Typing...</span>
                        </div>
                    `;
                    chatBoard.appendChild(loadingElem);
                    chatBoard.scrollTop = chatBoard.scrollHeight;
                    return loadingElem;
                }

                // Add message to chat board
                function addMessageToBoard(userMessage, assistantMessage) {
                    const userElem = document.createElement("div");
                    userElem.classList.add("chat__conversation-board__message-container", "reversed");
                    userElem.innerHTML = `
                        <div class="chat__conversation-board__message__context">
                            <div class="chat__conversation-board__message__bubble">
                                <span>${userMessage}</span>
                            </div>
                        </div>
                    `;
                    chatBoard.appendChild(userElem);

                    if (assistantMessage) {
                        const assistantElem = document.createElement("div");
                        assistantElem.classList.add("chat__conversation-board__message-container");
                        assistantElem.innerHTML = `
                            <div class="chat__conversation-board__message__context">
                                <div class="chat__conversation-board__message__bubble">
                                    <span>${assistantMessage}</span>
                                </div>
                            </div>
                        `;
                        chatBoard.appendChild(assistantElem);
                    }
                    chatBoard.scrollTop = chatBoard.scrollHeight;
                }

                // Send message to bot and get response
                async function sendMessage() {
                    const userMessage = messageInput.value.trim();
                    if (!userMessage) return;

                    // Display user message and clear input
                    addMessageToBoard(userMessage, null);
                    messageInput.value = "";

                    // Show loading animation
                    const loadingElem = addLoadingAnimation();

                    // Send request to server
                    try {
                        const response = await fetch("/chat", {
                            method: "POST",
                            headers: { "Content-Type": "application/json" },
                            body: JSON.stringify({ instruction: "", question: userMessage })
                        });
                        const data = await response.json();

                        // Remove loading animation
                        loadingElem.remove();

                        // Display assistant's response
                        addMessageToBoard(null, data.response);

                        // Save to chat history
                        chatHistory.push({ user: userMessage, assistant: data.response });
                        localStorage.setItem("chatHistory", JSON.stringify(chatHistory));
                    } catch (error) {
                        console.error("Error sending message:", error);
                    }
                }

                // Event listener for sending message
                sendButton.addEventListener("click", sendMessage);
                messageInput.addEventListener("keydown", (e) => {
                    if (e.key === "Enter") sendMessage();
                });
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)


Writing main.py


### 12.2 Getting server IP address

In [None]:
!curl ipinfo.io

{
  "ip": "34.16.221.9",
  "hostname": "9.221.16.34.bc.googleusercontent.com",
  "city": "Las Vegas",
  "region": "Nevada",
  "country": "US",
  "loc": "36.1750,-115.1372",
  "org": "AS396982 Google LLC",
  "postal": "89111",
  "timezone": "America/Los_Angeles",
  "readme": "https://ipinfo.io/missingauth"
}

### 12.3 Running FastAPI Application

In [None]:
!uvicorn main:app --host 0.0.0.0 --port 8000 &

# Use LocalTunnel to expose the FastAPI app
!lt --port 8000

2024-11-02 22:07:38.225539: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1730585258.246021    8390 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1730585258.252149    8390 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-02 22:07:42.586807: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
[32mINFO[0m:     Started server process [[36m8390[0m]
[32mINFO[0m:     Waiting for application startup.
[32mINFO[0m:     Ap

# 13. Conclusion

### 13.1 Summary of Fine-Tuning Process