# Artificial intelligence & machine learning Q&A Enhanced with Gemma 2 2b-it

We begin the Kaggle notebook by installing several Python packages using pip:

* The first line installs the PyTorch library, which is widely used for deep learning tasks, particularly for neural networks. The -q flag suppresses output except for errors, and -U ensures that the installed version of PyTorch is updated to the latest release. The --index-url flag specifies a custom package index URL, which in this case points to a specific PyTorch wheel for CUDA 11.7, enabling GPU acceleration.

* Next, we install the bitsandbytes package from the Python Package Index (PyPI). Similar to the previous line, -q suppresses installation output, -U updates the package if it's already installed, and -i specifies the package index URL.

* After that, the transformers library is installed, which offers state-of-the-art natural language processing (NLP) models like BERT, GPT, and others. As before, the -q flag suppresses output, and -U ensures that the latest version is installed.

* The following line installs the accelerate library, which provides utilities for high-performance computing, particularly for deep learning tasks. The -q and -U flags are once again used to suppress output and update the package.

* We then install the datasets library, which simplifies access to a wide range of datasets for machine learning tasks. Again, the -q and -U flags are employed for a quiet installation and to ensure the latest version is installed.

* Next, we install the trl library (Transformers Reinforcement Learning), a comprehensive library by Hugging Face that provides tools for training transformer-based models with reinforcement learning, from Supervised Fine-Tuning (SFT) and Reward Modeling (RM) to Proximal Policy Optimization (PPO). The -q and -U flags are used as before.

* We also install the peft (Parameter-Efficient Fine-Tuning) library from Hugging Face, which enables efficient adaptation of pre-trained language models (PLMs) to various downstream applications without needing to fine-tune all the model’s parameters. PEFT methods significantly reduce the computational and storage costs by only fine-tuning a small subset of the model's parameters.

* Finally, the last line installs the wikipedia-api library, which provides a simple interface to interact with Wikipedia data. As with the other installations, -q suppresses output, and -U ensures the package is up-to-date.

In [1]:
!pip install -q -U torch --index-url https://download.pytorch.org/whl/cu117
!pip install -q -U -i https://pypi.org/simple/ bitsandbytes
!pip install -q -U transformers
!pip install -q -U accelerate
!pip install -q -U datasets
!pip install -q -U trl
!pip install -q -U peft
!pip install -q -U wikipedia-api

We then proceed by importing the os module and setting two environment variables:

* CUDA_VISIBLE_DEVICES: This variable instructs PyTorch on which GPU(s) to use. Setting it to 0 specifies that only the first GPU will be utilized by PyTorch for computations.

* TOKENIZERS_PARALLELISM: This variable controls whether the Hugging Face Transformers library parallelizes the tokenization process. By setting it to false, tokenization is run in a single thread, preventing parallelization

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

The following code snippet imports the warnings module and configures it to suppress all warnings. This prevents any warnings from being displayed. While these warnings typically don’t affect the fine-tuning process, they can be distracting and may cause unnecessary concern during training.

In [3]:
import warnings
warnings.filterwarnings("ignore")

The next cell contains all the main imports needed to run the notebook:

In [4]:
import re
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import wikipediaapi

import torch
import torch.nn as nn

import transformers
from transformers import (AutoModelForCausalLM, 
                          AutoTokenizer,
                          AutoConfig,
                          BitsAndBytesConfig, 
                          TrainingArguments,
                          )

from datasets import Dataset
from peft import LoraConfig, PeftConfig
import bitsandbytes as bnb
from trl import SFTTrainer

We also define a function that determines the appropriate device (CPU, GPU, or macOS with MPS) for mapping the model and data when using PyTorch (as utilized by Hugging Face libraries). This ensures compatibility whether you're working on a CPU, GPU, or macOS system with Metal Performance Shaders (MPS).

In [5]:
def define_device():
    """Define the device to be used by PyTorch"""

    # Get the PyTorch version
    torch_version = torch.__version__

    # Print the PyTorch version
    print(f"PyTorch version: {torch_version}", end=" -- ")

    # Check if MPS (Multi-Process Service) device is available on MacOS
    if torch.backends.mps.is_available():
        # If MPS is available, print a message indicating its usage
        print("using MPS device on MacOS")
        # Define the device as MPS
        defined_device = torch.device("mps")
    else:
        # If MPS is not available, determine the device based on GPU availability
        defined_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Print a message indicating the selected device
        print(f"using {defined_device}")

    # Return the defined device
    return defined_device

