# Setup

In [None]:
!pip install --upgrade google-cloud-aiplatform
!gcloud auth application-default login

In [None]:
!gcloud auth application-default set-quota-project graph-localization

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.image import imread
import random

import base64
import vertexai
from vertexai.generative_models import GenerativeModel, Part, SafetySetting, Image
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
from pprint import pprint
import json
from json.decoder import JSONDecodeError
import os
import numpy as np
from time import sleep

# Code for generating scene graphs using Gemini


In [None]:
prompt = """
Task: Static Environmental Features Analysis

Analyze the provided image to identify the following types of objects:
- buildings
- parks
- trees
- overpasses
- bridges
- the road in the line of sight of the camera

Ignore any temporary objects, such as vehicles or people.

For each of these objects in the scene, generate a structured JSON output as described below. Each object must have:

A unique node ID (e.g., "brick_building_2_story", "oak_tree_left", "road_center", etc.).
A description that excludes spatial information but focuses on the object's attributes.
Defined relationships with every other object using the following spatial relations:
is_left_of
is_right_of
is_above
is_below
is_in_front_of
is_behind

If the type of the object is a building, write the feature as main building materal and estimated number of floors, such as
"glass, 12 stories"

Example Output Format:
{
  "node_id_1": {
    "type": "the type of object, e.g. building or overpass",
    "feature": "as described above. Leave as empty string if not a building.",
    "is_left_of": ["node_id_2", "node_id_3"],
    "is_right_of": ["node_id_4"],
    "is_above": ["node_id_5"],
    "is_below": [],
    "is_in_front_of": ["node_id_3"],
    "is_behind": ["node_id_4"]
  },
  "node_id_2": {
    ...
  },
  ...
}
Requirements:
Ensure that every object has defined relationships with all other objects using the specified spatial relation types.
Maintain consistency and completeness for all relationships.
"""

generation_config = {
    "max_output_tokens": 8192,
    "temperature": 0.1,
    "top_p": 0.95,
    "response_mime_type": "application/json",
}

safety_settings = [
    SafetySetting(
        category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
        threshold=SafetySetting.HarmBlockThreshold.OFF
    ),
    SafetySetting(
        category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
        threshold=SafetySetting.HarmBlockThreshold.OFF
    ),
    SafetySetting(
        category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
        threshold=SafetySetting.HarmBlockThreshold.OFF
    ),
    SafetySetting(
        category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
        threshold=SafetySetting.HarmBlockThreshold.OFF
    ),
]

def generate(prompt_content) -> str:
    """
    Prompt the AI and return the answer as a string.
    """
    vertexai.init(project="graph-localization", location="us-central1")
    model = GenerativeModel("gemini-1.5-pro-002")
    responses = model.generate_content(
        prompt_content,
        generation_config=generation_config,
        safety_settings=safety_settings,
        stream=True,
    )

    output = ""
    for response in responses:
        output += response.text
    return output


def convert_format(input_data, remove_doubles=True) -> dict:
    """
    Convert the decoded JSON object to an object readable by NetworkX
    """

    nodes = []
    edges = set()

    # When adding an edge, add the node pair (excluding the edge type)
    edges_without_type = set()

    # Mapping for edge type transformation
    # We do this as it is simpler to only have 3 relations than 6
    opposite_relation = {
        "is_right_of": "is_left_of",
        "is_behind": "is_in_front_of",
        "is_below": "is_above",
        "is_above": "is_below",
        "is_in_front_of": "is_behind",
        "is_left_of": "is_right_of",
    }

    # Extract nodes
    for source, attributes in input_data.items():
        # Add the node with its feature to the nodes list
        nodes.append({"id": source, "feature": attributes["type"] + ", " + attributes["feature"]})

        # Extract spatial relationships and convert them into edges
        for relation, targets in attributes.items():
            # Skip the feature and type attributes as they are not relations.
            if relation == "feature" or relation == "type":
                continue
            # Add edges for each spatial relationship
            for target in targets:
                if (source, target) in edges_without_type or (target, source) in edges_without_type:
                    # Skip adding the same edge twice
                    continue
                edges_without_type.add((source, target))
                edges.add((source, relation, target))
                edges.add((target, opposite_relation[relation], source))

    # Convert the set of edges to a list of dictionaries
    edges = [{"source": source, "type": edge_type, "target": target} for source, edge_type, target in edges]

    output = {
        "nodes": nodes,
        "edges": edges
    }

    return output

# Code for adding feature embeddings to nodes and edges

