In [1]:
from typing import List, Tuple, Dict
import torch
from torch import Tensor
from tqdm import tqdm
import os
from torch.utils.data import Dataset, DataLoader
from layers import (
    BinaryTreeActivation,
    BinaryTreeConv,
    BinaryTreeInstanceNorm,
    BinaryTreeAdaptivePooling,
)
from regressor import BinaryTreeSequential, BinaryTreeRegressor
from torch import nn
from torch.optim import lr_scheduler

# Creating DataLoader

In [2]:
# code from https://github.com/zinchse/hbo_bench/blob/main/dataset.py

def paddify_sequences(sequences: "List[Tensor]", target_length: "int") -> "List[Tensor]":
    """
    Pads sequences to make them of equal length.
    """
    padded_sequences = []
    n_channels = sequences[0].shape[1]
    for seq in sequences:
        padding_tokens = torch.zeros((target_length - len(seq), n_channels), dtype=seq.dtype, device=seq.device)
        padded_seq = torch.cat((seq, padding_tokens), dim=0)
        padded_sequences.append(padded_seq)
    return padded_sequences


class WeightedBinaryTreeDataset(Dataset):
    def __init__(
        self,
        list_vertices: "List[Tensor]",
        list_edges: "List[Tensor]",
        list_time: "List[Tensor]",
        device: "torch.device",
    ):
        """
        An iterator over <tensor of vectorized tree nodes, tree structure, frequency execution time>
        with the ability to move data to the specified device.
        """
        self.data_dict: "Dict[Tuple, Dict]" = {}

        for vertices, edges, time in zip(list_vertices, list_edges, list_time):
            key = str(vertices.flatten().tolist()), str(edges.flatten().tolist())
            if key in self.data_dict:
                self.data_dict[key]["freq"] += 1
                self.data_dict[key]["time"].append(time)
            else:
                self.data_dict[key] = {"vertices": vertices, "edges": edges, "time": [time], "freq": 1}

        self.list_vertices = [v["vertices"] for v in self.data_dict.values()]
        self.list_edges = [v["edges"] for v in self.data_dict.values()]
        self.list_time = [torch.stack(v["time"]).mean() for v in self.data_dict.values()]
        self.list_frequencies = [torch.tensor(v["freq"]) for v in self.data_dict.values()]
        self.size = len(self.data_dict)
        self.device = device
        self.move_to_device()

    def move_to_device(self) -> "None":
        for idx in range(self.size):
            self.list_vertices[idx] = self.list_vertices[idx].to(device=self.device)
            self.list_edges[idx] = self.list_edges[idx].to(device=self.device)
            self.list_frequencies[idx] = self.list_frequencies[idx].to(device=self.device)
            self.list_time[idx] = self.list_time[idx].to(device=self.device)

    def __len__(self) -> "int":
        return self.size

    def __getitem__(self, idx) -> "Tuple[Tensor, Tensor, Tensor, Tensor]":
        return self.list_vertices[idx], self.list_edges[idx], self.list_frequencies[idx], self.list_time[idx]

def weighted_binary_tree_collate(
    batch: "List[Tuple[Tensor, Tensor, Tensor, Tensor]]", target_length: "int"
) -> "Tuple[Tuple[Tensor, Tensor, Tensor], Tensor]":
    """
    Adds padding to equalize lengths, changes the number of axes and
    their order to make neural network inference more suitable.
    """
    list_vertices, list_edges, list_freq, list_time = [], [], [], []
    for vertices, edges, freq, time in batch:
        list_vertices.append(vertices)
        list_edges.append(edges)
        list_freq.append(freq)
        list_time.append(time)

    batch_vertices = torch.stack(paddify_sequences(list_vertices, target_length)).transpose(1, 2)
    batch_edges = torch.stack(paddify_sequences(list_edges, target_length)).unsqueeze(1)
    batch_freq = torch.stack(list_freq)
    return (batch_vertices, batch_edges, batch_freq), torch.stack(list_time)


In [3]:
root_node, l_child_node, ll_child_node, rl_child_node = (
    [1.0, 1.0],
    [1.0, -1.0],
    [-1.0, -1.0],
    [1.0, 1.0],
)
vertices = torch.tensor([root_node, l_child_node, ll_child_node, rl_child_node])

padding_idx, root, l_child, ll_child, rl_child = (0, 1, 2, 3, 4)
edges = torch.tensor(
    [
        [root, l_child, padding_idx],
        [l_child, ll_child, rl_child],
        [ll_child, padding_idx, padding_idx],
        [rl_child, padding_idx, padding_idx],
    ],
    dtype=torch.long,
)

time = torch.tensor(42.0)

