## Import

In [None]:
import importlib

import torch
from torch_geometric import data
import numpy as np

from awe import utils, filtering, features, html_utils, awe_graph
from awe.data import swde, live, dataset
from awe.features import extraction

for module in [utils, filtering, dataset, swde, live, extraction, html_utils, awe_graph]:
    importlib.reload(module)
utils.reload('awe.features', 'awe.visual')

In [None]:
np.random.seed(42)
torch.manual_seed(42)

## Split data

In [None]:
CHECK_VALIDITY = False
sds = swde.Dataset(suffix='-exact')
invalid_pages = [] if not CHECK_VALIDITY else sds.validate(
    parallelize=16,
    skip=0,
    verticals=sds.verticals[:1],
    collect_errors=True,
    #error_callback=lambda i, _, e: print(f'{i}: {str(e)}'),
    save_list=True,
    read_list=True,
)

In [None]:
utils.summarize_pages(p for _, p, _ in invalid_pages)

In [None]:
SUBSET = slice(None)
websites = sds.verticals[0].websites
rng = np.random.default_rng(42)
website_indices = rng.choice(len(websites), 5, replace=False)
train_pages = [
    p for i in website_indices
    for p in websites[i].pages
]
val_pages = [
    p for i in range(len(websites))
    if i not in website_indices
    for p in rng.choice(websites[i].pages, 50, replace=False)
]
ds = dataset.DatasetCollection()
ds.create('train', train_pages[SUBSET], shuffle=True)
ds.create('val_unseen', val_pages[SUBSET])
ds.create('val_seen', rng.choice(train_pages[SUBSET], SUBSET.stop or 200, replace=False))
ds.get_lengths()

In [None]:
{ name: set(p.site.name for p in items.pages) for name, items in ds.datasets.items() }

In [None]:
ds.summarize_pages_without_visual_features()

## Extract features

In [None]:
ds.features = [
    features.Depth(),
    features.IsLeaf(),
    features.CharCategories(),
    features.Visuals(),
    features.CharIdentifiers(),
    features.WordIdentifiers()
]

### Prepare context

In [None]:
ds.root.describe()

In [None]:
# Uncomment next line to invalidate prepared features.
#ds.delete_saved_root_context()

In [None]:
if ds.root.cutoff_words is None:
    ds.root.cutoff_words = 15
if ds.root.cutoff_word_length is None:
    ds.root.cutoff_word_length = 10
ds.root.extract_options()

In [None]:
prev_root = ds.root.describe()
prev_root

In [None]:
ds.prepare_features(parallelize=8)

In [None]:
curr_root = ds.root.describe()
curr_root

In [None]:
ds.root.describe_visual_categorical()

In [None]:
ds.root.freeze()

In [None]:
ds.save_root_context(
    overwrite_existing=(prev_root['pages'] != curr_root['pages'])
)

### Compute features

In [None]:
# Uncomment next line to update shapes of previously-computed features.
#ds.update_features(parallelize=16)

In [None]:
# Uncomment next line to invalidate computed features.
#ds.delete_saved_features(parallelize=16)

In [None]:
ds.compute_features(parallelize=6)

In [None]:
ds.first_dataset.label_map

In [None]:
ds.count_labels()

## Create dataloaders

In [None]:
ds.create_dataloaders(batch_size=64, num_workers=8)

In [None]:
[ds['train'][i] for i in [0, 1]]

In [None]:
for batch in ds['train'].loader:
    print(batch)
    break

In [None]:
# %%timeit
# for batch in ds['train'].loader:
#     break

## Inspect data

In [None]:
import itertools
interesting_nodes = (
    (ctx.page.file_path, node, batch.x[idx], batch.y[idx])
    for ctx, node, batch, idx in ds['val_seen'].iterate_data()
    if node.labels == ['price']
)
iterator = itertools.islice(interesting_nodes, 0, None)
next(iterator)

In [None]:
ds.feature_summary

In [None]:
# Are all labeled nodes also leaf nodes?
from tqdm.auto import tqdm
def verify_leaf_nodes(names):
    for name in names:
        for page in tqdm(ds[name].pages, desc=name):
            ctx = ds.create_page_context(page)
            for node in ctx.nodes:
                if len(node.labels) != 0 and not node.is_text:
                    print(f'Node {node.xpath} with labels {node.labels} in page {page.identifier} is not leaf.')
                    return
