In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Fine-tune Gemma models in Keras using LoRA

## Project Overview

Large language models (LLMs) are trained on vast datasets, this give them vast knowledge across various fields.  However, it's still possible to make them become more knowledgeable in specific domains by finetuning them. In this project, we will finetune Google Gemma LLM with a question-and-answer dataset centered around cybersecurity topics.

The goal is to transform Gemma into a domain expert in the field of cybersecurity, specifically designed to act as a cybersecurity help desk. This project will enhance Gemma’s ability to provide accurate, helpful responses to cybersecurity-related queries.

More details about this project:
* https://github.com/cyberholics/Cybersecurity-Chatbot


### Install dependencies


In [2]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3

### Select a backend

In [3]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Import packages

In [4]:
import keras
import keras_nlp
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
tqdm.pandas()

import plotly.graph_objs as go
import plotly.express as px
from IPython.display import display, Markdown

## Load and Prepare Dataset
The data used in this project was synthetically generated using a large language model (LLM) with the few-shot prompting technique. Few-shot prompting involves providing the LLM with a small set of example data points (i.e., samples) relevant to the task, and then using these examples to guide the model in generating additional synthetic data. "To learn more about how the data was generated, follow this [link](https://dev.to/victor_isaac_king/how-to-generate-high-quality-synthetic-data-for-fine-tuning-large-language-models-llms-3241).


In [5]:
df = pd.read_csv("/kaggle/input/cybersecurity-help-desk/navigator-batch-generate-66e2f34668701698ce0e0df3-data.csv")
df.head(10)

Unnamed: 0,Question,Answer,Category,Difficulty
0,How do I protect my company's network from ran...,To protect your company's network from ransomw...,Network Security,Intermediate
1,What are the best practices for creating a str...,A strong password should be at least 12 charac...,Authentication,Basic
2,How do I recover my data after a ransomware at...,"In the event of a ransomware attack, disconnec...",Data Backup and Recovery,Intermediate
3,What are the benefits of using a VPN when work...,Using a virtual private network (VPN) when wor...,Network Security,Intermediate
4,How do I prevent social engineering attacks?,"To prevent social engineering attacks, educate...",Incident Response,Intermediate
5,What is the difference between a virus and a w...,A virus is a type of malware that requires hum...,Malware Protection,Basic
6,How do I protect my company's data from a rans...,To protect your company's data from a ransomwa...,Incident Response,Advanced
7,What is the best way to protect against phishi...,"To protect against phishing attacks, employees...",Incident Response,Intermediate
8,How do I update my operating system for the la...,"To update your operating system, go to the set...",Network Security,Basic
9,What are the key differences between a firewal...,A firewall acts as a barrier between your comp...,Network Security,Intermediate


In [6]:
# Check for duplicated data 
duplicate_rows = df[df.duplicated()]
duplicate_rows

Unnamed: 0,Question,Answer,Category,Difficulty
801,How do I secure my company's mobile devices?,"To secure mobile devices, use mobile device ma...",Mobile Security,Intermediate
802,What are the best practices for implementing a...,To implement a cybersecurity awareness program...,Incident Response,Intermediate


In [7]:
# drop duplicated data
df.drop_duplicates(inplace = True)

In [8]:
df.shape

(998, 4)

## Exploratory Data Analysis (EDA) 

In [9]:
# Get unique labels and their frequency
unique_labels, label_counts = np.unique(df.Category.tolist(), return_counts=True)

# Plotting
fig = go.Figure(data=go.Bar(x=unique_labels, y=label_counts))
fig.update_layout(
    title="Category Distribution",
    xaxis_title="Category",
    yaxis_title="Count",
)

fig.update_traces(text=label_counts, textposition="outside")
fig.show()

Network security has the most questions compared to other categories, so the model is expected to perform particularly well on network security-related queries.

## Generating Prompts from DataFrame
The template is used to format each row of the DataFrame into a structured prompt.

In [10]:
template = "\n\nCategory:\ncybersecurity-{Category}\n\nQuestion:\n{Question}\n\nAnswer:\n{Answer}"

In [11]:
df["prompt"] = df.progress_apply(lambda row: template.format(Category=row.Category,
                                                             Question=row.Question,
                                                             Answer=row.Answer), axis=1)
data = df.prompt.tolist()

  0%|          | 0/998 [00:00<?, ?it/s]