In [None]:
def embed_text(texts) -> list[list[float]]:
    """
    Embeds texts with a pre-trained, foundational model using the API.

    Returns:
        A list of embedding vectors for each input text
    """

    task = "SEMANTIC_SIMILARITY"  # This task matches our goal the best
    model = TextEmbeddingModel.from_pretrained("text-embedding-005")

    inputs = [TextEmbeddingInput(text, task) for text in texts]
    embeddings = model.get_embeddings(inputs)

    return [embedding.values for embedding in embeddings]

def add_embeddings(graph_data) -> None:
    """
    Add embeddings to nodes and edges in the graph data.
    """

    # Add the embedding attribute to each node using only one API request
    node_features = [node["feature"] for node in graph_data["nodes"]]
    node_embeddings = embed_text(node_features)
    for node, embedding in zip(graph_data["nodes"], node_embeddings):
        node["embedding"] = embedding

    # Add embeddings to edges. Since edge features have less variety, we make
    # sure to only get the embedding once for each edge type.
    if graph_data["edges"]:
      edge_types = set([edge["type"] for edge in graph_data["edges"]])
      edge_embeddings = embed_text(edge_types)
      embedding_map = {edge_type: embedding for edge_type, embedding in zip(edge_types, edge_embeddings)}
      for edge in graph_data["edges"]:
          edge["embedding"] = embedding_map[edge["type"]]


# Code for visualizing a graph using NetworkX

In [None]:
def display_graph(graph_data, image_path) -> None:
    """
    Dsiplay a figure containing the image and its graph.
    """

    G = nx.node_link_graph(graph_data, edges="edges", directed=True)

    # Set up layout for the graph
    pos = nx.shell_layout(G)
    fig, axes = plt.subplots(1, 2, figsize=(16, 8))

    # Left subplot: Display the image
    image = imread(image_path)
    axes[0].imshow(image)
    axes[0].axis("off")  # Hide axes for the image
    axes[0].set_title(image_path.split("/")[-1])

    # Right subplot: Display the graph
    nx.draw_networkx_nodes(G, pos, node_size=300, node_color="skyblue", edgecolors="black", ax=axes[1])
    nx.draw_networkx_labels(G, pos, verticalalignment="center", horizontalalignment="center",
                            labels={node['id']: f"\n\n\n\n\n\n{node['feature']}" for node in graph_data["nodes"]}, font_size=8, ax=axes[1])
    nx.draw_networkx_edges(G, pos, arrowstyle="->", arrowsize=30, ax=axes[1])
    edge_labels = {(link["source"], link["target"]): link["type"] for link in graph_data["edges"]}
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8, label_pos=0.3, ax=axes[1])
    axes[1].axis("off")  # Hide axes for the graph

    # Adjust spacing and display the plot
    plt.tight_layout()
    plt.show()

# Generate a dataset of graphs



In [None]:
def generate_graph(image_path, embeddings=True, plot=False, remove_multiple_edges=True, attempt=1) -> dict | None:
    """
    Generate a graph for the given image and return the graph data.

    Args:
        image_path (str): The path to the image file.
        embeddings (bool, optional): Whether to add embeddings to the nodes and edges.
        plot (bool, optional): Whether to display the graph.
        remove_multiple_edges (bool, optional): Whether to only use a maximum of one edges between nodes.
            This can be useful if the model insists on having multiple edges between nodes.
        attempt (int, optional): The attempt number, used for retrying in case of failure.

    Returns:
        A dictionary containing the graph data if successful, otherwise None.
    """
    if attempt > 5:
      return None

    # Prompt the LLM
    image = Image.load_from_file(image_path)
    try:
      response = generate([image, prompt])
    except Exception as e:
      # This can sometimes happen due to API limits - waiting usually helps.
      print("Failed to get a response from the LLM:")
      print(e)
      print("Waiting 5 seconds and Trying again!")
      sleep(5)
      return generate_graph(image_path, embeddings, plot, remove_multiple_edges, attempt+1)

    # Convert into a NetworkX readable format
    try:
      graph_data = convert_format(json.loads(response), remove_doubles=remove_multiple_edges)
    except JSONDecodeError:
      print("Got invalid JSON. Trying again!")
      return generate_graph(image_path, embeddings, plot, remove_multiple_edges, attempt+1)
    except Exception as e:
      print("Got an error when generating graph:")
      print(e)
      print("Ignoring and trying again!")
      return generate_graph(image_path, embeddings, plot, remove_multiple_edges, attempt+1)

    # Add embeddings if requested
    if embeddings:
      try:
        add_embeddings(graph_data)
      except Exception as e:
        print("Got an error when adding embeddings:")
        print(e)
        print("Ignoring and trying again!")
        return generate_graph(image_path, embeddings, plot, remove_multiple_edges, attempt+1)

    if plot:
      display_graph(graph_data, image_path)

    return graph_data

