## Introduction

This notebook demos a Retrieval Augmented Generation (RAG) feature on the scripts of Wet Toast Talk Radio.

**Requirements**

* configure [AWS credentials](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html#cli-configure-files-examples)
* setup `config.yaml`, e.g:

```yaml
media_store:
  s3:
    local: false
    bucket_name: "media-store-XXXXXXXXXXXX" 
scriptwriter:
  llm:
    openai_api_key: sm:/wet-toast-talk-radio/scriptwriter/openai-api-key
```

## Data Loading

Setup AWS SDK

In [1]:
import boto3

session = boto3.Session()
s3 = session.client("s3")

Setup wet toast talk radio config

In [2]:
from pathlib import Path
from wet_toast_talk_radio.command.root import load_config

config_path = Path.cwd().parent / ("config-s3.yml")
config = load_config(config_path)
s3_config = config.media_store.s3

We list all scripts from september and october, given that the last improvement to scriptwriter was done in late August.

In [3]:
from wet_toast_talk_radio.disc_jockey.copy import Migrator

september_prefix = "script/2023-09"
october_prefix = "script/2023-10"

migrator = Migrator(cfg=s3_config)
september_objects = migrator.list_all_objects(september_prefix)
october_objects = migrator.list_all_objects(october_prefix)
objects = september_objects + october_objects
scripts = [obj["Key"] for obj in objects if obj["Key"].endswith(".jsonl")]

[2m2023-10-15 17:33:41[0m [[32m[1minfo     [0m] [1mListing objects               [0m [36mprefix[0m=[35mscript/2023-09[0m
[2m2023-10-15 17:33:47[0m [[32m[1minfo     [0m] [1mListing objects               [0m [36mprefix[0m=[35mscript/2023-10[0m


Download all scripts locally

In [7]:
data_path = Path.cwd().parent / "data"
data_path.mkdir(exist_ok=True)

In [60]:
from smart_open import open
import structlog
import concurrent
from tqdm import tqdm

logger = structlog.get_logger()

def download_script(script: str):
    (data_path / script).parent.mkdir(exist_ok=True, parents=True)
    with open(f"s3://{s3_config.bucket_name}/{script}") as fin:
        with open(data_path / script, "w") as fout:
            fout.write(fin.read())

logger.info("Downloading scripts", count=len(scripts))
with concurrent.futures.ThreadPoolExecutor(max_workers=s3_config.max_workers) as executor:
    futures = []
    for script in scripts:
        futures.append(executor.submit(download_script, script))
    # Used to log progress
    for _ in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
        pass

logger.info("Done downloading scripts")


[2m2023-10-15 16:34:59[0m [[32m[1minfo     [0m] [1mDownloading scripts           [0m [36mcount[0m=[35m12595[0m


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12595/12595 [17:48<00:00, 11.79it/s]

[2m2023-10-15 16:52:47[0m [[32m[1minfo     [0m] [1mDone downloading scripts[0m





Load each script in valid json

In [12]:
from wet_toast_talk_radio.common.dialogue import Line, Speaker, read_lines
from haystack import Document
from tqdm import tqdm
    
def load_script(script_path: Path):
    lines = read_lines(script_path)
    lines = (f"{l.speaker.name}: {l.content}" for l in lines)
    return {"content": "\n".join(lines)}

documents = []

for script in tqdm(scripts):
    script_json = load_script(data_path / script)
    doc = Document.from_json(script_json)
    documents.append(doc)
    

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12595/12595 [00:11<00:00, 1125.93it/s]


## RAG Pipeline

In [22]:
from haystack.document_stores import InMemoryDocumentStore

document_store = InMemoryDocumentStore(use_bm25=True)
document_store.write_documents(documents)

Updating BM25 representation...: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12595/12595 [00:01<00:00, 6752.07 docs/s]


In [63]:
from haystack import Pipeline
from haystack.nodes import BM25Retriever, PromptNode, PromptTemplate
from haystack.nodes.prompt import PromptNode

retriever = BM25Retriever(document_store, top_k=2)

api_key = config.scriptwriter.llm.openai_api_key.value()
qa_template = PromptTemplate("deepset/question-answering")
prompt_node = PromptNode("gpt-3.5-turbo", api_key=api_key, default_prompt_template=qa_template)

rag_pipeline = Pipeline()
rag_pipeline.add_node(component=retriever, name="retriever", inputs=["Query"])
rag_pipeline.add_node(component=prompt_node, name="prompt_node", inputs=["retriever"])

In [64]:
from pprint import pprint

print_answer = lambda out: pprint(out["results"][0].strip())

out = rag_pipeline.run(query="Who was debating with Wolfgang on the show The Great Debate?")
print_answer(out)

'Isabella'


In [65]:
out = rag_pipeline.run(query="What were Wolfgang and Isabella debating on The Great Debate?")
print_answer(out)

('Wolfgang and Isabella were debating on the topic of driving under the '
 'influence of drugs or alcohol.')


In [66]:
out = rag_pipeline.run(query="What were Wolfgang's main arguments when debating Isabella on The Great Debate?")
print_answer(out)

("Wolfgang's main arguments when debating Isabella on The Great Debate were "
 'that driving under the influence is dangerous, reckless, and puts innocent '
 'lives at risk. He emphasized the need for responsible behavior and '
 'prioritizing safety on the roads.')


In [67]:
out = rag_pipeline.run(query="Who won the debate between Wolfgang and Isabella on The Great Debate?")
print_answer(out)

('It is not stated who won the debate between Wolfgang and Isabella on The '
 'Great Debate.')
