In [1]:
# https://arxiv.org/pdf/2005.11401

import copy
import torch
from transformers import AutoTokenizer, GPT2Model, AutoModelForSeq2SeqLM

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [3]:
embedding_tokenizer = AutoTokenizer.from_pretrained("gpt2")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [4]:
chunks = [
    "The town of Alderwick is known for its glassblowing tradition, which dates back to the late 1700s. Families there still pass down specialized techniques for producing colored glass using local minerals.",
    "The River Kelburn flows only 46 kilometers before reaching the sea, yet it sustains a surprisingly diverse ecosystem of freshwater mussels, herons, and a rare species of trout found nowhere else.",
    "In 1923, the engineer Clara Whitcombe designed a modular bridge system made of interlocking steel plates. Though few were built, the concept influenced modern prefabricated bridge design.",
    "The mineral veridite is notable for its deep green color and waxy luster. Small deposits have been found in northern Canada, where it is sometimes carved into ornamental beads.",
    "Researchers at the Lydmere Observatory created a star catalog specifically focused on binary systems with highly eccentric orbits. Their findings suggested unusual stability in pairs once thought unstable.",
    "A traditional dessert in the coastal region of Marovia is a layered pastry filled with almond cream and soaked in citrus syrup. Locals prepare it during midsummer festivals to symbolize abundance.",
    "The Selwick Dial is a mechanical device built in the 1890s that calculates tidal movements. It relies on a set of rotating brass disks engraved with lunar phases and coastal calibration data.",
    "Unlike most orchids, the Frostpetal variety blooms only during cold months. Its white blossoms appear after the first snowfall and are pollinated by winter moths active in freezing temperatures.",
    "In early computing history, the Bronswick-12 was a failed attempt at building a portable typewriter-sized computer. Weighing 14 kilograms, it was too heavy to succeed but inspired later laptops.",
    "The city of Durness implemented a community-run power grid in 2008 using tidal turbines. Today nearly 70 percent of its households rely on locally generated renewable energy.",
    "The Ashgrove Manuscript contains herbal remedies described by a 14th-century physician. Many recipes combine medicinal plants with ritual instructions, suggesting a blend of science and folklore.",
    "Lysium clay is a natural pigment that shifts from blue to violet depending on the angle of sunlight. Ancient potters prized it for ceramics that seemed to change color through the day.",
    "The Polar Survey of 1974 reported unexpected volcanic vents beneath the Arctic ice sheet. The discovery reshaped theories of geothermal activity in polar regions.",
    "In the festival of Lantern Rest, villagers float small reed boats with candles downriver. The custom is believed to honor ancestors and mark the transition from winter to spring.",
    "The Mavricon Cipher is an encryption method invented during the early telegraph era. It substitutes letters with shifting geometric symbols, making it difficult to decode without a key grid.",
    "Researchers in experimental acoustics found that sound waves under 15 hertz can influence balance perception. Though inaudible, these infrasonic vibrations create a sense of dizziness in test subjects.",
    "The fruit of the candleberry tree produces a wax that burns cleanly without smoke. Before petroleum-based candles, entire villages harvested the berries for household lighting.",
    "During excavation near the city of Brackford, archaeologists uncovered a series of underground halls carved with spiral motifs. Radiocarbon dating placed the site around 900 BCE.",
    "The flightless bird known as the Silver Rail survives on a single island in the South Atlantic. Its feathers shimmer with a metallic sheen, which helps it blend into rocky coastlines.",
    "In 1961, the mathematician Orla Jensen proposed a number system based on repeating sequences of primes. Though impractical for calculation, it fascinated theorists studying abstract structures."
]

In [5]:
embedding_model = GPT2Model.from_pretrained("gpt2", output_hidden_states=True)

In [6]:
embeddings = [
    embedding_model(torch.tensor([embedding_tokenizer.encode(chunk)])).last_hidden_state[0].mean(dim=0)
    for chunk in chunks
]

In [7]:
embeddings[0].shape

torch.Size([768])

