Compared to previous version


* Attention weighted embeddings instead of simple mean pooling  
* Using `e5-base-v2` instead of `gpt2` for embeddings that is optimized for both retrieval and generation for better semantic discremination
*   Pre-normalizing embeddings for more efficient inference



In [8]:
# https://arxiv.org/pdf/2005.11401

import copy
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [35]:
embedding_model_name = "intfloat/e5-base-v2"
embedding_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
embedding_model = AutoModel.from_pretrained(embedding_model_name, output_hidden_states=True, output_attentions=True)

In [12]:
chunks = [
    "The town of Alderwick is known for its glassblowing tradition, which dates back to the late 1700s. Families there still pass down specialized techniques for producing colored glass using local minerals.",
    "The River Kelburn flows only 46 kilometers before reaching the sea, yet it sustains a surprisingly diverse ecosystem of freshwater mussels, herons, and a rare species of trout found nowhere else.",
    "In 1923, the engineer Clara Whitcombe designed a modular bridge system made of interlocking steel plates. Though few were built, the concept influenced modern prefabricated bridge design.",
    "The mineral veridite is notable for its deep green color and waxy luster. Small deposits have been found in northern Canada, where it is sometimes carved into ornamental beads.",
    "Researchers at the Lydmere Observatory created a star catalog specifically focused on binary systems with highly eccentric orbits. Their findings suggested unusual stability in pairs once thought unstable.",
    "A traditional dessert in the coastal region of Marovia is a layered pastry filled with almond cream and soaked in citrus syrup. Locals prepare it during midsummer festivals to symbolize abundance.",
    "The Selwick Dial is a mechanical device built in the 1890s that calculates tidal movements. It relies on a set of rotating brass disks engraved with lunar phases and coastal calibration data.",
    "Unlike most orchids, the Frostpetal variety blooms only during cold months. Its white blossoms appear after the first snowfall and are pollinated by winter moths active in freezing temperatures.",
    "In early computing history, the Bronswick-12 was a failed attempt at building a portable typewriter-sized computer. Weighing 14 kilograms, it was too heavy to succeed but inspired later laptops.",
    "The city of Durness implemented a community-run power grid in 2008 using tidal turbines. Today nearly 70 percent of its households rely on locally generated renewable energy.",
    "The Ashgrove Manuscript contains herbal remedies described by a 14th-century physician. Many recipes combine medicinal plants with ritual instructions, suggesting a blend of science and folklore.",
    "Lysium clay is a natural pigment that shifts from blue to violet depending on the angle of sunlight. Ancient potters prized it for ceramics that seemed to change color through the day.",
    "The Polar Survey of 1974 reported unexpected volcanic vents beneath the Arctic ice sheet. The discovery reshaped theories of geothermal activity in polar regions.",
    "In the festival of Lantern Rest, villagers float small reed boats with candles downriver. The custom is believed to honor ancestors and mark the transition from winter to spring.",
    "The Mavricon Cipher is an encryption method invented during the early telegraph era. It substitutes letters with shifting geometric symbols, making it difficult to decode without a key grid.",
    "Researchers in experimental acoustics found that sound waves under 15 hertz can influence balance perception. Though inaudible, these infrasonic vibrations create a sense of dizziness in test subjects.",
    "The fruit of the candleberry tree produces a wax that burns cleanly without smoke. Before petroleum-based candles, entire villages harvested the berries for household lighting.",
    "During excavation near the city of Brackford, archaeologists uncovered a series of underground halls carved with spiral motifs. Radiocarbon dating placed the site around 900 BCE.",
    "The flightless bird known as the Silver Rail survives on a single island in the South Atlantic. Its feathers shimmer with a metallic sheen, which helps it blend into rocky coastlines.",
    "In 1961, the mathematician Orla Jensen proposed a number system based on repeating sequences of primes. Though impractical for calculation, it fascinated theorists studying abstract structures."
]

In [55]:
embedding_tokenizer.encode("hi")

[101, 7632, 102]

In [75]:
def text_to_embedding(prompt, model=embedding_model, tokenizer=embedding_tokenizer, pooling_type="attention_weighted"):
  output = model(torch.tensor([tokenizer.encode(prompt)]))
  embedding = output.last_hidden_state[0]
  if pooling_type == "mean":
    pooled_embedding = embedding.mean(dim=0)
  elif pooling_type == "attention_weighted":
    attention_weights = output.attentions[-1].squeeze(0).mean(dim=0)[:, 0] # for CLS token
    pooled_embedding = attention_weights @ embedding
  elif pooling_type == "cls":
    pooled_embedding = embedding[0, :]
  else:
    raise NotImplemented

  normalized_embedding = pooled_embedding / torch.norm(pooled_embedding)
  return normalized_embedding

embeddings = [text_to_embedding(chunk) for chunk in chunks]

In [76]:
embeddings[0].shape

torch.Size([768])

In [77]:
generator_model_name = "google/flan-t5-small"
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
generator_model = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name).to(device)

