In [3]:
from collections import defaultdict
from openai import OpenAI
import replicate
from elasticsearch import Elasticsearch
from elasticsearch.exceptions import AuthenticationException, ConnectionError
from transformers import T5ForConditionalGeneration, T5Tokenizer
from sentence_transformers import SentenceTransformer
import chromadb
import os
from dotenv import load_dotenv

load_dotenv()

from utils import search_podcasts
from ingest import create_index, download_podcast, transcribe_podcast, encode_podcast, index_podcast
from rag import rag, search

[nltk_data] Downloading package punkt_tab to
[nltk_data]     /home/minasonbol/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [4]:
def update_session(**kwargs):
    for k, v in kwargs.items():
        session_state[k] = v

def text_input(input_text):
    return input(input_text)

def choose_podcast_option(episode_option):
    update_session(episode_option_selected=False)
    if episode_option == "1. Try a sample":
        update_session(episode_option_selected=True, episode_option=episode_option)
    elif episode_option == "2. Provide the iTunes URL for a specific podcast episode":
        episode_url = text_input("Enter the iTunes URL of the episode you want:")
        update_session(episode_option_selected=True, episode_option=episode_option, episode_url=episode_url)
    elif episode_option == "3. Provide a name of a podcast to explore its most recent episode":
        term = text_input("Enter a search term for podcasts:")
        try:
            if term != '':
                found_podcasts = search_podcasts(term)
                if found_podcasts['status'] == 'Fail':
                    raise Exception
                else:
                    podcast_names = [f"{podcast['collectionName']} by {podcast['artistName']}" for podcast in found_podcasts['podcasts']]
                    selected_podcast = selectbox("Select a podcast:", podcast_names)
                    selected_index=podcast_names.index(selected_podcast)
                    update_session(episode_option_selected=True, episode_option=episode_option, found_podcasts=found_podcasts['podcasts'], selected_index=selected_index)
        except Exception:
            print("Please enter a valid search term.")

def choose_encoder(sentence_encoder):
    update_session(sentence_encoder_selected=False)
    if sentence_encoder == "1. T5":
        encoder=SentenceTransformer("sentence-transformers/sentence-t5-base")
        update_session(sentence_encoder_selected=True, sentence_encoder=sentence_encoder, encoder=encoder)
    elif sentence_encoder == "2. OpenAI":
        embedding_model = "text-embedding-3-large"
        openai_api_key = text_input("OpenAI API Key", key="file_oa_api_key", type="password")
        if openai_api_key != '':
            try:
                oa_embedding_client = OpenAI(api_key=openai_api_key)
                response = oa_embedding_client.models.list()
                update_session(sentence_encoder_selected=True, sentence_encoder=sentence_encoder, embedding_client=oa_embedding_client, embedding_model=embedding_model)
            except:
                print("Invalid API key. Please provide a valid API token.")

def choose_transcription_method(transcription_method, session_state):
    if session_state.get('episode_option', False):
        if session_state['episode_option'] != "1. Try a sample":
            update_session(transcription_method_selected=False)
            if transcription_method=="1. Replicate":
                replicate_api_key = os.getenv('REPLICATE_API_KEY')
                if replicate_api_key != '':
                    try:
                        replicate_client = replicate.Client(api_token=replicate_api_key)
                        response = replicate_client.models.list()
                        update_session(transcription_method_selected=True, transcription_method=transcription_method, transcription_client=replicate_client)
                    except:
                        print("Invalid API key. Please provide a valid API token.")
            elif transcription_method=="2. Local transcription":
                update_session(transcription_method_selected=True, transcription_method=transcription_method)
        else:
            print("The sample podcast doesn't require a transcription method.")
            update_session(transcription_method_selected=True)

def choose_vector_db(vector_db):
    update_session(index_name="podcast-transcriber", vector_db_selected=False)
    if vector_db=="1. Minsearch":
        update_session(vector_db=vector_db)
        update_session(index=create_index(**session_state))
        update_session(vector_db_selected=True, index_created=True)
        print(f"Index {session_state['index'].index_name} was created successfully.")
    elif vector_db=="2. Elasticsearch":
        elasticsearch_api_key = os.getenv('ES_API_KEY')
        elasticsearch_cloud_id = os.getenv('ES_CLOUD_ID')
        if elasticsearch_api_key != '' and elasticsearch_cloud_id != '':
            try:
                es_client = Elasticsearch(cloud_id=elasticsearch_cloud_id, api_key=elasticsearch_api_key)
                response = es_client.cluster.health()
                update_session(vector_db=vector_db, vector_db_client=es_client)
                update_session(index=create_index(**session_state))
                update_session(vector_db_selected=True, index_created=True)
                print(f"Index {[k for k,v in session_state['index'].items()][0]} was created successfully.")
            except AuthenticationException:
                print("Invalid API key or Cloud ID. Please provide a valid tokens.")
            except ConnectionError:
                print("Connection error. Could not connect to the cluster.")
            except Exception as e:
                print(f"An error occurred: {e}")
    elif vector_db=="3. ChromaDB":
        chroma_client = chromadb.PersistentClient(path="./chroma_db")
        update_session(vector_db=vector_db, vector_db_client=chroma_client)
        update_session(index=create_index(**session_state))
        update_session(vector_db_selected=True, index_created=True)
        print(f"Index {session_state['vector_db_client'].list_collections()[0].name} was created successfully.")