In [9]:
model_name = "google/flan-t5-small"
generator_tokenizer = AutoTokenizer.from_pretrained(model_name)
generator_model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)

In [14]:
def rag_retriever(prompt, top_k=3):
  prompt_embedding = embedding_model(torch.tensor([embedding_tokenizer.encode(prompt)])).last_hidden_state[0].mean(dim=0)

  embedding_scores = []
  for i, chunk_embedding in enumerate(embeddings):
    cosine_similarity = (torch.dot(prompt_embedding, chunk_embedding) / (torch.norm(chunk_embedding) * torch.norm(prompt_embedding))).item()
    embedding_scores.append((i, cosine_similarity))

  return [
      chunks[id]
      for id, score in sorted(embedding_scores, key=lambda x: x[1], reverse=True)[:top_k]
  ]

In [27]:
def rag_generator(prompt, top_k=1):
  context = "Context:" + "\n".join(rag_retriever(prompt, top_k=top_k))
  enriched_prompt = context + "\n Question: " + prompt
  print(f"{enriched_prompt=}")
  inputs = generator_tokenizer(enriched_prompt, return_tensors="pt").to(device)
  outputs = generator_model.generate(**inputs)
  return generator_tokenizer.decode(outputs[0], skip_special_tokens=True)

In [16]:
def non_rag_generator(prompt):
  inputs = generator_tokenizer(prompt, return_tensors="pt").to(device)
  outputs = generator_model.generate(**inputs)
  return generator_tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test

In [17]:
non_rag_generator("Who designed a modular bridge system with interlocking steel plates in 1923?")

'john w. w. sturgeon'

In [18]:
rag_generator("Who designed a modular bridge system with interlocking steel plates in 1923?")

enriched_prompt='Context:In 1923, the engineer Clara Whitcombe designed a modular bridge system made of interlocking steel plates. Though few were built, the concept influenced modern prefabricated bridge design.\n Question: Who designed a modular bridge system with interlocking steel plates in 1923?'


'Clara Whitcombe'

In [19]:
non_rag_generator("Which town is famous for its glassblowing tradition starting in the 1700s?")

'st johns'

In [20]:
rag_generator("Which town is famous for its glassblowing tradition starting in the 1700s?")

enriched_prompt='Context:The town of Alderwick is known for its glassblowing tradition, which dates back to the late 1700s. Families there still pass down specialized techniques for producing colored glass using local minerals.\n Question: Which town is famous for its glassblowing tradition starting in the 1700s?'


'Alderwick'

In [21]:
non_rag_generator("Which orchid blooms only in winter after the first snowfall?")

'sturgeon'

In [29]:
rag_generator("Which orchid blooms only in winter after the first snowfall?", top_k=3)

enriched_prompt='Context:Unlike most orchids, the Frostpetal variety blooms only during cold months. Its white blossoms appear after the first snowfall and are pollinated by winter moths active in freezing temperatures.\nThe fruit of the candleberry tree produces a wax that burns cleanly without smoke. Before petroleum-based candles, entire villages harvested the berries for household lighting.\nThe River Kelburn flows only 46 kilometers before reaching the sea, yet it sustains a surprisingly diverse ecosystem of freshwater mussels, herons, and a rare species of trout found nowhere else.\n Question: Which orchid blooms only in winter after the first snowfall?'


'Frostpetal'

In [25]:
non_rag_generator("What was candleberry wax used for before petroleum candles?")

'a candle'

In [28]:
rag_generator("What was candleberry wax used for before petroleum candles?", top_k=3)

enriched_prompt='Context:The flightless bird known as the Silver Rail survives on a single island in the South Atlantic. Its feathers shimmer with a metallic sheen, which helps it blend into rocky coastlines.\nThe fruit of the candleberry tree produces a wax that burns cleanly without smoke. Before petroleum-based candles, entire villages harvested the berries for household lighting.\nThe Ashgrove Manuscript contains herbal remedies described by a 14th-century physician. Many recipes combine medicinal plants with ritual instructions, suggesting a blend of science and folklore.\n Question: What was candleberry wax used for before petroleum candles?'


'household lighting'