### FinRAG Demo with Gradio, Weaviate and Ollama

#### This notebook expects the Ollama server to be backed by GPU with 16GB of memory.

#### This is a lite version that uses Weaviate's embedded instance.

In [1]:
!pip install pip gradio weaviate_client wget ijson -Uq

In [2]:
import gradio as gr
from huggingface_hub import InferenceClient
import weaviate.classes as wvc
import weaviate
from weaviate.auth import AuthApiKey
import logging
import os
import requests
import json
import ijson
import weaviate

ollama_api_endpoint = os.getenv("OLLAMA_HOST", "http://ollama.ollama")
ollama_vectorizer_model = model = "all-minilm"
ollama_generative_model="granite3-dense:8b"

logging.basicConfig(level=logging.INFO)
logging.info(f'OLLAMA_API_ENDPOINT = {ollama_api_endpoint}')

INFO:root:OLLAMA_API_ENDPOINT = http://ollama.ollama


In [3]:
def connect_weaviate_embedded():
    logging.basicConfig(level=logging.ERROR)
    logging.info('Connecting to Weaviate embedded instance')
    client = weaviate.connect_to_embedded(
        environment_variables={"ENABLE_MODULES": "text2vec-ollama,generative-ollama"},
        version="1.25.6"
    )
    return client

In [4]:
client = connect_weaviate_embedded()

if client.is_ready():
    logging.info('')
    logging.info(f'Found {len(client.cluster.nodes())} Weaviate nodes.')
    logging.info('')
    for node in client.cluster.nodes():
        logging.info(node)
        logging.info('')
    logging.info(f'client.get_meta(): {client.get_meta()}')
else:
    logging.error("Client is not ready")

