# Training

This is a notebook for interactive training.
Its non-interactive script counter-part is `train.py`.

In [None]:
import awe.utils
awe.utils.init_notebook()

In [None]:
import torchinfo

In [None]:
import awe.data.set.pages
import awe.training.versioning
import awe.training.params
import awe.training.trainer
awe.utils.reload('awe', exclude=['awe.data.glove'])

## Load parameters

This part assumes existence of `data/params.json` with desired hyper-parameters.
This fill will be created on first run with default values.

In [None]:
params = awe.training.params.Params.load_user(normalize=True)
params

In [None]:
# Compare difference to previously trained version if any.
latest_version = awe.training.versioning.Version.get_latest()
latest_params = awe.training.params.Params.load_version(latest_version) if latest_version else None
params.difference(latest_params) if latest_params else None

In [None]:
# Note that previous trainer is preserved in case this cell is executed again
# (possibly with changed backend code).
trainer = awe.training.trainer.Trainer(params, prev_trainer=globals().get('trainer'))

## Load data

In [None]:
trainer.load_dataset()

In [None]:
# This is only necessary if re-running training after changing data-loading
# params. Otherwise, loaded data are kept in cache and the used features might
# be inconsistent with hyper-parameters.
trainer.ds.clear_cache(awe.data.set.pages.ClearCacheRequest(
    # dom=False,
    # labels=False,
    # dom_dirty_flags=True,
))

In [None]:
trainer.init_features()

In [None]:
trainer.split_data()

In [None]:
trainer.create_dataloaders()

## Explore data

In [None]:
trainer.explore_data().iloc[:15]

In [None]:
trainer.explore_visuals()

## Train

In [None]:
trainer.create_model()
torchinfo.summary(trainer.model, verbose=0)

In [None]:
trainer.create_version()

In [None]:
trainer.train()

## Evaluate

In [None]:
# Evaluate on a few unseen pages.
test_pages = [p for w in trainer.val_websites for p in w.pages[100:120]]
trainer.validate(trainer.create_run(test_pages, desc='test'))

### Example predictions

In [None]:
# Choose a few pages to predict.
pred_pages = trainer.val_pages[10:20]
[p.html_path for p in pred_pages]

In [None]:
# Initialize features on those pages.
pred_run = trainer.create_run(pred_pages, desc='pred')

In [None]:
# Evaluate performance metrics on those pages.
trainer.validate(pred_run)

In [None]:
# Display predictions.
preds = trainer.predict(pred_run)
trainer.decode(preds)