In [None]:
!pip install langchain-experimental langchain-community langchain networkx langchain-google-genai langchain-core json-repair

Collecting langchain-experimental
  Downloading langchain_experimental-0.3.2-py3-none-any.whl.metadata (1.7 kB)
Collecting langchain-community
  Downloading langchain_community-0.3.4-py3-none-any.whl.metadata (2.9 kB)
Collecting langchain-google-genai
  Downloading langchain_google_genai-2.0.3-py3-none-any.whl.metadata (3.9 kB)
Collecting json-repair
  Downloading json_repair-0.30.0-py3-none-any.whl.metadata (10 kB)
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain-community)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting httpx-sse<0.5.0,>=0.4.0 (from langchain-community)
  Downloading httpx_sse-0.4.0-py3-none-any.whl.metadata (9.0 kB)
Collecting langchain
  Downloading langchain-0.3.6-py3-none-any.whl.metadata (7.1 kB)
Collecting langchain-core
  Downloading langchain_core-0.3.14-py3-none-any.whl.metadata (6.3 kB)
Collecting pydantic-settings<3.0.0,>=2.4.0 (from langchain-community)
  Downloading pydantic_settings-2.6.0-py3-none-any.whl.metadat

In [None]:
# Cell 1: Install and import necessary libraries

# Install required libraries
# !pip install langchain
# !pip install networkx
# !pip install google-generative-ai
# !pip install matplotlib

# Import standard libraries
import json
import networkx as nx
import matplotlib.pyplot as plt

import os
from langchain_experimental.graph_transformers import LLMGraphTransformer
# from langchain_google_genai import GoogleGenerativeAI
import networkx as nx
from langchain.chains import GraphQAChain
from langchain_core.documents import Document
from langchain_community.graphs.networkx_graph import NetworkxEntityGraph
import json
import matplotlib.pyplot as plt

# Set up the Google API key (Replace 'YOUR_GOOGLE_API_KEY' with your actual API key)
GOOGLE_API_KEY = 'AIzaSyBLijOWUk_V3Wk6IAk4ewEw_uVJaoc0pBY'  # Replace with your API key

In [None]:
# Cell 2: Authenticate with Google Generative AI

import os

# Set the environment variable for the API key
os.environ['GOOGLE_API_KEY'] = GOOGLE_API_KEY


In [None]:
# Cell 3: Load the JSON data from the edh.json file

# Upload your edh.json file
from google.colab import files
uploaded = files.upload()

# Load the JSON data
import io

data = None
for filename in uploaded.keys():
    if filename == 'edh.json':
        data = json.load(io.StringIO(uploaded['edh.json'].decode('utf-8')))
        print("edh.json loaded successfully.")
        break

if data is None:
    print("edh.json not found. Please upload the file and try again.")

Saving edh.json to edh.json
edh.json loaded successfully.


In [None]:
# Cell 4: Initialize the knowledge graph

# Initialize an empty NetworkxEntityGraph
kg = NetworkxEntityGraph()


In [None]:
from langchain.graphs.networkx_graph import KnowledgeTriple


