## Import

In [None]:
import importlib
from typing import TypeVar
import collections
import os

import torch
from torch import nn, optim
from torch.utils import data
import numpy as np

from awe import features, html_utils, awe_graph
from awe.data import swde

for module in [swde, features, html_utils, awe_graph]:
    importlib.reload(module)

In [None]:
np.random.seed(42)
torch.manual_seed(42)

In [None]:
#swde.validate()

## Split data

In [None]:
T = TypeVar('T')
def train_val_split(data: list[T], val_split: float):
    split = int(np.floor(val_split * len(data)))
    copy = list(data)
    np.random.seed(42)
    np.random.shuffle(copy)
    return copy[split:], copy[:split]

In [None]:
vertical = swde.VERTICALS[0]
website = vertical.websites[0]
train_pages, val_pages = train_val_split(website.pages[:100], .2)
len(train_pages), len(val_pages)

## Prepare datapoints

In [None]:
def new_label_id_counter():
    counter = 0
    def new_label_id():
        nonlocal counter
        counter += 1
        return counter
    return new_label_id

def create_label_map():
    label_map = collections.defaultdict(new_label_id_counter())
    label_map[None] = 0
    return label_map

def prepare_nodes(
    pages: list[awe_graph.HtmlPage],
    label_map: dict[str, int]
):
    def prepare_page(page: awe_graph.HtmlPage):
        ctx = features.FeatureContext(page)
        ctx.add_all([
            features.DollarSigns,
            features.Depth
        ])
        return ctx.nodes

    def prepare_node(node: awe_graph.HtmlNode):
        # x = features
        x = torch.tensor([
            node.get_feature(features.DollarSigns).count,
            node.get_feature(features.Depth).relative
        ])

        # y = label (only the first one for now)
        label = None if len(node.labels) == 0 else node.labels[0]
        y = label_map[label]

        return [x, y]

    return [prepare_node(node) for page in pages for node in prepare_page(page)]

In [None]:
label_map = create_label_map()
train_nodes = prepare_nodes(train_pages, label_map)
label_map.default_factory = None # freeze label map
val_nodes = prepare_nodes(val_pages, label_map)
len(train_nodes), len(val_nodes)

In [None]:
label_map

In [None]:
# Convert labels to one-hot encoding.
label_count = len(label_map)
def to_one_hot(nodes: list[list]):
    return [
        [x, torch.FloatTensor(label_count).zero_().scatter_(0, torch.tensor([y]), 1)]
        for x, y in nodes
    ]
train_nodes_oh = to_one_hot(train_nodes)
val_nodes_oh = to_one_hot(val_nodes)

## Create dataloaders

In [None]:
cpu_count = os.cpu_count()
workers = 0 # cpu_count // 2
workers

In [None]:
BATCH_SIZE = 64
train_dataloader = data.DataLoader(train_nodes_oh, batch_size=BATCH_SIZE, num_workers=workers)
val_dataloader = data.DataLoader(val_nodes_oh, batch_size=BATCH_SIZE, num_workers=workers)

In [None]:
for X, y in train_dataloader:
    # print("X =", X)
    # print("Y =", y)
    print("Shape of X: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

In [None]:
len(train_dataloader) + len(val_dataloader)

## Analyze data

In [None]:
def count_label(data: list[list], label: int):
    return len([1 for _, l in data if l == label])

def count_labels(data: list[list]):
    return [count_label(data, label) for label in label_map.values()]

In [None]:
label_counts = count_labels(train_nodes)
label_counts, len(train_nodes)

In [None]:
label_weights = [len(train_nodes) / count for count in label_counts]
label_weights

In [None]:
# Manual override
label_weights = [1] + [100_000] * (label_count - 1)
label_weights

## Train a model

In [None]:
from awe import awe_model
import pytorch_lightning as pl
importlib.reload(awe_model)

In [None]:
model = awe_model.AweModel(label_count, label_weights)
trainer = pl.Trainer(max_epochs=10, progress_bar_refresh_rate=100)

In [None]:
trainer.fit(model, train_dataloader, val_dataloader)