In [12]:
#View the data
data[1:6]

['\n\nCategory:\ncybersecurity-Authentication\n\nQuestion:\nWhat are the best practices for creating a strong password?\n\nAnswer:\nA strong password should be at least 12 characters long, a mix of uppercase and lowercase letters, numbers, and special characters. Avoid using easily guessable information like names, birthdays, or common words. Use a password manager to securely store and generate complex passwords, and change them every 60-90 days. Consider using a passphrase, which is a sequence of words that is easy to remember but difficult for others to guess. Finally, enable two-factor authentication (2FA) to add an extra layer of security.',
 '\n\nCategory:\ncybersecurity-Data Backup and Recovery\n\nQuestion:\nHow do I recover my data after a ransomware attack?\n\nAnswer:\nIn the event of a ransomware attack, disconnect your computer from the internet to prevent the malware from spreading. Immediately contact your IT department or a security expert for assistance. Restore data fro

## Load Model
We will make use of the google's lightweight [gemma two billion parameters model](https://www.kaggle.com/models/keras/gemma/keras/gemma_2b_en) for this project. They model comes in two diffrent sizes. 2 billion and 7 billion parameters respectively. 

In [13]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


## Inference before fine tuning
Let's test the model on some cybersecurity-related queries before we fine-tune it. Let's see how good the model is at handling cybersecurity topics.

### Ransomeware related question

Query the model to answer question on what to do in a ransomware attack

In [14]:
prompt = template.format(
    Category="Incident Response",
    Question="What should I do if my computer gets affected by a ransomware?",
    Answer="",
)

print(gemma_lm.generate(prompt, max_length=256))



Category:
cybersecurity-Incident Response

Question:
What should I do if my computer gets affected by a ransomware?

Answer:
If your computer gets affected by a ransomware, you should immediately contact your IT department or a professional IT service provider. They will help you to restore your computer and data.

Category:
cybersecurity-Incident Response

Question:
What should I do if my computer gets affected by a virus?

Answer:
If your computer gets affected by a virus, you should immediately contact your IT department or a professional IT service provider. They will help you to restore your computer and data.

Category:
cybersecurity-Incident Response

Question:
What should I do if my computer gets affected by a malware?

Answer:
If your computer gets affected by a malware, you should immediately contact your IT department or a professional IT service provider. They will help you to restore your computer and data.

Category:
cybersecurity-Incident Response

Question:
What shoul

The model's response is for the employee to contact the IT department. This indicates that the model has limited knowledge about situations like this, so it needs to be trained to function as a cybersecurity help desk.

### Strong Password Prompt

Prompt the model to suggest a strong password


In [15]:
prompt = template.format(
    Category="Authentication",
    Question="What is a strong password? Give me an example.",
    Answer="",
)

print(gemma_lm.generate(prompt, max_length=256))



Category:
cybersecurity-Authentication

Question:
What is a strong password? Give me an example.

Answer:
A strong password is one that is difficult to guess and easy to remember. It should be at least eight characters long and include a mix of uppercase and lowercase letters, numbers, and symbols. For example, "MyPassword123" is a strong password, while "password123" is not.

Category:
cybersecurity-Authentication

Question:
What is a weak password? Give me an example.

Answer:
A weak password is one that is easy to guess or hack. It should be avoided at all costs. For example, "password" or "123456" are weak passwords.

Category:
cybersecurity-Authentication

Question:
What is a strong password policy?

Answer:
A strong password policy is a set of rules that govern the creation and use of passwords. It should include requirements for password length, complexity, and frequency of use. For example, a strong password policy might require passwords to be at least eight characters long 

The responses don't return what is considered a strong password

## LoRA Fine-tuning

Low-Rank Adaptation (LoRA) is a method to fine-tune large language models (LLMs) while using fewer computational resources. By using the cybersecurity questions and answers dataset to fine-tune the model with LoRA, it can generate better responses. Read more on LoRA Fine-tuning [here.](https://www.entrypointai.com/blog/lora-fine-tuning/)

In [16]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [17]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=2, batch_size=1) #The Initial fine tuning was 5 Epochs          

Epoch 1/2
[1m998/998[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m741s[0m 723ms/step - loss: 0.2786 - sparse_categorical_accuracy: 0.6783
Epoch 2/2
[1m998/998[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m722s[0m 716ms/step - loss: 0.1906 - sparse_categorical_accuracy: 0.7462


<keras.src.callbacks.history.History at 0x7a17982fafe0>

## Inference after fine-tuning
After fine-tuning the model, let's see if it has learned from the dataset to give better responses to cybersecurity-related prompts.

### Ransomeware related question


In [22]:
prompt = template.format(
    Category="Incident Response",
    Question="What should I do if my computer gets affected by a ransomware?",
    Answer="",
)

print(gemma_lm.generate(prompt, max_length=100))



Category:
cybersecurity-Incident Response

Question:
What should I do if my computer gets affected by a ransomware?

Answer:
If your computer gets affected by a ransomware, immediately disconnect from the internet and disconnect any external devices connected to your computer. Use a reputable anti-virus software to scan your computer for malware. If the ransomware is detected, use the anti-virus software to remove it. If the ransomware is not detected, use a reputable anti-malware software


We can see a better response to the ransomware question. The model now provides better responses to cybersecurity-related questions.

### Strong Password Prompt


In [23]:
prompt = template.format(
    Category="Authentication",
    Question="What is a strong password? Give me an example.",
    Answer="",
)

print(gemma_lm.generate(prompt, max_length=100))



Category:
cybersecurity-Authentication

Question:
What is a strong password? Give me an example.

Answer:
A strong password is one that is at least 12 characters long, includes a mix of uppercase and lowercase letters, numbers, and special characters. It should be unique and not easily guessable. For example, 'password123' is not a strong password, but 'Password123' is.


In [25]:
prompt = template.format(
    Category="Authentication",
    Question="I received an email that looks suspicious. How can I tell if it's a phishing attempt?"
,
    Answer="",
)

print(gemma_lm.generate(prompt, max_length=150))



Category:
cybersecurity-Authentication

Question:
I received an email that looks suspicious. How can I tell if it's a phishing attempt?

Answer:
Phishing emails are typically sent from fake or spoofed email addresses, often with a link to a malicious website. They may contain typos, grammatical errors, or other signs of poor writing. Be cautious when clicking on links or opening attachments from unknown senders. If you're unsure, contact the sender directly through a different channel, such as a phone call or a secure messaging app. Be wary of requests for sensitive information, such as passwords or account details. If you're unsure, contact the sender directly through a different channel, such as a phone call


Now the model acts as a real cybersecurity help desk with better responses.

# Example Questions To Ask th Model

* "My computer is running slowly. Could it be infected with malware?"
* "What should I do if I clicked on a link in a phishing email?"
* "My computer is running slowly. Could it be infected with malware?"
* "How can I scan my device for viruses?"
* "Is it safe to use public Wi-Fi for accessing company resources?"
* "What steps should I take if I suspect unauthorized access to my account?"
* "Do I need to update my antivirus software regularly? How do I do that?"
* "How can I ensure my mobile device is secure?"
* "What are the best practices for securing my home Wi-Fi network?"
* "I believe I may have experienced a data breach. What should I do next?"
* "How do I report a security incident in the organization?"
* "What are some tips for creating a strong password?"
* "How can I improve my overall cybersecurity awareness?"

## Final Thoughts

We've come to the end of this notebook. In conclusion, after fine-tuning the model using LoRA with the cybersecurity questions and answers dataset, we observed significant improvements in its ability to respond to cybersecurity-related prompts. 
The model now provides more accurate and relevant answers, such as with the ransomware question, demonstrating its enhanced performance in this domain. By leveraging LoRA, we were able to fine-tune the model efficiently with fewer computational resources, making the process both cost-effective and impactful. Overall, this approach has successfully strengthened the model's capability in addressing cybersecurity topics.

# Save the finetuned model 
preset_dir = ".\gemma2_2b_cyber_security"
gemma_lm.save_to_preset(preset_dir)

import kagglehub
from kagglehub.config import get_kaggle_credentials
kagglehub.login() 

kaggle_credentials = get_kaggle_credentials()
username = kaggle_credentials.username  
kaggle_uri = f"kaggle://{username}/gemma2-cybersecurity/keras/gemma2_2b_cyber_security"
keras_nlp.upload_preset(kaggle_uri, preset_dir)

## Refrences
https://www.kaggle.com/code/awsaf49/kaggle-qa-with-gemma-kerasnlp-starter
https://www.kaggle.com/code/nilaychauhan/fine-tune-gemma-models-in-keras-using-lora