In [None]:
def build_graph_from_json(kg, data):
    # Initialize dictionaries to keep track of nodes
    agents = {}
    interactions_nodes = {}
    positions = {}
    image_frames = {}

    # 1. Add Agent nodes
    for interaction in data.get('interactions', []):
        agent_id = interaction.get('agent_id')
        if agent_id is not None:
            agent_label = f"Agent_{agent_id}"
            if not kg.has_node(agent_label):
                agent_type = 'Commander' if agent_id == 0 else 'Driver'
                kg.add_node(agent_label)
                agents[agent_id] = agent_label
        else:
            print(f"Warning: agent_id is None in interaction: {interaction}")

    # 2. Add Action nodes and [Agent] --Performs--> [Action]
    actions = {}
    for interaction in data.get('interactions', []):
        action_id = interaction.get('action_id')
        time_start = interaction.get('time_start', '')
        if action_id is not None:
            action_label = f"Action_{action_id}_{time_start}"
            if not kg.has_node(action_label):
                kg.add_node(action_label)
                actions[action_label] = action_label

            # Add edge: Agent --Performs--> Action
            agent_id = interaction.get('agent_id')
            agent_label = agents.get(agent_id)
            if agent_label is not None:
                triple = KnowledgeTriple(subject=agent_label, predicate='Performs', object_=action_label)
                kg.add_triple(triple)
            else:
                print(f"Warning: agent_label is None for agent_id: {agent_id}")

            # Add edge: Action --Involves_Object--> Object (if applicable)
            oid = interaction.get('oid')
            if oid:
                if not kg.has_node(oid):
                    object_type = oid.split('|')[0]
                    kg.add_node(oid)
                triple = KnowledgeTriple(subject=action_label, predicate='Involves_Object', object_=oid)
                kg.add_triple(triple)

            # 3. Create Interaction nodes and edges
            time_start = interaction.get('time_start')
            if time_start is not None:
                interaction_label = f"Interaction_{time_start}"
                if not kg.has_node(interaction_label):
                    kg.add_node(interaction_label)
                    interactions_nodes[interaction_label] = interaction_label
                    # Add edge: Interaction --Involves_Agent--> Agent
                    if agent_label is not None:
                        triple = KnowledgeTriple(subject=interaction_label, predicate='Involves_Agent', object_=agent_label)
                        kg.add_triple(triple)
                    else:
                        print(f"Warning: agent_label is None for agent_id: {agent_id}")

                    # Add edge: Interaction --Involves_Object--> Object (if applicable)
                    if oid:
                        triple = KnowledgeTriple(subject=interaction_label, predicate='Involves_Object', object_=oid)
                        kg.add_triple(triple)
                    # Add edge: Interaction --Related_Action--> Action
                    triple = KnowledgeTriple(subject=interaction_label, predicate='Related_Action', object_=action_label)
                    kg.add_triple(triple)
            else:
                print(f"Warning: time_start is None in interaction: {interaction}")
        else:
            print(f"Warning: action_id is None in interaction: {interaction}")

    # 4. Add Dialogue nodes and [Agent] --Communicates--> [Dialogue]
    dialogues = data.get('dialog_history_with_das', [])
    for idx, dialogue in enumerate(dialogues):
        dialogue_id = f"Dialogue_{idx}"
        if not kg.has_node(dialogue_id):
            kg.add_node(dialogue_id)
            # Add edge: Agent --Communicates--> Dialogue
            speaker = dialogue.get('speaker')
            if speaker is not None:
                agent_id = 0 if speaker == 'Commander' else 1
                agent_label = agents.get(agent_id)
                if agent_label is not None:
                    triple = KnowledgeTriple(subject=agent_label, predicate='Communicates', object_=dialogue_id)
                    kg.add_triple(triple)
                else:
                    print(f"Warning: agent_label is None for agent_id: {agent_id}")
            else:
                print(f"Warning: speaker is None in dialogue: {dialogue}")

            # Add edge: Dialogue --Instructs--> Subgoal (if applicable)
            das = dialogue.get('da_metadata', {}).get('das', [])
            if 'Instruction' in das:
                for idx, subgoal in enumerate(data.get('future_subgoals', [])):
                    subgoal_id = f"Subgoal_{idx}"
                    if not kg.has_node(subgoal_id):
                        kg.add_node(subgoal_id)
                    triple = KnowledgeTriple(subject=dialogue_id, predicate='Instructs', object_=subgoal_id)
                    kg.add_triple(triple)

    # 5. Add [Agent] --Achieves--> [Subgoal]
    for agent_id, agent_label in agents.items():
        if agent_label == 'Agent_1':  # Assuming Driver is Agent_1
            for idx, subgoal in enumerate(data.get('future_subgoals', [])):
                subgoal_id = f"Subgoal_{idx}"
                if not kg.has_node(subgoal_id):
                    kg.add_node(subgoal_id)
                triple = KnowledgeTriple(subject=agent_label, predicate='Achieves', object_=subgoal_id)
                kg.add_triple(triple)

    # 6. Add Object nodes and [Object] --Contains--> [Object]
    final_state_objects = data.get('final_state_diff', {}).get('objects', {})
    for obj_id, obj_data in final_state_objects.items():
        if obj_id is not None:
            if not kg.has_node(obj_id):
                kg.add_node(obj_id)
            # Add Contains relationships
            receptacles = obj_data.get('receptacleObjectIds', [])
            if receptacles:
                for contained_obj_id in receptacles:
                    if contained_obj_id is not None:
                        if not kg.has_node(contained_obj_id):
                            kg.add_node(contained_obj_id)
                        triple = KnowledgeTriple(subject=obj_id, predicate='Contains', object_=contained_obj_id)
                        kg.add_triple(triple)
                    else:
                        print(f"Warning: contained_obj_id is None in object: {obj_id}")
            # Add [Object] --Located_At--> [Position]
            position = obj_data.get('position')
            if position and 'x' in position and 'y' in position and 'z' in position:
                pos_label = f"Position_{position['x']}_{position['y']}_{position['z']}"
                if not kg.has_node(pos_label):
                    kg.add_node(pos_label)
                triple = KnowledgeTriple(subject=obj_id, predicate='Located_At', object_=pos_label)
                kg.add_triple(triple)
            else:
                print(f"Warning: position is invalid for object: {obj_id}")
        else:
            print(f"Warning: obj_id is None in final_state_objects")

    # 7. Add ImageFrame nodes and [ImageFrame] --Captured_At--> [Action]/[Interaction]
    driver_images = data.get('driver_images_future', [])
    for image in driver_images:
        if image:
            image_label = image
            if not kg.has_node(image_label):
                kg.add_node(image_label)
                try:
                    image_time = float(image.split('.')[2])
                except (IndexError, ValueError) as e:
                    print(f"Warning: Invalid image time in image: {image}")
                    continue
                # Find the closest interaction or action
                closest_time_diff = float('inf')
                closest_node = None
                for interaction_label in interactions_nodes:
                    try:
                        interaction_time = float(interaction_label.split('_')[1])
                    except (IndexError, ValueError) as e:
                        continue
                    time_diff = abs(image_time - interaction_time)
                    if time_diff < closest_time_diff:
                        closest_time_diff = time_diff
                        closest_node = interaction_label
                if closest_node:
                    triple = KnowledgeTriple(subject=image_label, predicate='Captured_At', object_=closest_node)
                    kg.add_triple(triple)
                else:
                    print(f"Warning: No matching interaction for image: {image}")
        else:
            print(f"Warning: image is None or empty in driver_images")

    # 8. Add Agent positions (if available) and [Agent] --Located_At--> [Position]
    # If agent positions are available, you can add them similarly to objects

