In [None]:
import importlib
from typing import TypeVar, Optional
import collections

import torch
from torch.utils import data
import numpy as np

from awe import features, html_utils, awe_dataset, awe_graph
from awe.data import swde

for module in [swde, features, html_utils, awe_dataset, awe_graph]:
    importlib.reload(module)

In [None]:
np.random.seed(42)
torch.manual_seed(42)

In [None]:
#swde.validate()

## Split data

In [None]:
T = TypeVar('T')
def train_val_split(data: list[T], val_split: float):
    split = int(np.floor(val_split * len(data)))
    copy = list(data)
    np.random.seed(42)
    np.random.shuffle(copy)
    return copy[split:], copy[:split]

In [None]:
vertical = swde.VERTICALS[0]
website = vertical.websites[0]
train_pages, val_pages = train_val_split(website.pages[:100], .2)
len(train_pages), len(val_pages)

## Prepare datapoints

In [None]:
def new_label_id_counter():
    counter = 0
    def new_label_id():
        nonlocal counter
        counter += 1
        return counter
    return new_label_id

def create_label_map():
    label_map = collections.defaultdict(new_label_id_counter())
    label_map[None] = 0
    return label_map

def prepare_nodes(
    pages: list[awe_graph.HtmlPage],
    label_map: dict[str, int]
):
    def prepare_page(page: awe_graph.HtmlPage):
        ctx = features.FeatureContext(page)
        ctx.add_all([
            features.DollarSigns,
            features.Depth
        ])
        return ctx.nodes

    def prepare_node(node: awe_graph.HtmlNode):
        # x = features
        x = [
            node.get_feature(features.DollarSigns).count,
            node.get_feature(features.Depth).relative
        ]

        # y = label (only the first one for now)
        label = None if len(node.labels) == 0 else node.labels[0]
        y = label_map[label]

        return [x, y]

    return [prepare_node(node) for page in pages for node in prepare_page(page)]

In [None]:
label_map = create_label_map()
train_nodes = prepare_nodes(train_pages, label_map)
label_map.default_factory = None # freeze label map
val_nodes = prepare_nodes(val_pages, label_map)
len(train_nodes), len(val_nodes)

In [None]:
train_nodes

## Old

In [None]:
def page_transform(page: awe_graph.HtmlPage):
    ctx = features.FeatureContext(page)
    ctx.add_all([
        features.DollarSigns,
        features.Depth
    ])
    return ctx.nodes

def node_transform(node: awe_graph.HtmlNode):
    return node

def target_transform(labels: list[str]):
    return labels

def create_dataset(websites: list[swde.Website]):
    pages = [page for site in websites for page in site.pages]
    return awe_dataset.AweDataset(
        pages,
        page_transform,
        node_transform,
        target_transform
    )

In [None]:
#train_dataset = create_dataset(train_websites)
val_dataset = create_dataset(val_websites)

In [None]:
len(val_dataset)