# Using T5 Small for Retrieval Augmented Generation

## Dependencies

In [1]:
%pip install accelerate transformers[torch] torch sentencepiece --user

Note: you may need to restart the kernel to use updated packages.


## Model Setup

In [None]:
import os
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM, T5Tokenizer, T5ForConditionalGeneration, pipeline

In [2]:
model_id = "google/flan-t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_id)
model = T5ForConditionalGeneration.from_pretrained(model_id)

Downloading spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/308M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

<pad> Wie ich er bitten?</s>




### Storing Model [Optional]

In [17]:
# uncomment if you want to store your models to a folder
# model_name = model_id.replace('/', '_')
# models_path = Path(f"{os.getcwd()}/models/{model_name}")
# if not models_path.exists():
#  models_path.mkdir(parents=True, exist_ok=True)

# # Storing tokenizer locally
# tokenizer.save_pretrained(str(models_path))
# # Storing model locally
# model.save_pretrained(str(models_path))
# print("Model saved successfully!")

Tokenizer saved successfully!
Model saved successfully!


### Loading stored model [Optional]

In [18]:
# uncomment if you want to load your stored model
# model_name = model_id.replace('/', '_')
# models_path = Path(f"{os.getcwd()}/models/{model_name}")
# tokenizer = T5Tokenizer.from_pretrained(str(models_path))
# model = T5ForConditionalGeneration.from_pretrained(str(models_path))

## Accessing Embeddings Database

In [7]:
import chromadb
from chromadb.config import Settings
client = chromadb.Client(Settings(
    chroma_db_impl="duckdb+parquet",
    persist_directory="./db/"
))
collection = client.get_collection(name="airflow_docs_stable")

In [20]:
question = "Python Code to create a Dag Class"
results = collection.query(
    query_texts=[question],
    n_results=1,
)
formatted_result = "\n\n".join(results["documents"][0])
print(formatted_result)

dag_loader.py¶  from airflow import DAG  from airflow.decorators import task   import pendulum    def create_dag(dag_id, schedule, dag_number, default_args):      dag = DAG(          dag_id,          schedule=schedule,          default_args=default_args,          pendulum.datetime(2021, 9, 13, tz="UTC"),      )       with dag:           @task()          def hello_world():              print("Hello World")              print(f"This is DAG: {dag_number}")           hello_world()       return dag       DAG construction¶


## Setting up Retrieval Augmeneted Generation (RAG)

In [21]:
prompt = (
    "You are a helpful question and answer bot, your task is to provide the best answer to a given user's question.\n"
    "Only use the context below to answer the user's question, if you don't have the necessary information to answer say: 'I don't know!'\n"
    "Context and Question are denoted by ```\n"
    f"Context: ```{formatted_result}```\n\n"
    f"Question: ```{question}?```\n\n"
    "Response:"
)
# response = text_generation(prompt) # WIP: add text generation pipeline
# print(response[0]["generated_text"].lstrip())

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


You are a helpful question and answer bot, your task is to provide the best answer to a given user's question.
Only use the context below to answer the user's question, if you don't have the necessary information to answer say: 'I don't know!'
Context and Question are denoted by ```
Context: ```dag_loader.py¶  from airflow import DAG  from airflow.decorators import task   import pendulum    def create_dag(dag_id, schedule, dag_number, default_args):      dag = DAG(          dag_id,          schedule=schedule,          default_args=default_args,          pendulum.datetime(2021, 9, 13, tz="UTC"),      )       with dag:           @task()          def hello_world():              print("Hello World")              print(f"This is DAG: {dag_number}")           hello_world()       return dag       DAG construction¶```

Question: ```Python Code to create a Dag Class?```

Response: ```

Parameters:

dag - a DAG object - the DAG method to create.

default - the default DAG function that returns t