# Build the graph
build_graph_from_json(kg, data)

print("Knowledge graph constructed according to the schema.")


Knowledge graph constructed according to the schema.


In [None]:
# Cell 6: Explore the knowledge graph

# Get all triples (edges)
triples = kg.get_triples()

# Extract unique nodes
unique_nodes = set()
for triple in triples:
    unique_nodes.add(triple.subject)  # Subject
    unique_nodes.add(triple.object_)   # Object

print("Nodes in the knowledge graph:")
for node in unique_nodes:
    print(node)

print("\nEdges in the knowledge graph:")
for triple in triples:
    print(f"{triple.subject} --{triple.predicate}--> {triple.object_}")


AttributeError: 'tuple' object has no attribute 'subject'

In [None]:
# Cell 8: Use GraphQAChain to query the knowledge graph

# Initialize the LLM
llm = GooglePalm(
    api_key=GOOGLE_API_KEY,
    model_name='models/text-bison-001',  # Adjust as necessary
    temperature=0.0
)

# Initialize the GraphQAChain
graph_qa_chain = GraphQAChain(llm=llm, graph=kg)

# Example query
query = "Which objects did Agent_1 interact with during the task?"

# Run the query
answer = graph_qa_chain.run(query)

print(f"Query: {query}")
print(f"Answer: {answer}")

AttributeError: 'NetworkxEntityGraph' object has no attribute 'graph'