# Step 1: get the knowledge base

Apart from the first two functions helpful in cleaning the text from tags and formatting, the following code extracts references, such as pages or other Wikipedia categories, using the extract_wikipedia_pages function. Then, the get_wikipedia_pages function crawls to all the pages and information related to some initial Wikipedia category or page.

In [6]:
# Pre-compile the regular expression pattern for better performance
BRACES_PATTERN = re.compile(r'\{.*?\}|\}')

def remove_braces_and_content(text):
    """Remove all occurrences of curly braces and their content from the given text"""
    return BRACES_PATTERN.sub('', text)

def clean_string(input_string):
    """Clean the input string."""
    
    # Remove extra spaces by splitting the string by spaces and joining back together
    cleaned_string = ' '.join(input_string.split())
    
    # Remove consecutive carriage return characters until there are no more consecutive occurrences
    cleaned_string = re.sub(r'\r+', '\r', cleaned_string)
    
    # Remove all occurrences of curly braces and their content from the cleaned string
    cleaned_string = remove_braces_and_content(cleaned_string)
    
    # Return the cleaned string
    return cleaned_string

In [7]:
def get_wikipedia_pages(categories):
    """Retrieve Wikipedia pages from a list of categories and extract their content"""
    
    # Create a Wikipedia object
    wiki_wiki = wikipediaapi.Wikipedia('Gemma AI Assistant (gemma@example.com)', 'en')
    
    # Initialize lists to store explored categories and Wikipedia pages
    explored_categories = []
    wikipedia_pages = []

    # Iterate through each category
    print("- Processing Wikipedia categories:")
    for category_name in categories:
        print(f"\tExploring {category_name} on Wikipedia")
        
        # Get the Wikipedia page corresponding to the category
        category = wiki_wiki.page("Category:" + category_name)
        
        # Extract Wikipedia pages from the category and extend the list
        wikipedia_pages.extend(extract_wikipedia_pages(wiki_wiki, category_name))
        
        # Add the explored category to the list
        explored_categories.append(category_name)

    # Extract subcategories and remove duplicate categories
    categories_to_explore = [item.replace("Category:", "") for item in wikipedia_pages if "Category:" in item]
    wikipedia_pages = list(set([item for item in wikipedia_pages if "Category:" not in item]))
    
    # Explore subcategories recursively
    while categories_to_explore:
        category_name = categories_to_explore.pop()
        print(f"\tExploring {category_name} on Wikipedia")
        
        # Extract more references from the subcategory
        more_refs = extract_wikipedia_pages(wiki_wiki, category_name)

        # Iterate through the references
        for ref in more_refs:
            # Check if the reference is a category
            if "Category:" in ref:
                new_category = ref.replace("Category:", "")
                # Add the new category to the explored categories list
                if new_category not in explored_categories:
                    explored_categories.append(new_category)
            else:
                # Add the reference to the Wikipedia pages list
                if ref not in wikipedia_pages:
                    wikipedia_pages.append(ref)

    # Initialize a list to store extracted texts
    extracted_texts = []
    
    # Iterate through each Wikipedia page
    print("- Processing Wikipedia pages:")
    for page_title in tqdm(wikipedia_pages):
        try:
            # Make a request to the Wikipedia page
            page = wiki_wiki.page(page_title)

            # Check if the page summary does not contain certain keywords
            if "Biden" not in page.summary and "Trump" not in page.summary:
                # Append the page title and summary to the extracted texts list
                if len(page.summary) > len(page.title):
                    extracted_texts.append(page.title + " : " + clean_string(page.summary))

                # Iterate through the sections in the page
                for section in page.sections:
                    # Append the page title and section text to the extracted texts list
                    if len(section.text) > len(page.title):
                        extracted_texts.append(page.title + " : " + clean_string(section.text))
                        
        except Exception as e:
            print(f"Error processing page {page.title}: {e}")
                    
    # Return the extracted texts
    return extracted_texts

