In [None]:
from langgraph.graph import StateGraph, START, END
from langgraph.types import Send
from langgraph.checkpoint.memory import MemorySaver

from typing import List, Dict, Any, Annotated
from typing_extensions import TypedDict
import operator

from IPython.display import Image, display

In [None]:
import vertexai

from src.bigquery_utils import get_bigquery_table_schema, create_or_update_table, insert_json_to_bigquery, get_random_rows_from_bigquery
from src.schema_utils import generate_schema, parse_json, transform_schema_to_response_format
from src.extraction_utils import extract_key_values_from_document
from src.storage_utils import list_files_in_bucket  

# Importing config variables from config.py
from config.config import PROJECT_ID, DATASET_ID, TABLE_ID, BUCKET_NAME, FOLDER_PREFIX, MODEL_NAME, VERTEXAI_LOCATION
    
vertexai.init(project=PROJECT_ID, location=VERTEXAI_LOCATION)

In [None]:
# List files in the bucket
MAX_FILES = 300
file_uris = list_files_in_bucket(BUCKET_NAME, FOLDER_PREFIX, MAX_FILES)

In [None]:


class InputState(TypedDict):
    project_id: str
    dataset_id: str
    table_id: str
    file_uri: str
    model_name: str
    
class OutputState(TypedDict):
    result: str
    
# Define the overall schema, combining both input and output
class OverallState(InputState, OutputState):
    existing_schema: List[Dict[str, Any]]
    new_schema: List[Dict[str, Any]]
    new_schema_bq: List[Dict[str, Any]]
    response_schema: Dict[str, Any]
    random_rows: List[Dict[str, Any]]
    key_values: List[Dict[str, Any]]
    table_update_result: str
    
    
# Node 1: Get existing BigQuery table schema
def get_bigquery_table_schema_(state: InputState) -> OverallState:
    project_id = state["project_id"]
    dataset_id = state["dataset_id"]
    table_id = state["table_id"]
     
    return {"existing_schema": get_bigquery_table_schema(project_id, dataset_id, table_id)}

# Node 2: Generate new BigQuery schema
def generate_schema_(state: OverallState) -> OverallState:
    file_uri = state["file_uri"]
    model_name = state["model_name"]
    existing_schema = state["existing_schema"]
    
    return {"new_schema": generate_schema(file_uri, existing_schema, model_name)}


# Node 3: Create or update BigQuery table
def create_or_update_table_(state: OverallState) -> OverallState:
    project_id = state["project_id"]
    dataset_id = state["dataset_id"]
    table_id = state["table_id"]
    new_schema = state["new_schema"]

    return {"table_update_result": create_or_update_table(dataset_id, table_id, new_schema, project_id)}

# Node 4: Get the updated BigQuery table schema
def get_updated_bigquery_table_schema_(state: OverallState) -> OverallState:
    project_id = state["project_id"]
    dataset_id = state["dataset_id"]
    table_id = state["table_id"]

    return {"new_schema_bq": get_bigquery_table_schema(project_id, dataset_id, table_id)}

# Node 5: Transform schema to response format
def transform_schema_to_response_format_(state: OverallState) -> OverallState:
    schema = state["new_schema_bq"]

    return {"response_schema": transform_schema_to_response_format(schema)}

# Node 6: Get random rows from BigQuery table
def get_random_rows_(state: OverallState) -> OverallState:
    project_id = state["project_id"]
    dataset_id = state["dataset_id"]
    table_id = state["table_id"]
    num_rows = 6
    
    return {"random_rows": get_random_rows_from_bigquery(project_id, dataset_id, table_id, num_rows)}

# Node 7: Extract key-value pairs from the document
def extract_key_values_(state: OverallState) -> OverallState:
    file_uri = state["file_uri"]
    model_name = state["model_name"]
    schema = state["new_schema_bq"]
    response_schema = state["response_schema"]
    random_rows = state["random_rows"]
        
    return {"key_values": extract_key_values_from_document(file_uri, schema, random_rows, response_schema, model_name)}

# Node 8: Insert extracted key-value pairs into BigQuery table
def insert_json_to_bigquery_(state: OverallState) -> OutputState:
    project_id = state["project_id"]
    dataset_id = state["dataset_id"]
    table_id = state["table_id"]
    json_data = state["key_values"]
    
    return {"result": insert_json_to_bigquery(project_id, dataset_id, table_id, json_data)}


def table_update_result(state: OverallState):
    result = state["table_update_result"]
    if result == "Error":
        return "retry"
    else:
        return "success"

def table_insert_result(state: OverallState):
    result = state["result"]
    if result == "Error":
        return "retry"
    else:
        return "success"

# Build the graph with input and output schemas specified
graph = StateGraph(OverallState, input=InputState, output=OutputState)

# Adding nodes
graph.add_node(get_bigquery_table_schema_)
graph.add_node(generate_schema_)
graph.add_node(create_or_update_table_)
graph.add_node(get_updated_bigquery_table_schema_)
graph.add_node(transform_schema_to_response_format_)
graph.add_node(get_random_rows_)
graph.add_node(extract_key_values_)
graph.add_node(insert_json_to_bigquery_)

graph.add_edge(START, "get_bigquery_table_schema_")  
graph.add_edge("get_bigquery_table_schema_", "generate_schema_")  
graph.add_edge("generate_schema_", "create_or_update_table_")