INFO:root:Connecting to Weaviate embedded instance
INFO:weaviate-client:Started /opt/app-root/src/.cache/weaviate-embedded: process ID 10008
{"action":"startup","default_vectorizer_module":"none","level":"info","msg":"the default vectorizer modules is set to \"none\", as a result all new schema classes without an explicit vectorizer setting, will use this vectorizer","time":"2025-01-21T21:55:22Z"}
{"action":"startup","auto_schema_enabled":true,"level":"info","msg":"auto schema enabled setting is set to \"true\"","time":"2025-01-21T21:55:22Z"}
{"level":"info","msg":"No resource limits set, weaviate will use all available memory and CPU. To limit resources, set LIMIT_RESOURCES=true","time":"2025-01-21T21:55:22Z"}
{"level":"info","msg":"open cluster service","servers":{"Embedded_at_8079":41775},"time":"2025-01-21T21:55:22Z"}
{"address":"10.128.1.110:41776","level":"info","msg":"starting cloud rpc server ...","time":"2025-01-21T21:55:22Z"}
{"level":"info","msg":"starting raft sub-system ..

In [5]:
client.collections.delete_all()

INFO:httpx:HTTP Request: GET http://localhost:8079/v1/schema "HTTP/1.1 200 OK"
{"action":"load_all_shards","level":"error","msg":"failed to load all shards: context canceled","time":"2025-01-21T21:55:24Z"}
INFO:httpx:HTTP Request: DELETE http://localhost:8079/v1/schema/Symbols "HTTP/1.1 200 OK"


In [6]:
def download_data():
    try:
      os.stat("data/symbols.json")
      logging.info("Symbols already downloaded")
    except:
      logging.info("Downloading symbols...")
      url = "https://people.redhat.com/bkozdemb/downloads/symbols.json"
      wget.download(url, "data/symbols.json")

In [7]:
def ingest_data(client):
    # ===== Define the collection =====
    symbols = client.collections.create(
        name="Symbols",
        vectorizer_config=wvc.config.Configure.Vectorizer.text2vec_ollama(
            api_endpoint=ollama_api_endpoint,
            model=ollama_vectorizer_model
        ),
        generative_config=wvc.config.Configure.Generative.ollama(
            api_endpoint=ollama_api_endpoint,
            model=ollama_generative_model
        )
    )

    # Settings for displaying the import progress
    counter = 0
    interval = 1000  # Print progress every this many records; should be bigger than the batch_size

    logging.info("JSON streaming, to avoid running out of memory on large files...")
    with client.batch.fixed_size(batch_size=50) as batch:
        with open("data/symbols.json", "rb") as f:
            objects = ijson.items(f, "item")
            for obj in objects:
                properties = {
                    "Symbol": obj["Symbol"],
                    "Name": obj["Name"],
                    "Description": obj["Description"],
                    "CIK": obj["CIK"],
                    "Exchange": obj["Exchange"],
                    "Currency": obj["Currency"],
                    "Country": obj["Country"],
                    "Sector": obj["Sector"], 
                    "Industry": obj["Industry"],
                    "Address": obj["Address"],
                    "FiscalYearEnd": obj["FiscalYearEnd"],
                    "LatestQuarter": obj["LatestQuarter"],
                    "MarketCapitalization": obj["MarketCapitalization"],
                    "BookValue": obj["BookValue"],
                    "EBITDA": obj["EBITDA"],
                    "PERatio": obj["PERatio"],
                    "PEGRatio": obj["PEGRatio"],
                    "DividendPerShare": obj["DividendPerShare"],
                    "DividendYield": obj["DividendYield"],
                    "EPS": obj["EPS"],
                    "RevenuePerShareTTM": obj["RevenuePerShareTTM"],
                    "ProfitMargin": obj["ProfitMargin"],
                    "OperatingMarginTTM": obj["OperatingMarginTTM"],
                    "ReturnOnAssetsTTM": obj["ReturnOnAssetsTTM"],
                    "ReturnOnEquityTTM": obj["ReturnOnEquityTTM"],
                    "RevenueTTM": obj["RevenueTTM"],
                    "GrossProfitTTM": obj["GrossProfitTTM"],
                    "DilutedEPSTTM": obj["DilutedEPSTTM"],
                    "QuarterlyEarningsGrowthYOY": obj["QuarterlyEarningsGrowthYOY"],
                    "QuarterlyRevenueGrowthYOY": obj["QuarterlyRevenueGrowthYOY"],
                    "AnalystTargetPrice": obj["AnalystTargetPrice"],
                    "AnalystRatingStrongBuy": obj["AnalystRatingStrongBuy"],
                    "AnalystRatingBuy": obj["AnalystRatingBuy"],
                    "AnalystRatingHold": obj["AnalystRatingHold"],
                    "AnalystRatingSell": obj["AnalystRatingSell"],
                    "AnalystRatingStrongSell": obj["AnalystRatingStrongSell"],
                    "TrailingPE": obj["TrailingPE"],
                    "ForwardPE": obj["ForwardPE"],
                    "PriceToSalesRatioTTM": obj["PriceToSalesRatioTTM"],
                    "PriceToBookRatio": obj["PriceToBookRatio"],
                    "EVToRevenue": obj["EVToRevenue"],
                    "EVToEBITDA": obj["EVToEBITDA"],
                    "Beta": obj["Beta"],
                    "fiftytwoWeekHigh": obj["52WeekHigh"],
                    "fiftytwoWeekLow": obj["52WeekLow"],
                    "fiftyDayMovingAverage": obj["50DayMovingAverage"],
                    "twohundredDayMovingAverage": obj["200DayMovingAverage"],
                    "SharesOutstanding": obj["SharesOutstanding"],
                    "DividendDate": obj["DividendDate"],
                    "ExDividendDate": obj["ExDividendDate"]
                }
                batch.add_object(
                    collection="Symbols",
                    properties=properties,
                    # If you Bring Your Own Vectors, add the `vector` parameter here
                    # vector=obj.vector["default"]
                )

                # Calculate and display progress
                counter += 1
                if counter % interval == 0:
                    logging.info(f"Imported {counter} of 7116 stock symbols.")


    logging.info(f"Finished importing {counter} symbols.")
    return symbols

In [8]:
def semantic_search(query='computers', limit=2) -> dict:
    print(f'\nSemantic Search, query = {query}.')
    print(f'limit = {limit}')
    response = symbols.query.near_text(
        query=query,
        limit=limit
    )

    return_list = []
    for i in range(limit):
        return_list.append(response.objects[i].properties['name'])
    return return_list

In [9]:
def generative_search(query='computers', task=None, limit=2) -> str:
    print(f'\nPerforming generative search, query = {query}, limit = {limit}.')
    print(f'Prompt: {task}')
    print(f'limit = {limit}')
    response = symbols.generate.near_text(
        query=query,
        limit=limit,
        grouped_task=task
    )
    return response.generated

In [10]:
if __name__ == '__main__':
    logging.basicConfig(level=logging.ERROR)
    download_data()
    symbols = ingest_data(client)
    #
    # Build the Gradio user interface.
    #
    with gr.Blocks(title='Summarizing Financial Data using RAG') as demo:
            gr.Markdown("""# Summarizing Financial Data using Retrieval Augmented Generation (RAG).""")
            semantic_examples = [
                ["Computers"],
                ["Computer Software"],
                ["Pharmaceuticals"],
                ["Consumer Products"],
                ["Commodities"],
                ["Retail"],
                ["Manufacturing"],
                ["Energy"],
                ["National Defense"],
                ["Auto Makers"]
            ]
            gr.Markdown("""### Begin with a search.""")
            semantic_input_text = gr.Textbox(label="Enter a search concept or choose an example below:", value=semantic_examples[0][0])
            gr.Examples(semantic_examples,
                fn=semantic_search,
                inputs=semantic_input_text, label="Example search concepts:"
                )
            limit_slider = gr.Slider(label="Adjust the query return limit. (Optional)",value=2, minimum=1, maximum=5, step=1)
            vdb_button = gr.Button(value="Search the financial vector database.")
            vdb_button.click(fn=semantic_search, inputs=[semantic_input_text, limit_slider], outputs=gr.Textbox(label="Search Results (Filters = Name)"))
            
            prompt_examples = [
                ["Generate a paragraph that summarizes the given information from a financial perspective for the fiscal year end of December 2024."],
                ["Summarize the information from a financial investment perspective."],
                ["Summarize the potential financial investment risks and rewards."]
            ]

            gr.Markdown("""### Summarize""")
            generative_search_prompt_text = gr.Textbox(label="Enter a summarization task or choose an example below.", value=prompt_examples[0][0])
            gr.Examples(prompt_examples,
                fn=generative_search,
                inputs=[generative_search_prompt_text]
            )
            button = gr.Button(value="Generate the summary.")
            button.click(fn=generative_search,
            inputs=[semantic_input_text, generative_search_prompt_text, limit_slider],
            outputs=gr.Textbox(label="Summary"))
            
    demo.queue(max_size=10)
    demo.launch(server_name='0.0.0.0', server_port=8081, share=True)



INFO:root:Symbols already downloaded
{"action":"hnsw_prefill_cache_async","level":"info","msg":"not waiting for vector cache prefill, running in background","time":"2025-01-21T21:55:24Z","wait_for_cache_prefill":false}
{"level":"info","msg":"Created shard symbols_85SAhcdenrBD in 2.488107ms","time":"2025-01-21T21:55:24Z"}
{"action":"hnsw_vector_cache_prefill","count":1000,"index_id":"main","level":"info","limit":1000000000000,"msg":"prefilled vector cache","time":"2025-01-21T21:55:24Z","took":49810}
INFO:httpx:HTTP Request: POST http://localhost:8079/v1/schema "HTTP/1.1 200 OK"
INFO:root:JSON streaming, to avoid running out of memory on large files...
INFO:httpx:HTTP Request: GET http://localhost:8079/v1/schema "HTTP/1.1 200 OK"
{"action":"telemetry_push","level":"info","msg":"telemetry started","payload":"\u0026{MachineID:ca8d0aa4-6c1c-4a4f-9ad2-34d6cd1b1e56 Type:INIT Version:1.25.6 NumObjects:0 OS:linux Arch:amd64 UsedModules:[generative-ollama text2vec-ollama]}","time":"2025-01-21T21

* Running on local URL:  http://0.0.0.0:8081


INFO:httpx:HTTP Request: GET http://localhost:8081/gradio_api/startup-events "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: HEAD http://localhost:8081/ "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: GET https://api.gradio.app/v3/tunnel-request "HTTP/1.1 200 OK"


* Running on public URL: https://adb32ca4e2d87ef8b8.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


INFO:httpx:HTTP Request: HEAD https://adb32ca4e2d87ef8b8.gradio.live "HTTP/1.1 200 OK"



Semantic Search, query = Commodities.
limit = 2

Performing generative search, query = Commodities, limit = 2.
Prompt: Summarize the information from a financial investment perspective.
limit = 2


{"action":"restapi_management","level":"info","msg":"Shutting down... ","time":"2025-01-21T21:56:51Z"}
{"action":"restapi_management","level":"info","msg":"Stopped serving weaviate at http://127.0.0.1:8079","time":"2025-01-21T21:56:51Z"}
{"action":"telemetry_push","level":"info","msg":"telemetry terminated","payload":"\u0026{MachineID:ca8d0aa4-6c1c-4a4f-9ad2-34d6cd1b1e56 Type:TERMINATE Version:1.25.6 NumObjects:7116 OS:linux Arch:amd64 UsedModules:[generative-ollama text2vec-ollama]}","time":"2025-01-21T21:56:51Z"}
{"level":"info","msg":"closing raft FSM store ...","time":"2025-01-21T21:56:51Z"}
{"level":"info","msg":"shutting down raft sub-system ...","time":"2025-01-21T21:56:51Z"}
{"level":"info","msg":"transferring leadership to another server","time":"2025-01-21T21:56:51Z"}
{"error":"cannot find peer","level":"error","msg":"transferring leadership","time":"2025-01-21T21:56:51Z"}
{"level":"info","msg":"closing raft-net ...","time":"2025-01-21T21:56:51Z"}
{"level":"info","msg":"closi