In [8]:
def extract_wikipedia_pages(wiki_wiki, category_name):
    """Extract all references from a category on Wikipedia"""
    
    # Get the Wikipedia page corresponding to the provided category name
    category = wiki_wiki.page("Category:" + category_name)
    
    # Initialize an empty list to store page titles
    pages = []
    
    # Check if the category exists
    if category.exists():
        # Iterate through each article in the category and append its title to the list
        for article in category.categorymembers.values():
            pages.append(article.title)
    
    # Return the list of page titles
    return pages

To gather the information necessary to answer the most tricky questions about AI and machine learning, I’ve listed a few key categories putting enphasis on generative AI topics.

In [9]:
categories = ["Machine_learning", "Data_science", "Statistics", "Deep_learning", "Artificial_intelligence",
              "Neural_network_architectures", "Large_language_models", "OpenAI", "Generative_pre-trained_transformers",
              "Artificial_neural_networks", "Generative_artificial_intelligence", "Natural_language_processing"]
extracted_texts = get_wikipedia_pages(categories)
print("Found", len(extracted_texts), "Wikipedia pages")

- Processing Wikipedia categories:
	Exploring Machine_learning on Wikipedia
	Exploring Data_science on Wikipedia
	Exploring Statistics on Wikipedia
	Exploring Deep_learning on Wikipedia
	Exploring Artificial_intelligence on Wikipedia
	Exploring Neural_network_architectures on Wikipedia
	Exploring Large_language_models on Wikipedia
	Exploring OpenAI on Wikipedia
	Exploring Generative_pre-trained_transformers on Wikipedia
	Exploring Artificial_neural_networks on Wikipedia
	Exploring Generative_artificial_intelligence on Wikipedia
	Exploring Natural_language_processing on Wikipedia
	Exploring Tasks of natural language processing on Wikipedia
	Exploring Statistical natural language processing on Wikipedia
	Exploring Natural language processing researchers on Wikipedia
	Exploring Optical character recognition on Wikipedia
	Exploring Natural language processing software on Wikipedia
	Exploring Natural language generation on Wikipedia
	Exploring Machine translation on Wikipedia
	Exploring Fin

100%|██████████| 3819/3819 [27:41<00:00,  2.30it/s]

Found 17879 Wikipedia pages





# Step 2: convert the knowledge base into a Q&A dataset

Now, having collected our knowledge base on AI, we need to leverage Gemma to convert it into something more useful for training a model. The idea is to use a Q&A approach.

First, let’s upload Gemma 2 2b-it into memory by quantizing it into a 4-bit version using BitsAndBytes.

In [10]:
model_name = "/kaggle/input/gemma-2/transformers/gemma-2-2b-it/2"

compute_dtype = getattr(torch, "float16")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
)

config = AutoConfig.from_pretrained(model_name)
config.final_logit_softcapping = None  # Disable soft-capping

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    config=config,
    attn_implementation="eager",
    quantization_config=bnb_config, 
)

model.config.use_cache = False
model.config.pretraining_tp = 1

max_seq_length = 2304
tokenizer = AutoTokenizer.from_pretrained(model_name, max_seq_length=max_seq_length)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

A simple function can wrap up all the steps necessary to inquire about Gemma on a topic or pose a question. The function allows for the pointing out of different temperatures and can return the answer as a stdout or a string.

In [11]:
def question_gemma(question, model=model, tokenizer=tokenizer, temperature=0.0, return_answer=False):
    input_ids = tokenizer(question, return_tensors="pt").to("cuda") #.to(torch.float16)
    if temperature > 0:
        do_sample=True
    else:
        do_sample=False
    outputs = model.generate(**input_ids, 
                             max_new_tokens=2048, 
                             do_sample=do_sample, 
                             temperature=temperature)
    result = str(tokenizer.decode(outputs[0])).replace("<bos>", "").replace("<eos>", "").strip()
    if return_answer:
        return result
    else:
        print(result)

Using all of the knowledge base and posing multiple answers derived from the same text will help build out fine-tuning training data. Asking multiple answers is a necessity because Gemma will pick just a topic from the test, and it will tend to answer briefly.

