In [None]:
import importlib

import numpy as np
import torch
import pytorch_lightning as pl

from awe import awe_model, gym, features
from awe.data import swde, dataset, data_module

## Load data

In [None]:
sds = swde.Dataset(suffix='-exact')
websites = sds.verticals[0].websites
rng = np.random.default_rng(42)
website_indices = rng.choice(len(websites), 5, replace=False)
train_pages = [
    p for i in website_indices
    for p in websites[i].pages
]
val_pages = [
    p for i in range(len(websites))
    if i not in website_indices
    for p in rng.choice(websites[i].pages, 50, replace=False)
]
ds = dataset.DatasetCollection()
ds.root.freeze()
ds.features = [
    features.Depth(),
    features.IsLeaf(),
    features.CharCategories(),
    features.Visuals(),
    features.CharIdentifiers(),
    features.WordIdentifiers()
]
ds.create('train', train_pages, shuffle=True)
ds.create('val_unseen', val_pages)
ds.create('val_seen', rng.choice(train_pages, 200, replace=False))
ds.create_dataloaders(batch_size=64, num_workers=8)
ds.get_lengths()

## Load model

In [None]:
label_count = len(ds.first_dataset.label_map)
label_weights = [1] + [300] * (label_count - 1)
label_weights

In [None]:
model = awe_model.AweModel(
    feature_count=ds.feature_dim,
    label_count=label_count,
    label_weights=label_weights,
    char_count=len(ds.root.chars) + 1,
    use_gnn=True,
    use_lstm=True,
    use_cnn=False,
    lstm_args={ 'bidirectional': True },
    filter_node_words=True,
    label_smoothing=0.1
)

In [None]:
g = gym.Gym(ds, model)
epoch = 9
g.restore_checkpoint = f'logs/26-all-visuals-version_16/checkpoints/epoch={epoch}-step=789.ckpt'
g.trainer = pl.Trainer(
    max_epochs=epoch,
    callbacks=[gym.CustomProgressBar(refresh_rate=10)],
    resume_from_checkpoint=g.get_last_checkpoint_path(),
    logger=g.create_logger(),
)
g.trainer.fit(model, data_module.DataModule(ds))

## Predict

In [None]:
from awe import predictor

_ = importlib.reload(predictor)

In [None]:
pred = predictor.Predictor(ds, 'train', model)

In [None]:
pred.evaluate(range(1))

In [None]:
count = 0
metrics = awe_model.SwdeMetrics(0, 0, 0).to_vector()
for batch in pred.ds[pred.name].loader:
    for label in pred.ds.first_dataset.label_map.values():
        if label != 0:
            metrics += pred.model.compute_swde_metrics(
                awe_model.ModelInputs(batch), label).to_vector()
            count += 1
    break # TODO: Temporary.
awe_model.SwdeMetrics.from_vector(metrics / count)

In [None]:
batch.num_graphs