## Import

In [None]:
import importlib
from typing import TypeVar
import collections
import os
import itertools

import torch
from torch import nn, optim
from torch_geometric import data, loader
import numpy as np

from awe import features, html_utils, awe_graph
from awe.data import swde, live, dataset

for module in [swde, live, dataset, features, html_utils, awe_graph]:
    importlib.reload(module)

In [None]:
np.random.seed(42)
torch.manual_seed(42)

In [None]:
#swde.validate()

## Split data

In [None]:
T = TypeVar('T')
def train_val_split(data: list[T], val_split: float):
    split = int(np.floor(val_split * len(data)))
    copy = list(data)
    np.random.seed(42)
    np.random.shuffle(copy)
    return copy[split:], copy[:split]

In [None]:
vertical = swde.VERTICALS[0]
#train_pages, val_pages = train_val_split(website.pages[:100], .2)
train_pages = vertical.websites[0].pages[:100] + vertical.websites[1].pages[:100]
val_pages = vertical.websites[2].pages[:100]
len(train_pages), len(val_pages)

## Prepare datapoints

In [None]:
ds = dataset.Dataset()
ds.add('train', train_pages)
ds.add('val', val_pages)
len(ds.data['train']), len(ds.data['val'])

In [None]:
ds.feature_count

In [None]:
ds.label_map

## Create dataloaders

In [None]:
BATCH_SIZE = 1
train_dataloader = loader.DataLoader(ds.data['train'], batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = loader.DataLoader(ds.data['val'], batch_size=BATCH_SIZE)

In [None]:
for batch in train_dataloader:
    print(batch)
    break

In [None]:
len(train_dataloader) + len(val_dataloader)

## Weight labels

In [None]:
def count_label(data: list[data.Data], label: int):
    return len([1 for d in data for y in d.y if y == label])

def count_labels(data: list[data.Data]):
    return [count_label(data, label) for label in ds.label_map.values()]

In [None]:
label_counts = count_labels(ds.data['train'])
label_counts, len(ds.data['train'])

In [None]:
label_weights = [len(ds.data['train']) / count for count in label_counts]
label_weights

In [None]:
# Manual override
label_count = len(ds.label_map)
label_weights = [1] + [100_000] * (label_count - 1)
label_weights

## Train a model

In [None]:
from awe import awe_model
import pytorch_lightning as pl
importlib.reload(awe_model)

In [None]:
model = awe_model.AweModel(ds.feature_count, label_count, label_weights)

In [None]:
trainer = pl.Trainer(
    max_epochs=50,
    progress_bar_refresh_rate=100,
    resume_from_checkpoint='lightning_logs/chkpt'
)
trainer.fit(model, train_dataloader, val_dataloader)

In [None]:
_ = trainer.validate(model, val_dataloader)

In [None]:
trainer.save_checkpoint('lightning_logs/chkpt')

## Example prediction

In [None]:
def get_example(dataloader: loader.DataLoader, pages: list[awe_graph.HtmlPage], index: int):
    example_batch = next(itertools.islice(dataloader, index, None))
    example_page = pages[index]
    example_nodes = list(example_page.nodes)
    return example_batch, example_nodes

In [None]:
def evaluate(dataloader: loader.DataLoader, pages: list[awe_graph.HtmlPage], index: int, label: str):
    batch, _ = get_example(dataloader, pages, index)
    return model.compute_swde_metrics(batch, ds.label_map[label])

In [None]:
[{ label: evaluate(val_dataloader, val_pages, i, label) for label in ds.label_map if label is not None } for i in range(5)]

In [None]:
def predict(dataloader: loader.DataLoader, pages: list[awe_graph.HtmlPage], index: int, label: str) -> list[awe_graph.HtmlNode]:
    batch, nodes = get_example(dataloader, pages, index)

    predicted_nodes = []
    def handle(name: str, mask, idx=None):
        if name[1] == 'p':
            masked = itertools.compress(nodes, mask)
            node = next(itertools.islice(masked, idx, None))
            predicted_nodes.append(node)
    model.predict_swde(batch, ds.label_map[label], handle)

    return predicted_nodes

In [None]:
def get_example_texts(dataloader: loader.DataLoader, pages: list[awe_graph.HtmlPage], indices):
    return { label: [[node.text_content for node in predict(dataloader, pages, i, label)] for i in indices] for label in ds.label_map if label is not None }

In [None]:
get_example_texts(val_dataloader, val_pages, range(5))

## Live prediction

In [None]:
urls = [
    'https://www.cars.com/vehicledetail/81d8ee1f-155e-44ec-8ea4-0b25b0ca608a/'
]
live_pages = [live.Page(url) for url in urls]

In [None]:
# Download pages.
[page.dom for page in live_pages]

In [None]:
ds.add('live', live_pages)

In [None]:
live_dataloader = loader.DataLoader(ds.data['live'])

In [None]:
for batch in live_dataloader:
    print(batch)
    break

In [None]:
len(live_dataloader)

In [None]:
get_example_texts(live_dataloader, live_pages, range(len(live_dataloader)))