We can control how Gemma returns the question and answer, proposing it to return a JSON file in the form {“question”: “…”, “answer”: “…”}. Hence, it will be easy to retrieve the data from the output text utilizing regex.

In [12]:
qa_data = []

def extract_json(text, word):
    pattern = fr'"{word}": "(.*?)"'
    match = re.search(pattern, text)
    if match:
        return match.group(1)
    else:
        return ""

no_extracted_texts = 300 # increment this number up to len(extracted_texts)
question_ratio = 24 # decrement this number to produce more questions (suggested: 24)

for i in tqdm(range(len(extracted_texts[:no_extracted_texts]))):

    question_text = f"""Create a question and its answer from the following piece of information,
    put all the necessary information into the question (do not assume the reader knows the text),
    and return it exclusively in JSON format in the format {'{"question": "...", "answer": "..."}'}

    Here is the piece of information to elaborate:
    {extracted_texts[i]}

    OUTPUT JSON:
    """

    no_questions = min(1, len(extracted_texts[i]) // question_ratio)
    for j in range(no_questions):
    
        result = question_gemma(question_text, model=model, temperature=0.9, return_answer=True)
        result = result.split("OUTPUT JSON:")[-1]

        question = extract_json(result, "question")
        answer = extract_json(result, "answer")

        qa_data.append(f"{question}\n{answer}")
        
    if i > 3:
        break

  1%|▏         | 4/300 [03:22<4:09:54, 50.66s/it]


Now that the dataset has been gathered, it is time to turn it into an HF Dataset.

In [13]:
train_data = (pd.DataFrame(qa_data, columns=["text"])
              .sample(frac=1, random_state=5)
              .drop_duplicates()
             )
train_data = Dataset.from_pandas(train_data)

# Step 3: fine-tune the Gemma model

In the following cells, LoRA is set, and the training parameters are defined. Afterward, the fine-tuning can start.

In [14]:
output_dir = "gemma_assistant"

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj",],
)

training_arguments = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=1,
    gradient_checkpointing=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    optim="paged_adamw_8bit",
    save_steps=0,
    logging_steps=25,
    learning_rate=5e-4,
    weight_decay=0.001,
    fp16=True, ###
    bf16=False,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=False,
    evaluation_strategy='steps',
    eval_steps = 500,
    eval_accumulation_steps=1,
    lr_scheduler_type="cosine",
    report_to="tensorboard",
)

In [15]:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    peft_config=peft_config,
    dataset_text_field="text",
    tokenizer=tokenizer,
    max_seq_length=max_seq_length,
    args=training_arguments,
    packing=False,
)

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

In [16]:
trainer.train()

Step,Training Loss,Validation Loss


TrainOutput(global_step=1, training_loss=1.5988729000091553, metrics={'train_runtime': 4.829, 'train_samples_per_second': 1.035, 'train_steps_per_second': 0.207, 'total_flos': 3376350798336.0, 'train_loss': 1.5988729000091553, 'epoch': 1.0})

After we finish, we simply save the model and try to reload it in order to check if everything works as expected.

In [17]:
trainer.save_model()
tokenizer.save_pretrained(output_dir)

('gemma_assistant/tokenizer_config.json',
 'gemma_assistant/special_tokens_map.json',
 'gemma_assistant/tokenizer.model',
 'gemma_assistant/added_tokens.json',
 'gemma_assistant/tokenizer.json')

In [18]:
from peft import AutoPeftModelForCausalLM

finetuned_model = output_dir
compute_dtype = getattr(torch, "float16")
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoPeftModelForCausalLM.from_pretrained(
     finetuned_model,
     torch_dtype=compute_dtype,
     return_dict=False,
     low_cpu_mem_usage=True,
     device_map="auto",
)

model = model.to('cuda')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# Step 4: save the LoRA weights and merge them into Gemma

Now, the tricky part is combining the trained LoRA weights with the Gemma original model. The result is our new fine-tuned Gemma!

This cell cleans up the CPU and GPU memory.

In [19]:
import gc

