# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
# @title Notebook setup.
%cd ..
import random
import pprint
import tensorflow as tf
from semantic_routing.benchmark import utils
from semantic_routing.benchmark.graphs import city_graph
from semantic_routing.benchmark.graphs import grid_graph
from semantic_routing.benchmark.datasets import touring_dataset
from semantic_routing.benchmark.query_engines import labeled_query_engines
from semantic_routing.tokenization import tokenization

tf.compat.v1.enable_eager_execution()
tokenizer = tokenization.FullTokenizer(
    vocab_file=benchmark.DEFAULT_BERT_VOCAB
)

# Datasets

In [2]:
# @title Setup dataset.
poi_specs = utils.get_poi_specs(benchmark.POI_SPECS_PATH)
engine = labeled_query_engines.HumanLabeledQueryEngine(poi_specs=poi_specs, splits=[0.95, 0, 0.05])
rng = random.Random(0)

Let's generate a few datapoints and count the timing.

In [3]:
# @title Generate grid graph data.
for _ in range(10):
  datapoint = None
  for _ in range(10):
    seed = rng.randint(0, 1e8)
    graph = grid_graph.GridGraph(poi_specs, 900, seed=seed, splits=[1, 0, 0])
    data = touring_dataset.TouringDataset(tokenizer, graph, engine, poi_specs, 0, 128, 128)
    try:
      datapoint = data.sample_datapoint(True, 0, use_fresh=True)
      break
    except TimeoutError:
      continue
  print(datapoint["query_text"])

In [4]:
# @title Generate OSM data.
for _ in range(5):
  datapoint = None
  for _ in range(10):
    seed = rng.randint(0, 1e8)
    graph = city_graph.CityGraph(poi_specs, 20000, seed=seed, splits=[1, 0, 0], use_test_city=True)
    data = touring_dataset.TouringDataset(tokenizer, graph, engine, poi_specs, 0, 128, 128, max_segments=600, auto_simplify_datapoint=True)
    try:
      datapoint = data.sample_datapoint(True, 0, use_fresh=True)
      break
    except TimeoutError:
      continue
  print(datapoint["query_text"])

In [5]:
# For the networkx graph (may be contracted)
datapoint["parent"].road_graph.nx_graph
# Uncontracted networkx graph
data.road_graph.nx_graph
# Rest of the task information:
print(datapoint)

Datapoint features can be padded to a consistent length.

In [6]:
# @title Datapoint feature shapes.
from absl import app
import sys
app.parse_flags_with_usage([""])

datapoint = None
for _ in range(10):
  seed = rng.randint(0, 1e8)
  graph = city_graph.CityGraph(poi_specs, 20000, seed=seed, splits=[1, 0, 0], use_test_city=True)
  data = touring_dataset.TouringDataset(tokenizer, graph, engine, poi_specs, 0, 128, 128, max_segments=600, auto_simplify_datapoint=True)
  try:
    datapoint = data.sample_datapoint(False, 0, use_fresh=True)
    break
  except TimeoutError:
    continue

print("Shape of datapoint features with padding.")
for k, v in datapoint["parent"].featurize_datapoint(datapoint, pad=True).items():
  if isinstance(v, int):
    print(k, "TensorShape(Scalar)")
    continue
  pprint.pprint((k, v.shape))
print()
print("Shape of datapoint features without padding.")
for k, v in datapoint["parent"].featurize_datapoint(datapoint, pad=False).items():
  if isinstance(v, int):
    print(k, "TensorShape(Scalar)")
    continue
  pprint.pprint((k, v.shape))

We can evaluate routes.

In [7]:
# @title Datapoint evaluation statistics.

data = datapoint["parent"]

print("Statistics of ground-truth route with TERM")
x, p = data.road_graph.get_shortest_path_len(datapoint["edgelist"][0], datapoint["end"], datapoint["query_data"], return_path=True)
datapoint["edgelist"] = p + (data.term_token,)
datapoint["ground_truth"] = None
datapoint["candidates"] = ()
pprint.pprint(data.evaluate_datapoint(datapoint))
print()
print("Statistics of ground-truth route without TERM")
x, p = data.road_graph.get_shortest_path_len(datapoint["edgelist"][0], datapoint["end"], datapoint["query_data"], return_path=True)
datapoint["edgelist"] = p
datapoint["ground_truth"] = None
datapoint["candidates"] = ()
pprint.pprint(data.evaluate_datapoint(datapoint))
print()
print("Statistics of ground-truth route without last edge")
x, p = data.road_graph.get_shortest_path_len(datapoint["edgelist"][0], datapoint["end"], datapoint["query_data"], return_path=True)
datapoint["edgelist"] = p[:-1]
datapoint["ground_truth"] = None
datapoint["candidates"] = ()
pprint.pprint(data.evaluate_datapoint(datapoint))
print()
print("Statistics of ground-truth route ignoring POI")
x, p = data.road_graph.get_shortest_path_len(datapoint["edgelist"][0], datapoint["end"], {"linear": "", "pois": ()}, return_path=True)
datapoint["edgelist"] = p
datapoint["ground_truth"] = None
datapoint["candidates"] = ()
pprint.pprint(data.evaluate_datapoint(datapoint))
print()
print("Statistics of ground-truth route ignoring POI and without last edge")
x, p = data.road_graph.get_shortest_path_len(datapoint["edgelist"][0], datapoint["end"], {"linear": "", "pois": ()}, return_path=True)
datapoint["edgelist"] = p[:-1]
datapoint["ground_truth"] = None
datapoint["candidates"] = ()
pprint.pprint(data.evaluate_datapoint(datapoint))