In [78]:
def rag_retriever(prompt, top_k=3, print_similarity_scores=False):
  prompt_embedding = text_to_embedding(prompt)

  embedding_scores = []
  for i, chunk_embedding in enumerate(embeddings):
    cosine_similarity = torch.dot(prompt_embedding, chunk_embedding).item()  # vectors are already normalized hence cosine similarity == dot product
    embedding_scores.append((i, cosine_similarity))

  if print_similarity_scores:
    print(embedding_scores)

  return [
      chunks[id]
      for id, score in sorted(embedding_scores, key=lambda x: x[1], reverse=True)[:top_k]
  ]

In [80]:
rag_retriever(chunks[0], 3, True)

[(0, 1.0000001192092896), (1, 0.7374454140663147), (2, 0.7531095743179321), (3, 0.7874753475189209), (4, 0.7534308433532715), (5, 0.7852730751037598), (6, 0.7961565256118774), (7, 0.7430126667022705), (8, 0.7461990118026733), (9, 0.7723017930984497), (10, 0.7652126550674438), (11, 0.8015544414520264), (12, 0.7438547611236572), (13, 0.773317813873291), (14, 0.7475773096084595), (15, 0.7374314069747925), (16, 0.8047183156013489), (17, 0.7511347532272339), (18, 0.762906551361084), (19, 0.7113053798675537)]


['The town of Alderwick is known for its glassblowing tradition, which dates back to the late 1700s. Families there still pass down specialized techniques for producing colored glass using local minerals.',
 'The fruit of the candleberry tree produces a wax that burns cleanly without smoke. Before petroleum-based candles, entire villages harvested the berries for household lighting.',
 'Lysium clay is a natural pigment that shifts from blue to violet depending on the angle of sunlight. Ancient potters prized it for ceramics that seemed to change color through the day.']

In [85]:
def rag_generator(prompt, top_k=1):
  context = "Use the following context to answer the question:" + "\n - ".join(rag_retriever(prompt, top_k=top_k))
  enriched_prompt = context + "\n Question: " + prompt + "\n Answer:"
  print(f"{enriched_prompt=}")
  inputs = generator_tokenizer(enriched_prompt, return_tensors="pt").to(device)
  outputs = generator_model.generate(**inputs)
  return generator_tokenizer.decode(outputs[0], skip_special_tokens=True)

In [86]:
def non_rag_generator(prompt):
  inputs = generator_tokenizer(prompt, return_tensors="pt").to(device)
  outputs = generator_model.generate(**inputs)
  return generator_tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test

In [87]:
non_rag_generator("Who designed a modular bridge system with interlocking steel plates in 1923?")

'john w. w. sturgeon'

In [92]:
rag_generator("Who designed a modular bridge system with interlocking steel plates in 1923?", top_k=3)

enriched_prompt='Use the following context to answer the question:In 1923, the engineer Clara Whitcombe designed a modular bridge system made of interlocking steel plates. Though few were built, the concept influenced modern prefabricated bridge design.\n - In 1961, the mathematician Orla Jensen proposed a number system based on repeating sequences of primes. Though impractical for calculation, it fascinated theorists studying abstract structures.\n - The Mavricon Cipher is an encryption method invented during the early telegraph era. It substitutes letters with shifting geometric symbols, making it difficult to decode without a key grid.\n Question: Who designed a modular bridge system with interlocking steel plates in 1923?\n Answer:'


'Clara Whitcombe'

In [28]:
non_rag_generator("Which town is famous for its glassblowing tradition starting in the 1700s?")

'st johns'

In [89]:
rag_generator("Which town is famous for its glassblowing tradition starting in the 1700s?")

enriched_prompt='Use the following context to answer the question:The town of Alderwick is known for its glassblowing tradition, which dates back to the late 1700s. Families there still pass down specialized techniques for producing colored glass using local minerals.\n Question: Which town is famous for its glassblowing tradition starting in the 1700s?\n Answer:'


'Alderwick'

In [90]:
non_rag_generator("Which orchid blooms only in winter after the first snowfall?")

'sturgeon'

In [91]:
rag_generator("Which orchid blooms only in winter after the first snowfall?", top_k=3)

enriched_prompt='Use the following context to answer the question:Unlike most orchids, the Frostpetal variety blooms only during cold months. Its white blossoms appear after the first snowfall and are pollinated by winter moths active in freezing temperatures.\n - In the festival of Lantern Rest, villagers float small reed boats with candles downriver. The custom is believed to honor ancestors and mark the transition from winter to spring.\n - The flightless bird known as the Silver Rail survives on a single island in the South Atlantic. Its feathers shimmer with a metallic sheen, which helps it blend into rocky coastlines.\n Question: Which orchid blooms only in winter after the first snowfall?\n Answer:'


'Frostpetal'

In [93]:
non_rag_generator("What was candleberry wax used for before petroleum candles?")

'a candle'

In [94]:
rag_generator("What was candleberry wax used for before petroleum candles?", top_k=3)

enriched_prompt='Use the following context to answer the question:The fruit of the candleberry tree produces a wax that burns cleanly without smoke. Before petroleum-based candles, entire villages harvested the berries for household lighting.\n - Lysium clay is a natural pigment that shifts from blue to violet depending on the angle of sunlight. Ancient potters prized it for ceramics that seemed to change color through the day.\n - The Ashgrove Manuscript contains herbal remedies described by a 14th-century physician. Many recipes combine medicinal plants with ritual instructions, suggesting a blend of science and folklore.\n Question: What was candleberry wax used for before petroleum candles?\n Answer:'


'household lighting'