del [model, tokenizer, peft_config, trainer, train_data, bnb_config, training_arguments]
del [TrainingArguments, SFTTrainer, LoraConfig, BitsAndBytesConfig]

for _ in range(10):
    torch.cuda.empty_cache()
    gc.collect()

Now we proceed to the merging procedure:


In [20]:
from peft import AutoPeftModelForCausalLM

finetuned_model = output_dir
compute_dtype = getattr(torch, "float16")
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoPeftModelForCausalLM.from_pretrained(
     finetuned_model,
     torch_dtype=compute_dtype,
     return_dict=False,
     low_cpu_mem_usage=True,
     device_map="auto",
)

merged_model = model.merge_and_unload()
merged_model.save_pretrained("./gemma_assistant_merged",
                             safe_serialization=True, 
                             max_shard_size="2GB")
tokenizer.save_pretrained("./gemma_assistant_merged")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

('./gemma_assistant_merged/tokenizer_config.json',
 './gemma_assistant_merged/special_tokens_map.json',
 './gemma_assistant_merged/tokenizer.model',
 './gemma_assistant_merged/added_tokens.json',
 './gemma_assistant_merged/tokenizer.json')

Again, memory cleaning.

In [21]:
import gc

del [model, tokenizer, merged_model, AutoPeftModelForCausalLM]

for _ in range(10):
    torch.cuda.empty_cache()
    gc.collect()

In [22]:
for _ in range(10):
    torch.cuda.empty_cache()
    gc.collect()

The final step is reloading the fine-tuned model and try using it!

In [23]:
from transformers import (AutoModelForCausalLM, 
                          AutoTokenizer, 
                          BitsAndBytesConfig)

model_name = "./gemma_assistant_merged"

compute_dtype = getattr(torch, "float16")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    quantization_config=bnb_config, 
)

model.config.use_cache = False
model.config.pretraining_tp = 1

max_seq_length = 1024
tokenizer = AutoTokenizer.from_pretrained(model_name, max_seq_length=max_seq_length)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

We start by a series of DS and ML questions:

In [24]:
questions = ["In simple terms, what is a Large Language Model (LLM)?", 
             "What differentiates LLMs from traditional chatbots?", 
             "How are LLMs typically trained? (e.g., pre-training, fine-tuning)",
             "What are some of the typical applications of LLMs? (e.g., text generation, translation)", 
             "What is the role of transformers in LLM architecture?", 
             "Explain the concept of bias in LLM training data and its potential consequences.", 
             "How can prompt engineering be used to improve LLM outputs?", 
             "Describe some techniques for evaluating the performance of LLMs. (e.g., perplexity, BLEU score)", 
             "Discuss the limitations of LLMs, such as factual accuracy and reasoning abilities.", 
             "What are some ethical considerations surrounding the use of LLMs?", 
             "How do LLMs handle out-of-domain or nonsensical prompts?", 
             "Explain the concept of few-shot learning and its applications in fine-tuning LLMs.", 
             "What are the challenges associated with large-scale deployment of LLMs in real-world applications?", 
             "Discuss the role of LLMs in the broader field of artificial general intelligence (AGI).", 
             "How can the explainability and interpretability of LLM decisions be improved?", 
             "Compare and contrast LLM architectures, such as GPT-3 and LaMDA.", 
             "Explain the concept of self-attention and its role in LLM performance.", 
             "Discuss the ongoing research on mitigating bias in LLM training data and algorithms.", 
             "How can LLMs be leveraged to create more human-like conversations?", 
             "Explore the potential future applications of LLMs in various industries.", 
             "You are tasked with fine-tuning an LLM to write creative content. How would you approach this?", 
             "An LLM you’re working on starts generating offensive or factually incorrect outputs. How would you diagnose and address the issue?",
             "A client wants to use an LLM for customer service interactions. What are some critical considerations for this application?", 
             "How would you explain the concept of LLMs and their capabilities to a non-technical audience?", 
             "Imagine a future scenario where LLMs are widely integrated into daily life. What ethical concerns might arise?", 
             "Discuss some emerging trends in LLM research and development.", 
             "What are the potential societal implications of widespread LLM adoption?", 
             "How can we ensure the responsible development and deployment of LLMs?",]