In [4]:
batch_size, dataset_size = 8, 8
device = torch.device("cpu")
dataloader = DataLoader(
    dataset=WeightedBinaryTreeDataset(
        [vertices] * dataset_size, [edges] * dataset_size, [time] * dataset_size, device
    ),
    batch_size=batch_size,
    shuffle=True,
    collate_fn=lambda el: weighted_binary_tree_collate(el, 10),
    drop_last=False,
)

# Building NN Architecture

In [5]:
model = BinaryTreeRegressor(
    btcnn=BinaryTreeSequential(
        BinaryTreeConv(2, 128),
        BinaryTreeInstanceNorm(128),
        BinaryTreeActivation(torch.nn.functional.leaky_relu),
        BinaryTreeAdaptivePooling(torch.nn.AdaptiveMaxPool1d(1))
    ),
    fcnn=nn.Sequential(
        nn.Linear(128, 1),
        nn.Softplus(),
    ),
    name="SimpleBTCNNRegressor",
    device=device,
)

# Training

## helpers

In [6]:
def save_ckpt(
    model: "BinaryTreeRegressor", optimizer: "Optimizer", scheduler: "ReduceLROnPlateau", epoch: "int", path: "str"
) -> "None":
    state = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
    }
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(state, path)

In [7]:
def load_model(model: "BinaryTreeRegressor", path: "str", device: "torch.device") -> "BinaryTreeRegressor":
    ckpt_path = path
    ckpt_state = torch.load(ckpt_path, map_location=device, weights_only=True)
    model.load_state_dict(ckpt_state["model_state_dict"])
    model = model.to(device)
    model.device = device
    return model

In [8]:
def set_seed(seed: "int") -> "None":
    torch.manual_seed(seed)

In [9]:
def calculate_loss(
    model: "BinaryTreeRegressor",
    optimizer: "Optimizer",
    criterion: "nn.Module",
    dataloader: "DataLoader",
    train_mode: "bool" = True,
) -> "float":
    _ = model.train() if train_mode else model.eval()
    running_loss, total_samples = 0.0, 0
    for (vertices, edges, freq), time in dataloader:
        if train_mode:
            optimizer.zero_grad()

        outputs = model(vertices, edges)
        weighted_loss = (freq.float().squeeze(-1) * criterion(outputs.squeeze(-1), time)).mean()

        if train_mode:
            weighted_loss.backward()
            optimizer.step()

        running_loss += weighted_loss.item() * vertices.size(0)
        total_samples += freq.sum()
    return running_loss / total_samples

In [10]:
def weighted_train_loop(
    model: "BinaryTreeRegressor",
    optimizer: "Optimizer",
    criterion: "nn.Module",
    scheduler: "ReduceLROnPlateau",
    train_dataloader: "DataLoader",
    num_epochs: "int",
    start_epoch: "int" = 0,
    ckpt_period: "int" = 10,
    path_to_save: "Optional[str]" = None,
) -> "None":
    tqdm_desc = "Initialization"
    progress_bar = tqdm(range(start_epoch + 1, start_epoch + num_epochs + 1), desc=tqdm_desc, leave=True, position=0)
    for epoch in progress_bar:
        train_loss = calculate_loss(model, optimizer, criterion, train_dataloader)
        scheduler.step(train_loss)
        progress_bar.set_description(f"[{epoch}/{start_epoch + num_epochs}] MSE: {train_loss:.4f}")
        if path_to_save and not epoch % ckpt_period:
            save_ckpt(model, optimizer, scheduler, epoch, path_to_save)


## loop

In [11]:
lr, epochs = 3e-4, 1000

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=20)
set_seed(42)

weighted_train_loop(
    model=model,
    optimizer=optimizer,
    criterion=nn.MSELoss(reduction="none"),
    scheduler=scheduler,
    train_dataloader=dataloader,
    num_epochs=epochs,
    ckpt_period=epochs,
    path_to_save=f"/tmp/{model.name}.pth",
)

[1000/1000] MSE: 0.0000: 100%|██████████| 1000/1000 [00:01<00:00, 858.52it/s]


In [12]:
final_loss = calculate_loss(
    model=model,
    optimizer=optimizer,
    criterion=nn.MSELoss(reduction="none"),
    dataloader=dataloader,
    train_mode=False,
)
assert final_loss < 1e-3, "Problems with fitting"

In [13]:
model = load_model(model, f"/tmp/{model.name}.pth", device)
final_loss_after_reloading = calculate_loss(
    model=model,
    optimizer=optimizer,
    criterion=nn.MSELoss(reduction="none"),
    dataloader=dataloader,
    train_mode=False,
)
assert abs(final_loss - final_loss_after_reloading) < 1e-3, "Inconsistency after reloading"