In [None]:
import os

import hopsworks

from sentence_transformers import SentenceTransformer
from openai import OpenAI

from functions.prompt_engineering import get_reranker, get_context_and_source, get_answer_from_gemini, get_answer_from_gpt

import config

import warnings
warnings.filterwarnings('ignore')

## <span style="color:#ff5f27;"> 🔮 Connecting to Hopsworks Feature Store </span>

In [None]:
project = hopsworks.login()
fs = project.get_feature_store() 

## <span style="color:#ff5f27;">🪄 Get Feature Views </span>

In [None]:
stanford_reports_view = fs.get_feature_view(
    name="stanford_reports",
    version=1)

eqt_portfolio_view = fs.get_feature_view(
    name="eqt_portfolio",
    version=1)

## <span style="color:#ff5f27;">🗄️ Build Prompt </span>

In [None]:
query = "Tell me  how eqt x portfolio companies are affected by Stanford AI Index report?"

In [None]:
# Retrieve a reranker
reranker = get_reranker(config.RERANKER)

# Load the SentenceTransformer model
sentence_transformer = SentenceTransformer(
    config.MODEL_SENTENCE_TRANSFORMER,
).to(config.DEVICE)

In [None]:
reports_and_source = get_context_and_source(user_query=query, 
                                            sentence_transformer=sentence_transformer, 
                                            feature_view=stanford_reports_view, 
                                            reranker=reranker,
                                            year=2024, 
                                            k=50)
companies_and_source = get_context_and_source(user_query=query, sentence_transformer=sentence_transformer, 
                                              feature_view=eqt_portfolio_view, reranker=reranker)

In [None]:
reports_company_context = reports_and_source[0].copy()
reports_company_context.extend(companies_and_source[0])

## <span style="color:#ff5f27;">🚀 Model Inference </span>

### OpenAI

In [None]:
client = OpenAI(
    api_key=os.environ["OPENAI_API_KEY"],
)

response = get_answer_from_gpt(query = query, context = reports_company_context, source = reports_and_source[1], 
                               gpt_model = config.GPT_MODEL, client=client)
print(response)

### Gemini

In [None]:
response = get_answer_from_gemini(query = query, context = reports_company_context, source = reports_and_source[1], api_key=os.environ["GEMINI_KEY"])
print(response)