In [1]:
# Step 1 — Check and install transformers if missing

import importlib
import sys
import subprocess

package_name = "transformers"

spec = importlib.util.find_spec(package_name)
if spec is None:
    print(f"'{package_name}' not found — installing now...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
    print("Installation complete!")
else:
    print(f"'{package_name}' is already installed.")


'transformers' is already installed.


In [None]:
# Only run once per environment
#!pip install sentence-transformers

In [2]:
# Step 2 — Imports
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

In [3]:
# Cell 3 — Create a Knowledge Base (KB)

kb_text = """
There are over 70 species of possums found across Australia and nearby islands.
Most possums weigh between 1 and 5 kilograms, though some species can reach up to 10 kilograms.

In urban areas, possums are reported in 3 out of 5 households with roof cavities,
showing their strong adaptation to human environments.

Ecologically, possums disperse thousands of seeds each year and reduce insect populations by up to 40%.
Despite sometimes being considered pests, they remain protected under Australian wildlife laws.

Important distinction: possums (Australia) vs. opossums (North America).
Opossums include about 100 species, while possums number closer to 70.
"""

# Split the KB into chunks (simple split by double newline)
kb_chunks = kb_text.strip().split("\n\n")
kb_chunks


['There are over 70 species of possums found across Australia and nearby islands.\nMost possums weigh between 1 and 5 kilograms, though some species can reach up to 10 kilograms.',
 'In urban areas, possums are reported in 3 out of 5 households with roof cavities,\nshowing their strong adaptation to human environments.',
 'Ecologically, possums disperse thousands of seeds each year and reduce insect populations by up to 40%.\nDespite sometimes being considered pests, they remain protected under Australian wildlife laws.',
 'Important distinction: possums (Australia) vs. opossums (North America).\nOpossums include about 100 species, while possums number closer to 70.']

In [4]:
# Step 4 - Embedding Model
embedder = SentenceTransformer("all-MiniLM-L6-v2")

# Generate embeddings for each KB chunk
kb_embeddings = embedder.encode(kb_chunks)
kb_embeddings = np.array(kb_embeddings)


In [5]:
# Step 5 — Retrieval Function

from sklearn.metrics.pairwise import cosine_similarity

def retrieve(query, top_k=2):
    # Embed the query
    query_embedding = embedder.encode([query])

    # Compute similarity scores
    scores = cosine_similarity(query_embedding, kb_embeddings)[0]

    # Get top-k highest scoring chunks
    top_indices = scores.argsort()[::-1][:top_k]

    # Return retrieved chunks with their scores
    retrieved = [(kb_chunks[i], float(scores[i])) for i in top_indices]
    return retrieved


In [6]:
# Step 6 — Load Generation Model (T5-Small)
gen_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
gen_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")

In [7]:
# Step 7 — RAG: Combine Context + Query and Generate Answer
def rag_answer(query):
    # Retrieve relevant chunks
    retrieved = retrieve(query)
    context_text = "\n\n".join([c for c, s in retrieved])

    # Build augmented prompt
    prompt = f"""
Use the following context to answer the question.

CONTEXT:
{context_text}

QUESTION:
{query}

ANSWER:
"""

    # Tokenize and generate
    inputs = gen_tokenizer(prompt, return_tensors="pt", truncation=True)
    output = gen_model.generate(**inputs, max_length=256)
    answer = gen_tokenizer.decode(output[0], skip_special_tokens=True)

    return answer, retrieved


In [8]:
# Step 8 — Test Case 1 (Factual)
query1 = "What do possums eat and how do they use their tails?"
answer1, retrieved1 = rag_answer(query1)
print("ANSWER:", answer1)
print("\nRETRIEVED:", retrieved1)

ANSWER: (iii).

RETRIEVED: [('Ecologically, possums disperse thousands of seeds each year and reduce insect populations by up to 40%.\nDespite sometimes being considered pests, they remain protected under Australian wildlife laws.', 0.5258525013923645), ('There are over 70 species of possums found across Australia and nearby islands.\nMost possums weigh between 1 and 5 kilograms, though some species can reach up to 10 kilograms.', 0.5241583585739136)]


In [9]:
# Step 9 — Test Case 2 (Out-of-Scope / General Knowledge)
query2 = "Are possums native to Australia?"
answer2, retrieved2 = rag_answer(query2)
print("ANSWER:", answer2)
print("\nRETRIEVED:", retrieved2)

ANSWER: b).

RETRIEVED: [('There are over 70 species of possums found across Australia and nearby islands.\nMost possums weigh between 1 and 5 kilograms, though some species can reach up to 10 kilograms.', 0.7462597489356995), ('Important distinction: possums (Australia) vs. opossums (North America).\nOpossums include about 100 species, while possums number closer to 70.', 0.713020920753479)]


In [10]:
# CELL 10 — Test Case 3 (Synthesis)
query3 = "How do possums help the environment and protect themselves from predators?"
answer3, retrieved3 = rag_answer(query3)
print("ANSWER:", answer3)
print("\nRETRIEVED:", retrieved3)

ANSWER: b).

RETRIEVED: [('Ecologically, possums disperse thousands of seeds each year and reduce insect populations by up to 40%.\nDespite sometimes being considered pests, they remain protected under Australian wildlife laws.', 0.7011035680770874), ('In urban areas, possums are reported in 3 out of 5 households with roof cavities,\nshowing their strong adaptation to human environments.', 0.6443817019462585)]
