## Import

In [None]:
import importlib
from typing import TypeVar
import collections
import os
import itertools

import torch
from torch import nn, optim
from torch_geometric import data, loader
import numpy as np

from awe import features, html_utils, awe_graph
from awe.data import swde

for module in [swde, features, html_utils, awe_graph]:
    importlib.reload(module)

In [None]:
np.random.seed(42)
torch.manual_seed(42)

In [None]:
#swde.validate()

## Split data

In [None]:
T = TypeVar('T')
def train_val_split(data: list[T], val_split: float):
    split = int(np.floor(val_split * len(data)))
    copy = list(data)
    np.random.seed(42)
    np.random.shuffle(copy)
    return copy[split:], copy[:split]

In [None]:
vertical = swde.VERTICALS[0]
website = vertical.websites[0]
train_pages, val_pages = train_val_split(website.pages[:100], .2)
len(train_pages), len(val_pages)

## Prepare datapoints

In [None]:
def new_label_id_counter():
    counter = 0
    def new_label_id():
        nonlocal counter
        counter += 1
        return counter
    return new_label_id

def create_label_map():
    label_map = collections.defaultdict(new_label_id_counter())
    label_map[None] = 0
    return label_map

def prepare_data(
    pages: list[awe_graph.HtmlPage],
    label_map: dict[str, int]
):
    def get_node_features(node: awe_graph.HtmlNode):
        return [
            node.get_feature(features.DollarSigns).count,
            node.get_feature(features.Depth).relative
        ]

    def get_node_label(node: awe_graph.HtmlNode):
        # Only the first label for now.
        label = None if len(node.labels) == 0 else node.labels[0]
        return label_map[label]

    def prepare_page(page: awe_graph.HtmlPage):
        ctx = features.FeatureContext(page)
        ctx.add_all([
            features.DollarSigns,
            features.Depth
        ])
        x = torch.tensor(list(map(get_node_features, ctx.nodes)))
        y = torch.tensor(list(map(get_node_label, ctx.nodes)))
        return data.Data(x=x, y=y)

    return list(map(prepare_page, pages))

In [None]:
label_map = create_label_map()
train_data = prepare_data(train_pages, label_map)
label_map.default_factory = None # freeze label map
val_data = prepare_data(val_pages, label_map)
len(train_data), len(val_data)

In [None]:
label_map

## Create dataloaders

In [None]:
BATCH_SIZE = 1
train_dataloader = loader.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = loader.DataLoader(val_data, batch_size=BATCH_SIZE)

In [None]:
for batch in train_dataloader:
    print(batch)
    break

In [None]:
len(train_dataloader) + len(val_dataloader)

## Weight labels

In [None]:
def count_label(data: list[data.Data], label: int):
    return len([1 for d in data for y in d.y if y == label])

def count_labels(data: list[data.Data]):
    return [count_label(data, label) for label in label_map.values()]

In [None]:
label_counts = count_labels(train_data)
label_counts, len(train_data)

In [None]:
label_weights = [len(train_data) / count for count in label_counts]
label_weights

In [None]:
# Manual override
label_count = len(label_map)
label_weights = [1] + [100_000] * (label_count - 1)
label_weights

## Train a model

In [None]:
from awe import awe_model
import pytorch_lightning as pl
importlib.reload(awe_model)

In [None]:
model = awe_model.AweModel(label_count, label_weights)

In [None]:
trainer = pl.Trainer(
    max_epochs=50,
    progress_bar_refresh_rate=100,
    resume_from_checkpoint='lightning_logs/chkpt'
)
trainer.fit(model, train_dataloader, val_dataloader)

In [None]:
trainer.validate(model, val_dataloader)

In [None]:
trainer.save_checkpoint('lightning_logs/chkpt')

## Example prediction

In [None]:
def predict(index: int):
    example_batch = next(itertools.islice(val_dataloader, index, None))
    example_page = val_pages[index]
    example_nodes = list(example_page.nodes)

    predicted_nodes = []
    def handle(name: str, mask, idx=None):
        if name[1] == 'p':
            masked = itertools.compress(example_nodes, mask)
            node = next(itertools.islice(masked, idx, None))
            predicted_nodes.append(node)
    model.predict_swde(example_batch, label_map['model'], handle)

    return predicted_nodes

In [None]:
[predict(i)[0].element.text_content() for i in range(10)]