In [None]:
ECT_SCHEMA = """
Strictly output JSON from user's query following the pydantic class below. Do not include any other text, explanations, or formatting:

class FinancialData(BaseModel):
    year_s: str = Field(..., description="Four digit year string, default to be string 2024")
    ticker_s: str = Field(..., description="Ticker for the stock from user's the query")
    quarter_s: Optional[str] = Field(..., description="Quarter mentioned in user's query, do not include if it's not mentioned, should follow Q[d] where [d] is a digit")

Here is the query: {query}

Respond with only the JSON, no other text:
"""

from langchain.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field
from typing import Optional
import json

class FinancialData(BaseModel):
    year_s: str = Field(..., description="Four digit year string, default to be string 2024")
    ticker_s: str = Field(..., description="Ticker for the stock from user's the query")
    quarter_s: Optional[str] = Field(None, description="Quarter mentioned in user's query, do not include if it's not mentioned, should follow Q[d] where [d] is a digit")


def prepare_vespa_payload(data):
    print(data)
    query = data['query']
    parsed_json = data['parsed_json']
    
    # Ensure parsed_json is a dict and has expected keys
    if isinstance(parsed_json, dict) and 'ticker_s' in parsed_json:
        payload = get_vespa_payload(query, "", parsed_json, 30)
    else:
        # If parsing failed, use a default payload
        payload = get_vespa_payload(query, "", {}, 30)
    
    search_result = run_vespa_search(payload)
    print(search_result)
    return payload

def parse_ect_search(data):
    try:
        search_result = data['search_data']['data']['searchResults']
        l = [{search_result[i]["fields"]["title"]: search_result[i]["fields"]["data"]} for i in range(len(search_result))]
    except (KeyError, TypeError, IndexError):
        # If there's an error accessing the expected structure, return an empty list
        l = []
    return {"search_results": l, "query": data['query']}

def create_ect_chain(llm):
    ect_schema_prompt = ChatPromptTemplate.from_template(ECT_SCHEMA)
    ect_prompt = ChatPromptTemplate.from_template(ECT_PROMPT)

    def parse_json_safely(json_str):
        try:
            # First, try to parse as-is
            return json.loads(json_str)
        except json.JSONDecodeError:
            # If that fails, try to extract JSON from the string
            try:
                json_start = json_str.index('{')
                json_end = json_str.rindex('}') + 1
                clean_json_str = json_str[json_start:json_end]
                return json.loads(clean_json_str)
            except (ValueError, json.JSONDecodeError):
                # If all else fails, return None
                return None

    parse_schema = RunnableParallel(
        query=RunnablePassthrough(),
        parsed_json=ect_schema_prompt | llm | RunnableLambda(parse_json_safely)
    )

    vespa_search = RunnableParallel(
        search_data=RunnableLambda(prepare_vespa_payload) | RunnableLambda(run_vespa_search),
        query=RunnablePassthrough()
    )

    ect_chain = (
        parse_schema
        | RunnableLambda(lambda x: {**x, "parsed_json": x["parsed_json"] or {}})  # Ensure parsed_json is always a dict
        | vespa_search
        | RunnableLambda(parse_ect_search)
        | RunnablePassthrough.assign(summary=ect_prompt | llm | StrOutputParser())
    )

    return ect_chain