In [1]:
from collections import defaultdict
from openai import OpenAI
import replicate
from elasticsearch import Elasticsearch
from elasticsearch.exceptions import AuthenticationException, ConnectionError
from transformers import T5ForConditionalGeneration, T5Tokenizer
from sentence_transformers import SentenceTransformer
import chromadb
import os
from dotenv import load_dotenv

load_dotenv()

from utils import search_podcasts
from ingest import create_index, download_podcast, transcribe_podcast, encode_podcast, index_podcast
from rag import rag, search

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /home/minasonbol/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


# Setup

In [2]:
def update_session(**kwargs):
    for k, v in kwargs.items():
        session_state[k] = v

def text_input(input_text):
    return input(input_text)

def choose_podcast_option(episode_option):
    update_session(episode_option_selected=False)
    if episode_option == "1. Try a sample":
        update_session(episode_option_selected=True, episode_option=episode_option)
    elif episode_option == "2. Provide the iTunes URL for a specific podcast episode":
        episode_url = text_input("Enter the iTunes URL of the episode you want:")
        update_session(episode_option_selected=True, episode_option=episode_option, episode_url=episode_url)
    elif episode_option == "3. Provide a name of a podcast to explore its most recent episode":
        term = text_input("Enter a search term for podcasts:")
        try:
            if term != '':
                found_podcasts = search_podcasts(term)
                if found_podcasts['status'] == 'Fail':
                    raise Exception
                else:
                    podcast_names = [f"{podcast['collectionName']} by {podcast['artistName']}" for podcast in found_podcasts['podcasts']]
                    selected_podcast = selectbox("Select a podcast:", podcast_names)
                    selected_index=podcast_names.index(selected_podcast)
                    update_session(episode_option_selected=True, episode_option=episode_option, found_podcasts=found_podcasts['podcasts'], selected_index=selected_index)
        except Exception:
            print("Please enter a valid search term.")

def choose_encoder(sentence_encoder):
    update_session(sentence_encoder_selected=False)
    if sentence_encoder == "1. T5":
        encoder=SentenceTransformer("sentence-transformers/sentence-t5-base")
        update_session(sentence_encoder_selected=True, sentence_encoder=sentence_encoder, encoder=encoder)
    elif sentence_encoder == "2. OpenAI":
        embedding_model = "text-embedding-3-large"
        openai_api_key = text_input("OpenAI API Key", key="file_oa_api_key", type="password")
        if openai_api_key != '':
            try:
                oa_embedding_client = OpenAI(api_key=openai_api_key)
                response = oa_embedding_client.models.list()
                update_session(sentence_encoder_selected=True, sentence_encoder=sentence_encoder, embedding_client=oa_embedding_client, embedding_model=embedding_model)
            except:
                print("Invalid API key. Please provide a valid API token.")

def choose_transcription_method(transcription_method, session_state):
    if session_state.get('episode_option', False):
        if session_state['episode_option'] != "1. Try a sample":
            update_session(transcription_method_selected=False)
            if transcription_method=="1. Replicate":
                replicate_api_key = os.getenv('REPLICATE_API_KEY')
                if replicate_api_key != '':
                    try:
                        replicate_client = replicate.Client(api_token=replicate_api_key)
                        response = replicate_client.models.list()
                        update_session(transcription_method_selected=True, transcription_method=transcription_method, transcription_client=replicate_client)
                    except:
                        print("Invalid API key. Please provide a valid API token.")
            elif transcription_method=="2. Local transcription":
                update_session(transcription_method_selected=True, transcription_method=transcription_method)
        else:
            print("The sample podcast doesn't require a transcription method.")
            update_session(transcription_method_selected=True)

def choose_vector_db(vector_db):
    update_session(index_name="podcast-transcriber", vector_db_selected=False)
    if vector_db=="1. Minsearch":
        update_session(vector_db=vector_db)
        update_session(index=create_index(**session_state))
        update_session(vector_db_selected=True, index_created=True)
        print(f"Index {session_state['index'].index_name} was created successfully.")
    elif vector_db=="2. Elasticsearch":
        elasticsearch_api_key = os.getenv('ES_API_KEY')
        elasticsearch_cloud_id = os.getenv('ES_CLOUD_ID')
        if elasticsearch_api_key != '' and elasticsearch_cloud_id != '':
            try:
                es_client = Elasticsearch(cloud_id=elasticsearch_cloud_id, api_key=elasticsearch_api_key)
                response = es_client.cluster.health()
                update_session(vector_db=vector_db, vector_db_client=es_client)
                update_session(index=create_index(**session_state))
                update_session(vector_db_selected=True, index_created=True)
                print(f"Index {[k for k,v in session_state['index'].items()][0]} was created successfully.")
            except AuthenticationException:
                print("Invalid API key or Cloud ID. Please provide a valid tokens.")
            except ConnectionError:
                print("Connection error. Could not connect to the cluster.")
            except Exception as e:
                print(f"An error occurred: {e}")
    elif vector_db=="3. ChromaDB":
        chroma_client = chromadb.PersistentClient(path="./chroma_db")
        update_session(vector_db=vector_db, vector_db_client=chroma_client)
        update_session(index=create_index(**session_state))
        update_session(vector_db_selected=True, index_created=True)
        print(f"Index {session_state['vector_db_client'].list_collections()[0].name} was created successfully.")

