## Import

In [None]:
import importlib
from typing import TypeVar

import torch
from torch_geometric import data, loader
import numpy as np

from awe import utils, filtering, features, html_utils, awe_graph, visual
from awe.data import swde, live, dataset

for module in [utils, filtering, dataset, swde, live, features, html_utils, awe_graph, visual]:
    importlib.reload(module)

In [None]:
np.random.seed(42)
torch.manual_seed(42)

In [None]:
sds = swde.Dataset(suffix='-exact')
#sds.validate()
ds = dataset.DatasetCollection()

## Split data

In [None]:
T = TypeVar('T')
def train_val_split(data: list[T], val_split: float):
    split = int(np.floor(val_split * len(data)))
    copy = list(data)
    np.random.seed(42)
    np.random.shuffle(copy)
    return copy[split:], copy[:split]

In [None]:
vertical = sds.verticals[0]
#train_pages, val_pages = train_val_split(website.pages[:100], .2)
train_pages = vertical.websites[0].pages[:300]
val_pages = vertical.websites[0].pages[300:400]
unseen_pages = vertical.websites[1].pages[:100]
len(train_pages), len(val_pages), len(unseen_pages)

## Prepare datapoints

In [None]:
ds.features = [
    features.Depth(),
    features.IsLeaf(),
    features.CharCategories(),
    features.FontSize(),
    features.WordEmbedding(),
]

In [None]:
ds.feature_dim, ds.feature_labels

In [None]:
ds.create('train', train_pages)
ds.create('val', val_pages)
ds.create('unseen', unseen_pages)

In [None]:
# Uncomment next line to invalidate computed features.
#ds.delete_saved_features()

In [None]:
ds.initialize()

In [None]:
ds.parallelize = None
ds['train'].prepare()
ds['val'].prepare()
ds['unseen'].prepare()

In [None]:
ds.first_dataset.label_map

## Create dataloaders

In [None]:
BATCH_SIZE = 1
ds['train'].loader = loader.DataLoader(ds['train'], batch_size=BATCH_SIZE, shuffle=True)
ds['val'].loader = loader.DataLoader(ds['val'], batch_size=BATCH_SIZE)
ds['unseen'].loader = loader.DataLoader(ds['unseen'], batch_size=BATCH_SIZE)

In [None]:
for batch in ds['train'].loader:
    print(batch)
    break

In [None]:
len(ds['train'].loader) + len(ds['val'].loader)

## Inspect data

In [None]:
interesting_nodes = (
    (node, x, y)
    for node, x, y in ds['val'].iterate_data()
    if node.labels == ['price']
)
next(iter(interesting_nodes))

In [None]:
# Are all labeled nodes also leaf nodes?
from tqdm.auto import tqdm
def verify_leaf_nodes(names):
    for name in names:
        for page in tqdm(ds[name].pages, desc=name):
            ctx = ds.create_context(page)
            for node in ctx.nodes:
                if len(node.labels) != 0 and not node.is_text:
                    print(f'Node {node.xpath} with labels {node.labels} in page {page.identifier} is not leaf.')
                    return
#verify_leaf_nodes(['train', 'val', 'unseen'])

## Weight labels

In [None]:
def count_label(data: list[data.Data], label: int):
    return len([1 for d in data for y in d.y if y == label])

def count_labels(data: list[data.Data]):
    return [count_label(data, label) for label in ds.first_dataset.label_map.values()]

In [None]:
#label_counts = count_labels(ds['train'])
#label_counts, len(ds['train'])

In [None]:
#label_weights = [len(ds['train']) / count for count in label_counts]
#label_weights

In [None]:
# Manual override
label_count = len(ds.first_dataset.label_map)
label_weights = [1] + [100_000] * (label_count - 1)
label_weights

## Train a model

In [None]:
import pytorch_lightning as pl

from awe import awe_model, gym, utils
for module in [awe_model, gym, utils]:
    importlib.reload(module)

In [None]:
model = awe_model.AweModel(ds.feature_dim, label_count, label_weights)
model

In [None]:
g = gym.Gym(ds, model)
# Uncomment next line to re-train a model.
g.restore_checkpoint = False
g.get_last_checkpoint_path(), g.get_last_checkpoint_version()

In [None]:
g.trainer = pl.Trainer(
    max_epochs=50,
    callbacks=[gym.CustomProgressBar(refresh_rate=100)],
    resume_from_checkpoint=g.get_last_checkpoint_path(),
    logger=g.create_logger(),
)

In [None]:
g.trainer.logger.version

In [None]:
g.trainer.fit(model, ds['train'].loader, ds['val'].loader)

In [None]:
# TODO: Breaks torch_geometric for some reason!
#g.save_model()

In [None]:
g.save_model_text()

In [None]:
g.save_results('val')

In [None]:
g.save_results('unseen')

In [None]:
g.save_pages()

In [None]:
# Uncomment and change name to save interesting results.
#g.save_named_version('only-leaf-300')

## Example prediction

In [None]:
from awe import predictor
importlib.reload(predictor)

In [None]:
val_predictor = predictor.Predictor(ds, 'unseen', model)

In [None]:
predict_total = len(val_predictor.items)
predict_indices = [0, 1, predict_total - 2, predict_total - 1]

In [None]:
val_predictor.evaluate(predict_indices)

In [None]:
val_predictor.get_example_texts(predict_indices)

## Live prediction

In [None]:
urls = [
    'https://www.cars.com/vehicledetail/81d8ee1f-155e-44ec-8ea4-0b25b0ca608a/'
]
live_pages = [live.Page(url) for url in urls]

In [None]:
# Download pages.
[page.dom for page in live_pages]

In [None]:
ds.create('live', live_pages).prepare(skip_existing=False)

In [None]:
live_predictor = predictor.Predictor(ds, 'live', model)

In [None]:
live_predictor.get_example_texts(range(len(live_pages)))