In [None]:
import itertools
import importlib

import numpy as np
import networkx as nx

import torch
from torch_geometric import utils as gutils

from awe import features, graph_utils
from awe.data import swde, dataset

for module in [features, graph_utils, dataset, swde]:
    importlib.reload(module)

In [None]:
sds = swde.Dataset(suffix='-exact')

In [None]:
SUBSET = slice(None)
# HACK: Some websites are wrongly extracted, so skip them for now.
websites = [w for w in sds.verticals[0].websites if w.name not in ['motortrend', 'msn', 'yahoo']]
rng = np.random.default_rng(42)
website_indices = rng.choice(len(websites), 5, replace=False)
train_pages = [
    p for i in website_indices
    for p in rng.choice(websites[i].pages, 300)
]
val_pages = [
    p for i in range(len(websites))
    if i not in website_indices
    for p in rng.choice(websites[i].pages, 100)
]
ds = dataset.DatasetCollection()
ds.create('train', train_pages[SUBSET])
ds.create('val_unseen', val_pages[SUBSET])
ds.create('val_seen', rng.choice(train_pages[SUBSET], 200))
ds.get_lengths()

In [None]:
ds.features = [
    features.Depth(),
    features.IsLeaf(),
    features.CharCategories(),
    features.FontSize(),
    features.CharIdentifiers(),
    features.WordIdentifiers()
]

In [None]:
ds.create_dataloaders(batch_size=2)

In [None]:
interesting_nodes = (
    (ctx.page.file_path, node, batch.x[idx], batch.y[idx], batch.edge_index)
    for ctx, node, batch, idx in ds['train'].iterate_data()
    if node.labels == ['price']
)
iterator = itertools.islice(interesting_nodes, 0, None)
next(iterator)

In [None]:
import bokeh.io
import bokeh.plotting
bokeh.io.output_notebook()

In [None]:
page = ds['train'].pages[0]
page_ctx = ds.prepare_page_context(page)
page.file_path

In [None]:
graph = graph_utils.to_networkx(page_ctx)

In [None]:
bokeh_graph = bokeh.plotting.from_networkx(graph, nx.spring_layout)

In [None]:
plot = bokeh.plotting.figure()
plot.renderers.append(bokeh_graph)
bokeh.io.show(plot)