def choose_llm(llm_option):
    update_session(llm_option_selected=False)
    if llm_option == "1. GPT-4o":
        if session_state['sentence_encoder'] != "2. OpenAI":
            openai_api_key = os.getenv('OPENAI_API_KEY')
            if openai_api_key != '':
                try:
                    oa_client = OpenAI(api_key=openai_api_key)
                    response = oa_client.models.list()
                    update_session(llm_option_selected=True, llm_option=llm_option, llm_client=oa_client)
                except:
                    print("Invalid API key. Please provide a valid API token.")
        else:
            oa_client = session_state['embedding_client']
            update_session(llm_option_selected=True, llm_option=llm_option, llm_client=oa_client)

    elif llm_option == "2. FLAN-5":
        model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
        tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
        update_session(llm_option_selected=True, llm_option=llm_option, llm_client=model, llm_tokenizer=tokenizer)

# Main

In [3]:
session_state = defaultdict(
    episode_option = "1. Try a sample",
    sentence_encoder = "1. T5",
    transcription_method = "1. Replicate",
    vector_db = "3. ChromaDB",
    llm_option = "1. GPT-4o"
)

In [4]:
choose_podcast_option(session_state['episode_option'])
# https://podcasts.apple.com/us/podcast/what-if-the-russian-revolution-hadnt-been-bolshevik/id1682047968?i=1000668755545

In [5]:
choose_encoder(session_state['sentence_encoder'])



In [6]:
choose_transcription_method(session_state['transcription_method'], session_state)

The sample podcast doesn't require a transcription method.


In [7]:
choose_vector_db(session_state['vector_db'])

Index podcast-transcriber was created successfully.


In [8]:
choose_llm(session_state['llm_option'])

In [9]:
import json

# Specify the path to your JSON file
file_path = 'transcription_data.json'

# Open the JSON file and load the data
with open(file_path, 'r') as file:
    data = json.load(file)

In [10]:
for d in data:
    for i, chunk in enumerate(d['chunks']):
        chunk.update({'id': str(i+1)})

In [11]:
import time 
document_lengths = []
embedding_speeds = []
indexing_speeds = []

for d in data:
    session_state['episode_details'] = d
    document_lengths.append(len(d['text'].split(" ")))
    
    start_time = time.time()
    session_state['episode_details'].update(encode_podcast(**session_state))
    end_time = time.time()
    embedding_speed = end_time - start_time
    embedding_speeds.append(embedding_speed)
    
    start_time = time.time()
    index_podcast(**session_state)
    end_time = time.time()
    execution_time = end_time - start_time
    indexing_speeds.append(execution_time)


Insert of existing embedding ID: 1
Insert of existing embedding ID: 2
Insert of existing embedding ID: 3
Insert of existing embedding ID: 4
Insert of existing embedding ID: 5
Insert of existing embedding ID: 6
Insert of existing embedding ID: 7
Insert of existing embedding ID: 8
Insert of existing embedding ID: 9
Insert of existing embedding ID: 10
Insert of existing embedding ID: 11
Insert of existing embedding ID: 12
Insert of existing embedding ID: 13
Add of existing embedding ID: 1
Add of existing embedding ID: 2
Add of existing embedding ID: 3
Add of existing embedding ID: 4
Add of existing embedding ID: 5
Add of existing embedding ID: 6
Add of existing embedding ID: 7
Add of existing embedding ID: 8
Add of existing embedding ID: 9
Add of existing embedding ID: 10
Add of existing embedding ID: 11
Add of existing embedding ID: 12
Add of existing embedding ID: 13
Insert of existing embedding ID: 1
Insert of existing embedding ID: 2
Insert of existing embedding ID: 3
Insert of existi

In [12]:
document_lengths

[1551, 164, 1322, 893, 429, 234, 658, 961, 971, 299]

In [13]:
embedding_speeds

[9.905736446380615,
 0.7764761447906494,
 7.16798210144043,
 4.440104246139526,
 3.2460227012634277,
 1.0705726146697998,
 3.7475805282592773,
 5.765589475631714,
 4.3084917068481445,
 1.7004859447479248]

In [14]:
average_t5_embedding_speed = sum(embedding_speeds) / len(embedding_speeds)
average_t5_embedding_speed

4.2129041910171505

In [15]:
indexing_speeds

[0.22745394706726074,
 0.05115485191345215,
 0.41826939582824707,
 0.26992321014404297,
 0.1572573184967041,
 0.05965089797973633,
 0.16653776168823242,
 0.2607290744781494,
 0.202528715133667,
 0.09075641632080078]

In [16]:
average_indexing_speed = sum(indexing_speeds) / len(indexing_speeds)
average_indexing_speed

0.19042615890502929

In [34]:
len(d['text'].split(" "))

299