In [1]:
!pip install sentence-transformers
!pip install chromadb



In [2]:
pip install langchain langchain-community

Note: you may need to restart the kernel to use updated packages.


In [3]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

In [4]:
import keras
import keras_nlp

In [5]:
import json

data = []
with open('/kaggle/input/ml-dataset/ML_dataset (1).json') as file:
    feature_list = json.load(file)
    # Filter out examples with context, to keep it simple.
    for features in feature_list:
        template = "Instruction:\n{line}"
        data.append(template.format(**features))

In [6]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [7]:
instruction = "What is data science?"
prompt = template.format(line=instruction, response="")

print(gemma_lm.generate(prompt, max_length=100))

Instruction:
What is data science?
Data science is the process of collecting, cleaning, and analyzing data to extract useful information. Data science is a combination of statistics, mathematics, and computer science.
Data science is a process of collecting, cleaning, and analyzing data to extract useful information. Data science is a combination of statistics, mathematics, and computer science.
Data science is a process of collecting, cleaning, and analyzing data to extract useful information. Data science is a combination of


In [8]:
prompt = template.format(
    line="What is data engineering?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=150))

Instruction:
What is data engineering?
Data engineering is the process of collecting, storing, and analyzing data to extract insights and make informed decisions. It involves the design, development, and maintenance of data systems, as well as the creation of data-driven applications and services. Data engineering is a rapidly growing field, as organizations increasingly rely on data to drive decision-making and improve business performance.
Data engineering is a complex and multifaceted discipline, requiring a deep understanding of data structures, algorithms, and software engineering principles. It involves a wide range of skills, including data modeling, data warehousing, data mining, and data visualization. Data engineers work closely with data scientists, software developers, and business analysts to design and implement data-driven solutions


In [9]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [10]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

[1m1554/1554[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1155s[0m 731ms/step - loss: 0.1760 - sparse_categorical_accuracy: 0.4284


<keras.src.callbacks.history.History at 0x7c727807de40>

In [11]:
template = "Instruction:\n{line}\n\nResponse:\n{response}"
prompt = template.format(
    line="What is data science?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=150))

Instruction:
What is data science?

Response:
Data science is the process of collecting, cleaning, and analyzing data to extract insights and make predictions.


In [19]:
prompt = template.format(
    line="What is data engineering?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=150))

Instruction:
What is data engineering?

Response:
Data engineering is the process of building, managing, and maintaining data pipelines. It involves designing, building, and maintaining data pipelines that move data from its source to its destination. Data engineers are responsible for ensuring that data is accurate, complete, and timely.


In [20]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain.document_loaders import JSONLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma

from IPython.display import display, Markdown
import json
from langchain.schema import Document

In [21]:
def create_prompt(self, query, context):
        # prompt template
        prompt = f"""
        You are a ML Chatbot specialized to answer and provide learning resources to data enthusiast.
        Explain the concept or answer their question about Data space.
        In order to create the answer, please only use the information from the
        context provided (Context). Do not include other information.
        Answer with simple words.
        If needed, include also explanations.
        Question: {query}
        Context: {context}
        Answer:
        """
        return prompt

In [22]:
class RAGSystem:
    """Retrieval-Augmented Generation System with Keras Model."""
    def __init__(self, gemma_lm, num_retrieved_docs=2, data_path="/kaggle/input/ml-dataset/ML_dataset (1).json"):
        # Set up the model and retriever parameters
        self.num_docs = num_retrieved_docs
        self.ai_agent = gemma_lm
        self.template = "\n\nQuestion:\n{question}\n\nPrompt:\n{prompt}\n\nAnswer:\n{answer}\n\nContext:\n{context}"

         # Load JSON data without jq
        with open(data_path, 'r') as file:
            raw_data  = json.load(file)
            
        # Format JSON data to a list of Document objects with `page_content`
        documents = [Document(page_content=doc.get("line", ""), metadata={}) for doc in raw_data]
        
        # Text splitting
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
        all_splits = text_splitter.split_documents(documents)

        # Embeddings for retrieval
        embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
        self.vector_db = Chroma.from_documents(documents=all_splits, embedding=embeddings, persist_directory="chroma_db")
        self.retriever = self.vector_db.as_retriever()

    def retrieve(self, query):
        """Retrieve top k similar documents based on the query."""
        docs = self.retriever.get_relevant_documents(query)
        return docs

    def query(self, query):
        """Generate an answer based on retrieved documents and query."""
        # Retrieve context documents
        context_docs = self.retrieve(query)
        context = "".join([doc.page_content for doc in context_docs[:self.num_docs]])
        
        # Generate response with the model
        prompt = f"{query}\n\nContext:\n{context}"
        generated_output = self.ai_agent.generate(prompt, max_length=150)

        return self.template.format(question=query, prompt=prompt, answer=generated_output, context=context)

In [23]:
def colorize_text(text):
    for word, color in zip(["Question", "Prompt", "Answer", "Context"], ["blue", "magenta", "red", "green"]):
        text = text.replace(f"\n\n{word}:", f"\n\n**<font color='{color}'>{word}:</font>**")
    return text

In [24]:
rag_system = RAGSystem(gemma_lm)



In [25]:
# testing the query
question = "What is data science?"
response = rag_system.query(question)
display(Markdown(colorize_text(response)))



**<font color='blue'>Question:</font>**
What is data science?

**<font color='magenta'>Prompt:</font>**
What is data science?

**<font color='green'>Context:</font>**
What is data science?What is data science?

**<font color='red'>Answer:</font>**
What is data science?

**<font color='green'>Context:</font>**
What is data science?What is data science?

Data science is the process of extracting knowledge from data. It involves collecting, cleaning, and analyzing data to answer specific questions or solve problems. Data scientists use a variety of techniques, including machine learning, statistics, and data visualization, to extract insights from data.

Data science is a rapidly growing field, with a wide range of applications in various industries, including healthcare, finance, and marketing. It is a highly technical field, requiring a combination of skills in programming, data analysis, and problem-solving.

Data science is a rapidly growing field, with a wide range of applications in various industries, including healthcare, finance, and marketing.

**<font color='green'>Context:</font>**
What is data science?What is data science?

In [26]:
# trying another question
question = "What is Machine learning?"
response = rag_system.query(question)
display(Markdown(colorize_text(response)))



**<font color='blue'>Question:</font>**
What is Machine learning?

**<font color='magenta'>Prompt:</font>**
What is Machine learning?

**<font color='green'>Context:</font>**
What is Machine Learning?What is Machine Learning?

**<font color='red'>Answer:</font>**
What is Machine learning?

**<font color='green'>Context:</font>**
What is Machine Learning?What is Machine Learning?

Machine learning is a branch of artificial intelligence (AI) that focuses on the development of computer systems that can learn and improve from experience without being explicitly programmed.

Machine learning algorithms are used to extract patterns and insights from data, and can be applied to a wide range of tasks, including image classification, natural language processing, and recommendation systems.

Machine learning is often used in conjunction with other AI technologies, such as deep learning and natural language processing, to create more advanced and powerful systems.

The goal of machine learning is to develop systems that can learn and improve without being explicitly programmed, and to create systems that can learn from experience and improve

**<font color='green'>Context:</font>**
What is Machine Learning?What is Machine Learning?

In [18]:
gemma_lm.save('gemma_lm.keras')

KeyboardInterrupt: 