## Import

In [None]:
import importlib

import torch
from torch_geometric import data
import numpy as np

from awe import utils, filtering, features, extraction, html_utils, awe_graph, visual
from awe.data import swde, live, dataset

for module in [utils, filtering, dataset, swde, live, features, extraction, html_utils, awe_graph, visual]:
    importlib.reload(module)

In [None]:
np.random.seed(42)
torch.manual_seed(42)

## Split data

In [None]:
sds = swde.Dataset(suffix='-exact')
#sds.validate()

In [None]:
SUBSET = slice(2)
vertical = sds.verticals[0]
train_pages = vertical.websites[0].pages[:400] + vertical.websites[1].pages[:100] + vertical.websites[2].pages[:100]
val_pages = vertical.websites[3].pages[:100]
ds = dataset.DatasetCollection()
ds.create('train', train_pages[SUBSET], shuffle=True)
ds.create('val', val_pages[SUBSET])
ds.get_lengths()

In [None]:
ds.summarize_pages_without_visual_features()

## Extract features

In [None]:
ds.features = [
    features.Depth(),
    features.IsLeaf(),
    features.CharCategories(),
    features.FontSize(),
    features.CharIdentifiers(),
    features.WordIdentifiers(),
]
ds.feature_summary

In [None]:
# Uncomment next line to invalidate prepared features.
#ds.delete_saved_root_context()

In [None]:
ds.prepare_features()

In [None]:
len(ds.root.pages), len(ds.root.chars), len(ds.root.tokens)

In [None]:
# Find longest word.
tokens = list(ds.root.tokens)
word = tokens[np.argmax(list(map(len, tokens)))]
word, len(word)

In [None]:
ds.save_root_context()

In [None]:
# Uncomment next line to invalidate computed features.
#ds.delete_saved_features()

In [None]:
ds.compute_features(parallelize=None)

In [None]:
ds.first_dataset.label_map

In [None]:
ds.count_labels()

## Create dataloaders

In [None]:
ds.create_dataloaders(batch_size=4)

In [None]:
for batch in ds['train'].loader:
    print(batch)
    break

## Inspect data

In [None]:
interesting_nodes = (
    (node, x, y)
    for node, x, y in ds['val'].iterate_data()
    if node.labels == ['price']
)
next(iter(interesting_nodes))

In [None]:
# Are all labeled nodes also leaf nodes?
from tqdm.auto import tqdm
def verify_leaf_nodes(names):
    for name in names:
        for page in tqdm(ds[name].pages, desc=name):
            ctx = ds.create_page_context(page)
            for node in ctx.nodes:
                if len(node.labels) != 0 and not node.is_text:
                    print(f'Node {node.xpath} with labels {node.labels} in page {page.identifier} is not leaf.')
                    return
#verify_leaf_nodes(['train', 'val', 'unseen'])

In [None]:
# Compute how many words were found in the pretrained GloVe embeddings.
import collections
stats = collections.defaultdict(int)
if False:
    for batch in ds['train'].loader:
        word_ids = extraction.collate(batch.word_identifiers)
        for t in word_ids:
            stats['pad'] += sum(1 for x in t if x == 0)
            stats['unknown'] += sum(1 for x in t if x == 1)
            stats['found'] += sum(1 for x in t if x >= 2)
            stats['lens'] += len(t)
stats

## Weight labels

In [None]:
def count_label(data: list[data.Data], label: int):
    return sum(1 for d in data for y in d.y if y == label)

def count_labels(data: list[data.Data]):
    return [count_label(data, label) for label in ds.first_dataset.label_map.values()]

In [None]:
label_counts = []
#label_counts = count_labels(ds['train'])
{
    label: count
    for label, count in
    zip(ds.first_dataset.label_map.keys(), label_counts)
}

In [None]:
#label_weights = [len(ds['train']) / count for count in label_counts]
#label_weights

In [None]:
# Manual override
label_count = len(ds.first_dataset.label_map)
label_weights = [1] + [100_000] * (label_count - 1)
label_weights

## Train a model

In [None]:
import pytorch_lightning as pl

from awe import awe_model, gym, utils
for module in [awe_model, gym, utils]:
    importlib.reload(module)

In [None]:
model = awe_model.AweModel(
    feature_count=ds.feature_dim,
    label_count=label_count,
    label_weights=label_weights,
    char_count=len(ds.root.chars),
    use_gnn=False)
model, model.hparams

In [None]:
g = gym.Gym(ds, model)
# Comment next line to restore previously trained model.
g.restore_checkpoint = False
g.get_last_checkpoint_path(), g.get_last_checkpoint_version()

In [None]:
g.trainer = pl.Trainer(
    max_epochs=100,
    callbacks=[gym.CustomProgressBar(refresh_rate=100)],
    resume_from_checkpoint=g.get_last_checkpoint_path(),
    logger=g.create_logger(),
)

In [None]:
g.trainer.logger.version

In [None]:
g.trainer.fit(model, ds['train'].loader, ds['val'].loader)

In [None]:
# TODO: Breaks torch_geometric for some reason!
#g.save_model()

In [None]:
g.save_model_text()

In [None]:
g.save_results('val')

In [None]:
g.save_pages()

In [None]:
# Uncomment and change name to save interesting results.
#g.save_named_version('seed-3')

## Example prediction

In [None]:
from awe import predictor
importlib.reload(predictor)

In [None]:
val_predictor = predictor.Predictor(ds, 'val', model)

In [None]:
predict_total = len(val_predictor.items)
predict_indices = [0, 1, predict_total - 2, predict_total - 1]

In [None]:
val_predictor.evaluate(predict_indices)

In [None]:
val_predictor.get_example_texts(predict_indices)

## Live prediction

In [None]:
urls = [
    'https://www.cars.com/vehicledetail/81d8ee1f-155e-44ec-8ea4-0b25b0ca608a/'
]
live_pages = [live.Page(url) for url in urls]

In [None]:
# Download pages.
[page.dom for page in live_pages]

In [None]:
ds.create('live', live_pages).prepare(skip_existing=False)

In [None]:
live_predictor = predictor.Predictor(ds, 'live', model)

In [None]:
live_predictor.get_example_texts(range(len(live_pages)))