In [None]:
from scipy.io import loadmat

# Load the .mat files containing coordinates in both lat_long and cartesian coordinates.
# The cartesian coordinates are constructed so that the euclidean distance between two points is the real-world distance in meters.
lat_long_data = loadmat('/content/drive/Shareddrives/CS224W/Dataset/GPS_Long_Lat_Compass.mat')
cartesian_data = loadmat('/content/drive/Shareddrives/CS224W/Dataset/Cartesian_Location_Coordinates.mat')

# Get all available images
img_dir = "/content/drive/Shareddrives/CS224W/Dataset/all"
number_of_files = len(os.listdir(img_dir))
sorted_img_list = sorted(os.listdir(img_dir))
print("Number of images: ", len(sorted_img_list))

In [None]:
import pandas as pd
from datetime import datetime

# Create maching lists of coordinates, both lat-long and cartesian.
lat_longs = []
cartesians = []
for filename in sorted_img_list:
    index = int(filename.split("_")[0]) - 1
    lat_longs.append(lat_long_data["GPS_Compass"][index])
    cartesians.append(cartesian_data["XYZ_Cartesian"][index])

# Merge these lists into a dataframe.
full_df = pd.DataFrame(
    np.concatenate((lat_longs, cartesians), axis=1),
    columns=['GPS_Lat', 'GPS_Long', 'Compass', 'X', 'Y', 'Z'],
    )
full_df['image_filename'] = sorted_img_list

# Adding image direction simplifies any filtering we might want to do later.
full_df['image_direction'] = [int(filename.split(".")[0][-1]) for filename in full_df['image_filename']]

# Graphs share the same filename, just a different file type.
full_df['graph_filename'] = [filename.replace(".jpg", ".json") for filename in full_df['image_filename']]

# Create output folder
date_time_str = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
output_folder_path = f"/content/drive/Shareddrives/CS224W/outputs/{date_time_str}"
os.makedirs(output_folder_path)

# Filter for images facing in the driving direciton
df = full_df.loc[full_df["image_direction"] == 4]
print("Number of images in the chosen direction: ", len(df))

# Only keep images from Pittsburgh
df = df[(df['GPS_Long']<=-79.8)&(df['GPS_Long']>=-80.01)&(df['GPS_Lat']>=40)]
print("Number of images from Pittsburgh: ", len(df))

# Only keep a certain number of images
NUMBER_OF_ROWS = 1000
df = df.head(NUMBER_OF_ROWS)

# Save a draft of the dataframe now. Some graphs might prove problematic to
# generate, and might get skipped. We collect the problematic image indices
# and delete those rows before saving the final dataframe.
df.to_csv(os.path.join(output_folder_path, "df_initial.csv"))
failed_row_indices = []

# In case of previous failure, set path manually here and set continue_from to
# where to resume from.
# output_folder_path = "/content/drive/Shareddrives/CS224W/outputs/2024-12-08_02:36:13"
continue_from = 0

# Generate a graph for each image
i = 0  # iterrows returns row index which is not suitable for progress updates
for row_index, row in df.iterrows():
  if (i:=i+1) < continue_from:
    continue

  print(f"Generating graph for {row['image_filename']} ({(i)}/{len(df)})")

  # Generate Graph data
  image_path = os.path.join(img_dir, row['image_filename'])
  graph_data = generate_graph(image_path, embeddings=True, plot=False)

  if graph_data is None:
    failed_row_indices.append(row_index)
    print(f"Failed to generate graph for {row['image_filename']}")
    continue

  # Save Graph data
  graph_path = os.path.join(output_folder_path, row['graph_filename'])
  json_object = json.dumps(graph_data, indent=4)
  with open(graph_path, "w") as outfile:
    outfile.write(json_object)

  print()

# Delete all failed rows and save the final dataframe
print("Failed rows: ", failed_row_indices)
df = df.drop(failed_row_indices)
df.to_csv(os.path.join(output_folder_path, "df_final.csv"))

In [None]:
# Verify that all output files are present
files = set(os.listdir(output_folder_path))
for filename in df['graph_filename']:
  if filename not in files:
    print(f"Missing {filename}")