## Import

In [None]:
import importlib
from typing import TypeVar

import torch
from torch_geometric import data, loader
import numpy as np

from awe import utils, features, html_utils, awe_graph, predictor, visual
from awe.data import swde, live, dataset

for module in [utils, dataset, swde, live, features, html_utils, awe_graph, predictor, visual]:
    importlib.reload(module)

In [None]:
np.random.seed(42)
torch.manual_seed(42)

In [None]:
sds = swde.Dataset(suffix='-exact')
#sds.validate()
ds = dataset.DatasetCollection()

## Load visual attributes

In [None]:
page = sds.verticals[0].websites[0].pages[0]
page.file_path

In [None]:
ctx = ds.create_context(page)

In [None]:
len(ctx.nodes), ctx.nodes[:10]

In [None]:
[n for n in ctx.nodes if n.labels != []]

## Split data

In [None]:
T = TypeVar('T')
def train_val_split(data: list[T], val_split: float):
    split = int(np.floor(val_split * len(data)))
    copy = list(data)
    np.random.seed(42)
    np.random.shuffle(copy)
    return copy[split:], copy[:split]

In [None]:
vertical = sds.verticals[0]
#train_pages, val_pages = train_val_split(website.pages[:100], .2)
train_pages = vertical.websites[0].pages[:100] + vertical.websites[1].pages[:100]
val_pages = vertical.websites[2].pages[:100]
len(train_pages), len(val_pages)

## Prepare datapoints

In [None]:
#we = features.WordEmbedding()
#we.dimension

In [None]:
ds.features = [
    features.Depth(),
    features.IsLeaf(),
    features.CharCategories(),
    #we
]

In [None]:
ds.feature_dim, ds.feature_labels

In [None]:
ds.create('train', train_pages).prepare(skip_existing=True)

In [None]:
ds.create('val', val_pages).prepare(skip_existing=True)

In [None]:
ds.label_map

## Create dataloaders

In [None]:
BATCH_SIZE = 1
ds.loaders['train'] = loader.DataLoader(ds.data['train'], batch_size=BATCH_SIZE, shuffle=True)
ds.loaders['val'] = loader.DataLoader(ds.data['val'], batch_size=BATCH_SIZE)

In [None]:
for batch in ds.loaders['train']:
    print(batch)
    break

In [None]:
len(ds.loaders['train']) + len(ds.loaders['val'])

## Inspect data

In [None]:
interesting_nodes = (
    (node, x, y)
    for node, x, y in ds.iterate_data('val')
    if node.labels == ['price']
)
next(iter(interesting_nodes))

## Weight labels

In [None]:
def count_label(data: list[data.Data], label: int):
    return len([1 for d in data for y in d.y if y == label])

def count_labels(data: list[data.Data]):
    return [count_label(data, label) for label in ds.label_map.values()]

In [None]:
#label_counts = count_labels(ds.data['train'])
#label_counts, len(ds.data['train'])

In [None]:
#label_weights = [len(ds.data['train']) / count for count in label_counts]
#label_weights

In [None]:
# Manual override
label_count = len(ds.label_map)
label_weights = [1] + [100_000] * (label_count - 1)
label_weights

## Train a model

In [None]:
import pytorch_lightning as pl

from awe import awe_model, gym, utils
for module in [awe_model, gym, utils]:
    importlib.reload(module)

In [None]:
model = awe_model.AweModel(ds.feature_dim, label_count, label_weights)
model

In [None]:
g = gym.Gym(ds, model)
g.restore_checkpoint = False
g.get_last_checkpoint_path()

In [None]:
g.trainer = pl.Trainer(
    max_epochs=50,
    progress_bar_refresh_rate=100,
    resume_from_checkpoint=g.get_last_checkpoint_path()
)

In [None]:
g.trainer.logger.version

In [None]:
g.trainer.fit(model, ds.loaders['train'], ds.loaders['val'])

In [None]:
# TODO: Breaks torch_geometric for some reason!
#g.save_model()

In [None]:
g.save_model_text()

In [None]:
g.save_results('val')

In [None]:
#g.save_named_version('leaf-only')

## Example prediction

In [None]:
val_predictor = predictor.Predictor(ds, 'val', model)

In [None]:
val_predictor.evaluate(range(5))

In [None]:
val_predictor.get_example_texts(range(5))

## Live prediction

In [None]:
urls = [
    'https://www.cars.com/vehicledetail/81d8ee1f-155e-44ec-8ea4-0b25b0ca608a/'
]
live_pages = [live.Page(url) for url in urls]

In [None]:
# Download pages.
[page.dom for page in live_pages]

In [None]:
ds.add('live', live_pages)

In [None]:
live_predictor = predictor.Predictor(ds, 'live', model)

In [None]:
live_predictor.get_example_texts(range(len(live_pages)))