In [25]:
for i, question in enumerate(questions):
    print(f"QUESTION {i}")
    question_gemma(question, model=model, tokenizer=tokenizer)
    print("-" * 64)

QUESTION 0
In simple terms, what is a Large Language Model (LLM)?

**A Large Language Model (LLM) is a type of artificial intelligence (AI) that can understand and generate human-like text.**

Here's a breakdown:

* **Artificial Intelligence (AI):**  This refers to machines that can perform tasks that typically require human intelligence, like learning, problem-solving, and decision-making.
* **Large Language Model (LLM):**  This is a specific type of AI that's trained on massive amounts of text data. This training allows them to:
    * **Understand and interpret human language:** They can analyze the meaning of words, sentences, and paragraphs.
    * **Generate human-like text:** They can write stories, poems, articles, summaries, and even code.
    * **Answer questions and provide information:** They can access and process information from their training data to answer your questions.

**Examples of LLMs:**

* ChatGPT
* Bard
* LaMDA
* GPT-3

**How LLMs work:**

LLMs are based on a te

Our final test is asking for help in understanding Gemma 2 paper:

Team, Gemma, et al. "Gemma 2: Improving open language models at a practical size." arXiv preprint arXiv:2408.00118 (2024). https://arxiv.org/pdf/2408.00118

In [26]:
arch = """
Similar to previous Gemma models (Gemma
Team, 2024), the Gemma 2 models are based on a
decoder-only transformer architecture (Vaswani et al., 2017). 
A few architectural elements are similar to the
first version of Gemma models; namely, a context
length of 8192 tokens, the use of Rotary Position Embeddings (RoPE) (Su et al., 2021), and
the approximated GeGLU non-linearity (Shazeer,
2020). A few elements differ between Gemma 1
and Gemma 2, including using deeper networks.
We summarize the key differences below.
Local Sliding Window and Global Attention.
We alternate between a local sliding window attention (Beltagy et al., 2020a,b) 
and global attention (Luong et al., 2015) in every other layer.
The sliding window size of local attention layers
is set to 4096 tokens, while the span of the global
attention layers is set to 8192 tokens.
Logit soft-capping. We cap logits (Bello et al.,
2016) in each attention layer and the final layer
such that the value of the logits stays between
−soft_cap and +soft_cap. More specifically, we
cap the logits with the following function:
logits ← soft_cap ∗ tanh(logits/soft_cap).
We set the soft_cap parameter to 50.0 for the selfattention layers and to 30.0 for the final layer.
Post-norm and pre-norm with RMSNorm. To
stabilize training, we use RMSNorm (Zhang and
Sennrich, 2019) to normalize the input and output of each transformer sub-layer, the attention
layer, and the feedforward layer.
Grouped-Query Attention (Ainslie et al., 2023).
We use GQA with num_groups = 2, based on ablations showing increased speed at inference time
while maintaining downstream performance.
"""

prompt = """You are acting as a valuable study assistant for AI/ML topics. 
Please explain the following technical excerpt, providing information on the mentioned technical topics: """
prompt += arch

question_gemma(prompt, model=model, tokenizer=tokenizer)

You are acting as a valuable study assistant for AI/ML topics. 
Please explain the following technical excerpt, providing information on the mentioned technical topics: 
Similar to previous Gemma models (Gemma
Team, 2024), the Gemma 2 models are based on a
decoder-only transformer architecture (Vaswani et al., 2017). 
A few architectural elements are similar to the
first version of Gemma models; namely, a context
length of 8192 tokens, the use of Rotary Position Embeddings (RoPE) (Su et al., 2021), and
the approximated GeGLU non-linearity (Shazeer,
2020). A few elements differ between Gemma 1
and Gemma 2, including using deeper networks.
We summarize the key differences below.
Local Sliding Window and Global Attention.
We alternate between a local sliding window attention (Beltagy et al., 2020a,b) 
and global attention (Luong et al., 2015) in every other layer.
The sliding window size of local attention layers
is set to 4096 tokens, while the span of the global
attention layers is set 

We conclude the tutorial here. By following the same steps, you can fine-tune Gemma for any topic.

Enjoy fine-tuning with Google Gemma!