#verify_leaf_nodes(['train', 'val', 'unseen'])

In [None]:
# Compute how many words were found in the pretrained GloVe embeddings.
import collections
from tqdm.auto import tqdm
stats = collections.defaultdict(int)
if False:
    for batch in tqdm(ds['train'].loader, desc='train'):
        word_ids = extraction.collate(batch.word_identifiers)
        for t in word_ids:
            stats['unknown'] += sum(1 for x in t if x == 0)
            stats['found'] += sum(1 for x in t if x >= 1)
            stats['lens'] += len(t)
stats

## Weight labels

In [None]:
def count_label(data: list[data.Data], label: int):
    return sum(1 for d in data for y in d.y if y == label)

def count_labels(data: list[data.Data]):
    return [count_label(data, label) for label in ds.first_dataset.label_map.values()]

In [None]:
label_counts = []
#label_counts = count_labels(ds['train'])
{
    label: count
    for label, count in
    zip(ds.first_dataset.label_map.keys(), label_counts)
}

In [None]:
#label_weights = [len(ds['train']) / count for count in label_counts]
#label_weights

In [None]:
# Manual override
label_count = len(ds.first_dataset.label_map)
label_weights = [1] + [300] * (label_count - 1)
label_weights

## Train a model

In [None]:
import pytorch_lightning as pl

from awe import awe_model, gym, utils
from awe.data import data_module
from awe.features import extraction

for module in [awe_model, gym, utils, data_module, extraction]:
    importlib.reload(module)

In [None]:
model = awe_model.AweModel(
    feature_count=ds.feature_dim,
    label_count=label_count,
    label_weights=label_weights,
    char_count=len(ds.root.chars) + 1,
    use_gnn=True,
    use_lstm=True,
    use_cnn=False,
    lstm_args={
        # https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
        'bidirectional': True,
        'num_layers': 2
    },
    filter_node_words=True,
    label_smoothing=0.1
)
model, model.hparams

In [None]:
g = gym.Gym(ds, model)
# Comment next line to restore previously trained model.
g.restore_checkpoint = False
g.get_last_checkpoint_path(), g.get_last_checkpoint_version()

In [None]:
g.trainer = pl.Trainer(
    gpus=torch.cuda.device_count(),
    max_epochs=50,
    callbacks=[gym.CustomProgressBar(refresh_rate=10)],
    resume_from_checkpoint=g.get_last_checkpoint_path(),
    logger=g.create_logger(),
)

In [None]:
g.trainer.logger.version

In [None]:
g.save_inputs()
g.save_model_text()

In [None]:
g.trainer.fit(model, data_module.DataModule(ds))

In [None]:
# TODO: Breaks torch_geometric for some reason!
#g.save_model()

In [None]:
g.save_results('val_unseen')

In [None]:
g.save_results('val_seen')

In [None]:
# Uncomment and change name to save interesting results.
#g.save_named_version('29-lstm-2-layers')

## Example prediction

In [None]:
from awe import predictor

for module in [predictor, awe_graph, awe_model]:
    importlib.reload(module)

In [None]:
# Predict on manually-selected pages.
PREDICT_MANUAL = False
if PREDICT_MANUAL:
    ds.create('pred', sds.verticals[0].websites[0].pages[:4])
    ds.compute_features()

In [None]:
target_ds = 'pred' if PREDICT_MANUAL else 'val_unseen'
val_predictor = predictor.Predictor(ds, target_ds, model)

In [None]:
predict_total = len(val_predictor.items)
predict_indices = np.random.choice(predict_total, 4, replace=False)
predict_indices

In [None]:
with awe_graph.HtmlPageCaching():
    pred_metrics = val_predictor.evaluate(predict_indices)
    pred_texts = val_predictor.get_example_texts(predict_indices)

In [None]:
pred_metrics

In [None]:
pred_texts

## Live prediction

In [None]:
urls = [
    'https://www.cars.com/vehicledetail/81d8ee1f-155e-44ec-8ea4-0b25b0ca608a/'
]
live_pages = [live.Page(url) for url in urls]

In [None]:
# Download pages.
[page.dom for page in live_pages]

In [None]:
ds.create('live', live_pages)

In [None]:
ds.compute_features()

In [None]:
live_predictor = predictor.Predictor(ds, 'live', model)

In [None]:
live_predictor.get_example_texts(range(len(live_pages)))