## Import

In [None]:
import importlib
from typing import TypeVar, Optional
import collections

import torch
from torch import nn, optim
from torch.utils import data
import numpy as np

from awe import features, html_utils, awe_dataset, awe_graph
from awe.data import swde

for module in [swde, features, html_utils, awe_dataset, awe_graph]:
    importlib.reload(module)

In [None]:
np.random.seed(42)
torch.manual_seed(42)

In [None]:
#swde.validate()

## Split data

In [None]:
T = TypeVar('T')
def train_val_split(data: list[T], val_split: float):
    split = int(np.floor(val_split * len(data)))
    copy = list(data)
    np.random.seed(42)
    np.random.shuffle(copy)
    return copy[split:], copy[:split]

In [None]:
vertical = swde.VERTICALS[0]
website = vertical.websites[0]
train_pages, val_pages = train_val_split(website.pages[:100], .2)
len(train_pages), len(val_pages)

## Prepare datapoints

In [None]:
def new_label_id_counter():
    counter = 0
    def new_label_id():
        nonlocal counter
        counter += 1
        return counter
    return new_label_id

def create_label_map():
    label_map = collections.defaultdict(new_label_id_counter())
    label_map[None] = 0
    return label_map

def prepare_nodes(
    pages: list[awe_graph.HtmlPage],
    label_map: dict[str, int]
):
    def prepare_page(page: awe_graph.HtmlPage):
        ctx = features.FeatureContext(page)
        ctx.add_all([
            features.DollarSigns,
            features.Depth
        ])
        return ctx.nodes

    def prepare_node(node: awe_graph.HtmlNode):
        # x = features
        x = torch.tensor([
            node.get_feature(features.DollarSigns).count,
            node.get_feature(features.Depth).relative
        ])

        # y = label (only the first one for now)
        label = None if len(node.labels) == 0 else node.labels[0]
        y = label_map[label]

        return [x, y]

    return [prepare_node(node) for page in pages for node in prepare_page(page)]

In [None]:
label_map = create_label_map()
train_nodes = prepare_nodes(train_pages, label_map)
label_map.default_factory = None # freeze label map
val_nodes = prepare_nodes(val_pages, label_map)
len(train_nodes), len(val_nodes)

In [None]:
label_map

In [None]:
train_nodes

In [None]:
BATCH_SIZE = 64
train_dataloader = data.DataLoader(train_nodes, batch_size=BATCH_SIZE)
val_dataloader = data.DataLoader(val_nodes, batch_size=BATCH_SIZE)

In [None]:
for X, y in train_dataloader:
    print("X =", X)
    print("Y =", y)
    #print("Shape of X [N, C, H, W]: ", X.shape)
    #print("Shape of y: ", y.shape, y.dtype)
    break

## Create model

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(2, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, len(label_map))
        )

    def forward(self, x):
        return self.layers(x)

In [None]:
model = NeuralNetwork()
print(model)

## Train

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [None]:
def val(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    val_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            val_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    val_loss /= num_batches
    correct /= size
    print(f"Validation Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {val_loss:>8f} \n")

In [None]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    val(val_dataloader, model, loss_fn)
print("Done!")

## Old

In [None]:
def page_transform(page: awe_graph.HtmlPage):
    ctx = features.FeatureContext(page)
    ctx.add_all([
        features.DollarSigns,
        features.Depth
    ])
    return ctx.nodes

def node_transform(node: awe_graph.HtmlNode):
    return node

def target_transform(labels: list[str]):
    return labels

def create_dataset(websites: list[swde.Website]):
    pages = [page for site in websites for page in site.pages]
    return awe_dataset.AweDataset(
        pages,
        page_transform,
        node_transform,
        target_transform
    )

In [None]:
#train_dataset = create_dataset(train_websites)
val_dataset = create_dataset(val_websites)

In [None]:
len(val_dataset)