graph.add_conditional_edges(
    "create_or_update_table_",
    table_update_result,
    {
        "retry": "generate_schema_",
        "success": "get_updated_bigquery_table_schema_",
    },
)

graph.add_edge("get_updated_bigquery_table_schema_", "transform_schema_to_response_format_")
graph.add_edge(START, "get_random_rows_")
graph.add_edge(["get_random_rows_", "transform_schema_to_response_format_"], "extract_key_values_")
graph.add_edge("extract_key_values_", "insert_json_to_bigquery_")

graph.add_conditional_edges(
    "insert_json_to_bigquery_",
    table_insert_result,

    {
        "retry": "extract_key_values_",
        "success": END
    },
)

memory = MemorySaver()
app = graph.compile(checkpointer=memory)  # Compile the graph


try:
    display(Image(app.get_graph(xray=True).draw_mermaid_png()))
except Exception:
    # This requires some extra dependencies and is optional
    pass

In [None]:
failed_uris = []  # List to store failed file_uris

for file_uri in file_uris[:30]:
    print("\n")
    print(f"{'=' * 10} {file_uri} {'=' * (102 - len(file_uri))}")
    config = {"configurable": {"thread_id": file_uri, "recursion_limit": 50}}

    try:
        app.invoke(
            {
                "project_id": PROJECT_ID,
                "dataset_id": DATASET_ID,
                "table_id": TABLE_ID,
                "model_name": MODEL_NAME,
                "file_uri": file_uri,
            },
            config,
        )
    except Exception as e:
        print(f"Failed to process {file_uri}: {e}")
        failed_uris.append(file_uri)

# You can treat the failed_uris list later
print("Failed URIs to treat later:", failed_uris)

In [None]:
app.get_state()

In [None]:
config = {"configurable": {"thread_id": file_uri}}
app.get_state(config).values["existing_schema"]

In [None]:
app.get_state(config).values["new_schema"]

In [None]:
from langgraph.graph import StateGraph, START, END
from typing import List, Dict, Any, Annotated
from typing_extensions import TypedDict
from langgraph.types import Send
import operator

from IPython.display import Image, display

class InputStateSubset(TypedDict):
    project_id: str
    dataset_id: str
    table_id: str
    file_uri: str
    model_name: str
    response_schema: Dict[str, Any]
    random_rows: List[Dict[str, Any]]

class OutputStateSubset(TypedDict):
    result: str
    
# Define the overall schema, combining both input and output
class OverallStateSubset(InputStateSubset, OutputStateSubset):
    key_values: List[Dict[str, Any]]

# Build the graph with input and output schemas specified
graph_subset = StateGraph(OverallStateSubset, input=InputStateSubset, output=OutputStateSubset)

graph_subset.add_edge(START, "extract_key_values_")
graph_subset.add_edge("extract_key_values_", "insert_json_to_bigquery_")
graph_subset.add_conditional_edges(
    "insert_json_to_bigquery_",
    table_insert_result,

    {
        # If `tools`, then we call the tool node.
        "retry": "extract_key_values_",
        # We may ask the human
        "success": END
    },
)

app_subset = graph_subset.compile()  # Compile the graph


try:
    display(Image(app_subset.get_graph(xray=True).draw_mermaid_png()))
except Exception:
    # This requires some extra dependencies and is optional
    pass

In [None]:
app.get_state(config).values["response_schema"]

In [None]:
app.get_state(config).values["random_rows"]

In [None]:
file_uri = file_uris[10]

app_subset.invoke({"project_id": PROJECT_ID, "dataset_id": DATASET_ID, "table_id": TABLE_ID, "model_name": MODEL_NAME, "file_uri" : file_uri, "response_schema": response_schema, "random_rows": random_rows})

In [None]:
class InputStateBatch(TypedDict):
    project_id: str
    dataset_id: str
    table_id: str
    model_name: str
    file_uris: List[str]
    
class OutputStateBatch(TypedDict):
    results: Annotated[List, operator.add]

# Define the overall schema, combining both input and output
class OverallStateBatch(InputStateBatch, OutputStateBatch):
    file_uri: str
    
def extract_and_insert(state: OverallStateBatch) -> OverallStateBatch:
    project_id = state["project_id"]
    dataset_id = state["dataset_id"]
    table_id = state["table_id"]
    model_name = state["model_name"]
    file_uri = state["uri"]
    result = app.invoke({"project_id": project_id, 
                                "dataset_id": dataset_id, 
                                "table_id": table_id, 
                                "model_name": model_name, 
                                "file_uri" : file_uri})
    
    return {"results": [result]}

def continue_to_insert(state: OverallStateBatch):
    # We will return a list of `Send` objects
    # Each `Send` object consists of the name of a node in the graph
    # as well as the state to send to that nod    
    
    for uri in state["file_uris"]:
        state["uri"] = uri
        state["results"].append(Send("extract_and_insert", state))
    
    return state["results"]


graph_batch = StateGraph(OverallStateBatch, input=InputStateBatch, output=OutputStateBatch)

graph_batch.add_node("extract_and_insert", extract_and_insert)
graph_batch.add_conditional_edges(START, continue_to_insert, ["extract_and_insert"])
graph_batch.add_edge("extract_and_insert", END)

app_batch = graph_batch.compile()

Image(app_batch.get_graph().draw_mermaid_png())


In [None]:
app_batch.invoke({"project_id": PROJECT_ID, "dataset_id": DATASET_ID, "table_id": TABLE_ID, "model_name": MODEL_NAME, "file_uris" : file_uris[:30]})