def choose_llm(llm_option):
    update_session(llm_option_selected=False)
    if llm_option == "1. GPT-4o":
        if session_state['sentence_encoder'] != "2. OpenAI":
            openai_api_key = os.getenv('OPENAI_API_KEY')
            if openai_api_key != '':
                try:
                    oa_client = OpenAI(api_key=openai_api_key)
                    response = oa_client.models.list()
                    update_session(llm_option_selected=True, llm_option=llm_option, llm_client=oa_client)
                except:
                    print("Invalid API key. Please provide a valid API token.")
        else:
            oa_client = session_state['embedding_client']
            update_session(llm_option_selected=True, llm_option=llm_option, llm_client=oa_client)

    elif llm_option == "2. FLAN-5":
        model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
        tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
        update_session(llm_option_selected=True, llm_option=llm_option, llm_client=model, llm_tokenizer=tokenizer)

In [5]:
session_state = defaultdict(
    episode_option = "1. Try a sample",
    sentence_encoder = "1. T5",
    transcription_method = "1. Replicate",
    vector_db = "1. Minsearch",
    llm_option = "1. GPT-4o"
)

In [6]:
choose_podcast_option(session_state['episode_option'])
# https://podcasts.apple.com/us/podcast/what-if-the-russian-revolution-hadnt-been-bolshevik/id1682047968?i=1000668755545

In [7]:
choose_encoder(session_state['sentence_encoder'])



In [8]:
choose_transcription_method(session_state['transcription_method'], session_state)

The sample podcast doesn't require a transcription method.


In [9]:
choose_vector_db(session_state['vector_db'])

Index podcast-transcriber was created successfully.


In [10]:
choose_llm(session_state['llm_option'])

In [11]:
# download
episode_details = download_podcast(**session_state)
update_session(episode_details=episode_details, podcast_downloaded=True)

In [12]:
# no need to convert to pandas
documents = session_state['episode_details']['chunks']

In [13]:
prompt_template = """
You emulate a user of our deep-pod application.
Formulate 5 questions this user might ask based on a provided text.
Make the questions specific to this text.
The record should contain the answer to the questions, and the questions should
be complete and not too short. Use as few words as possible from the record. 

The record:

text: {text}

Provide the output in parsable JSON without using code blocks:

{{"questions": ["question1", "question2", ..., "question5"]}}
""".strip()

In [14]:
client = OpenAI()

In [15]:
prompt = prompt_template.format(**documents[0])

In [16]:
prompt

'You emulate a user of our deep-pod application.\nFormulate 5 questions this user might ask based on a provided text.\nMake the questions specific to this text.\nThe record should contain the answer to the questions, and the questions should\nbe complete and not too short. Use as few words as possible from the record. \n\nThe record:\n\ntext:  Balancing a wellness routine and busy travel plans?\n\nProvide the output in parsable JSON without using code blocks:\n\n{"questions": ["question1", "question2", ..., "question5"]}'

In [17]:
def llm(prompt):
    response = client.chat.completions.create(
        model='gpt-4o-mini',
        messages=[{"role": "user", "content": prompt}]
    )
    
    return response.choices[0].message.content

In [18]:
questions = llm(prompt)

In [19]:
json.loads(questions)

{'questions': ['What are some tips for maintaining a wellness routine while traveling?',
  'How can I prioritize wellness activities during a busy trip?',
  'What challenges might I face in keeping my wellness routine while on the road?',
  'Are there specific wellness practices that are easy to incorporate while traveling?',
  'How can I adapt my wellness routine to fit a hectic travel schedule?']}

In [20]:
def generate_questions(doc):
    prompt = prompt_template.format(**doc)

    response = client.chat.completions.create(
        model='gpt-4o-mini',
        messages=[{"role": "user", "content": prompt}]
    )

    json_response = response.choices[0].message.content
    return json_response

In [21]:
results = {}

In [22]:
id_iterator = itertools.count(start=1)
for doc in tqdm(documents): 
    doc_id = str(next(id_iterator))
    if doc_id in results:
        continue

    questions_raw = generate_questions(doc)
    questions = json.loads(questions_raw)
    results[doc_id] = questions['questions']

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 537/537 [13:32<00:00,  1.51s/it]


In [23]:
final_results = []

for doc_id, questions in results.items():
    for q in questions:
        final_results.append((doc_id, q))

In [24]:
ground_truth = pd.DataFrame(final_results, columns=['id', 'question'])
ground_truth.to_csv('sample/ground-truth-retrieval.csv', index=False)