# Preparation

This notebook creates the necessary files to run the streamlit app (i.e. models
and databases).

First, train the digit recognition CNN, saving every version of the model along
the way.

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import os
import pandas as pd
import duckdb
import itertools


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64 // 4, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [9]:
# Takes 5-6mins on a decent machine.
def ensure_models_exist():
    base_save_path = "models/mnist_cnn_{epoch}.pt"

    # If the last one is saved, we assume all of them are
    if os.path.exists(base_save_path.format(epoch=14)):
        print("All models available")
        return

    # These are the default values for the CLI app.
    class args():
        epochs = 14
        torch.manual_seed(1)
        device = torch.device("cpu")
        batch_size = 64
        learning_rate = 1.0
        gamma = 0.7
        log_interval = 10
        dry_run = False

    args = args()

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    dataset1 = datasets.MNIST('../data', train=True, download=True,
                        transform=transform)
    dataset2 = datasets.MNIST('../data', train=False,
                        transform=transform)

    train_loader = torch.utils.data.DataLoader(dataset1, batch_size=args.batch_size)
    test_loader = torch.utils.data.DataLoader(dataset2, batch_size=args.batch_size)

    model = Net().to(args.device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.learning_rate)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

    torch.save(model.state_dict(), base_save_path.format(epoch=0))

    for epoch in range(1, args.epochs + 1):
        train(args, model, args.device, train_loader, optimizer, epoch)
        test(model, args.device, test_loader)
        scheduler.step()
        torch.save(model.state_dict(), base_save_path.format(epoch=epoch))

ensure_models_exist()

All models available


And now we'll set up the multimodel database and save all models. This code is
all taken from the "Eval - CNN" and the "Eval - multiple networks" notebooks.

In [10]:
con = duckdb.connect()

def batch_insert(con, generator, table, batch_size=8_000_000):
    while True:
        chunk = list(itertools.islice(generator, batch_size))
        if not chunk:
            break

        df = pd.DataFrame(chunk)
        con.execute(f"INSERT INTO {table} SELECT * FROM df")

id = 0

def get_contributing_points(output_point, kernel_size, stride=1, padding=0):
    x_out, y_out = output_point
    x_in_start = x_out * stride - padding
    y_in_start = y_out * stride - padding

    contributing_points = []
    for i in range(kernel_size):
        for j in range(kernel_size):
            x_in = x_in_start + i
            y_in = y_in_start + j
            contributing_points.append(((x_in, y_in), (i, j)))

    return contributing_points

def load_model_into_db(model, name):
    state_dict = model.state_dict()

    result = con.execute(
        "INSERT INTO model (name) VALUES ($name) RETURNING (id)",
        {"name": name}
    )

    (model_id,) = result.fetchone()
    # We have to hardcode this
    input_size = 28

    state_dict = model.state_dict()
    _, conv1_kernel_size, _ = state_dict['conv1.weight'][0].size()

    nodes = {}

    conv1_weight = state_dict['conv1.weight']
    conv1_bias = state_dict['conv1.bias']
    conv1_output_size = input_size - conv1_kernel_size + 1

    conv1_out_channels, conv1_in_channels, conv1_kernel_size, _ = conv1_weight.size()
    if conv1_in_channels != 1:
        raise Exception("Not handling >1 input channels for now")

    conv2_weight = state_dict['conv2.weight']
    conv2_bias = state_dict['conv2.bias']

    conv2_out_channels, conv2_in_channels, conv2_kernel_size, _ = conv2_weight.size()
    conv2_output_size = conv1_output_size - conv2_kernel_size + 1

    fc1_weight = state_dict['fc1.weight']
    fc1_output_size, fc1_input_size = fc1_weight.size()

    fc2_weight = state_dict['fc2.weight']
    fc2_output_size, fc2_input_size = fc2_weight.size()

    def node_generator(nodes):
        (max_id_in_db,) = con.execute("SELECT COALESCE(MAX(id), 0) FROM node").fetchone()
        node_idx = max_id_in_db + 1

        # Input nodes (1 channel for now)
        for y in range(0, input_size):
            for x in range(0, input_size):
                name = f"input.{x}.{y}"
                yield [node_idx, model_id, 0, name]
                nodes[name] = node_idx
                node_idx += 1

        # Conv1
        for y in range(0, conv1_output_size):
            for x in range(0, conv1_output_size):
                for c in range(0, conv1_out_channels):
                    name = f"conv1.{c}.{x}.{y}"
                    # The bias of this layer is simply the bias of the corresponding
                    # kernel
                    bias = conv1_bias[c]

                    yield [node_idx, model_id, bias.item(), name]
                    nodes[name] = node_idx
                    node_idx += 1

        # Conv2
        for y in range(0, conv2_output_size):
            for x in range(0, conv2_output_size):
                for c in range(0, conv2_out_channels):
                    name = f"conv2.{c}.{x}.{y}"
                    bias = conv2_bias[c]

                    yield [node_idx, model_id, bias.item(), name]
                    nodes[name] = node_idx
                    node_idx += 1

        # fc1
        for i in range(0, fc1_output_size):
            name = f"fc1.{i}"
            bias = state_dict['fc1.bias'][i]
            yield [node_idx, model_id, bias.item(), name]
            nodes[name] = node_idx
            node_idx += 1

        # fc2
        for i in range(0, fc2_output_size):
            name = f"fc2.{i}"
            bias = state_dict['fc2.bias'][i]
            yield [node_idx, model_id, bias.item(), name]
            nodes[name] = node_idx
            node_idx += 1

    def edge_generator(nodes):
        # Add the edges from input to conv1. Per channel, per output pixel of the
        # convolution, we have to match the 9 input pixels to it (for a 3x3 kernel)
        for c in range(0, conv1_out_channels):
            for y_conv in range(0, conv1_output_size):
                for x_conv in range(0, conv1_output_size):
                    # (x_conv, y_conv) is the position in the output channel. We can
                    # find the 9 matching input values from them.
                    for (p_in, p_kernel) in get_contributing_points((x_conv, y_conv), conv1_kernel_size):
                        (x_in, y_in) = p_in
                        (x_kernel, y_kernel) = p_kernel

                        # 0 corresponds to the input channel (which we only have one
                        # of).
                        kernel = conv1_weight[c][0]
                        weight = kernel[y_kernel][x_kernel]

                        src = nodes[f"input.{x_in}.{y_in}"]
                        dst = nodes[f"conv1.{c}.{x_conv}.{y_conv}"]

                        yield [model_id, src, dst, weight.item()]


        # Add the edges from conv1 to conv2. This is similar as connecting the input to
        # conv1, except that we have 2 input channels. Outputs are summed per output
        # channel.
        for c_out in range(0, conv2_out_channels):
            for y_conv2 in range(0, conv2_output_size):
                for x_conv2 in range(0, conv2_output_size):
                    for (p_in, p_kernel) in get_contributing_points((x_conv2, y_conv2), conv2_kernel_size):
                        (x_in, y_in) = p_in
                        (x_kernel, y_kernel) = p_kernel

                        for c_in in range(0, conv2_in_channels):
                            kernel = conv2_weight[c_out][c_in]
                            weight = kernel[y_kernel][x_kernel]

                            src = nodes[f"conv1.{c_in}.{x_in}.{y_in}"]
                            dst = nodes[f"conv2.{c_out}.{x_conv2}.{y_conv2}"]

                            yield [model_id, src, dst, weight.item()]

        # Connect conv2 to fc1.
        for c in range(0, conv2_out_channels):
            for y_conv in range(0, conv2_output_size):
                for x_conv in range(0, conv2_output_size):
                    for i in range(0, fc1_output_size):
                        # By adding the channel offset, we flatten.
                        channel_offset = c * conv2_output_size * conv2_output_size
                        weight = fc1_weight[i][y_conv * conv2_output_size + x_conv + channel_offset]

                        src = nodes[f"conv2.{c}.{x_conv}.{y_conv}"]
                        dst = nodes[f"fc1.{i}"]

                        yield [model_id, src, dst, weight.item()]

        # Connect fc1 to fc2.
        for i in range(0, fc2_input_size):
            for j in range(0, fc2_output_size):
                weight = fc2_weight[j][i]

                src = nodes[f"fc1.{i}"]
                dst = nodes[f"fc2.{j}"]

                yield [model_id, src, dst, weight.item()]

    batch_insert(con, node_generator(nodes), "node")
    batch_insert(con, edge_generator(nodes), "edge")


def create_db():
    save_path = 'dbs/cnn_multimodel.db'
    if os.path.exists(save_path):
        return

    con.execute("DROP TABLE IF EXISTS edge")
    con.execute("DROP TABLE IF EXISTS node")
    con.execute("DROP TABLE IF EXISTS model")

    con.execute("DROP SEQUENCE IF EXISTS seq_node")
    con.execute("CREATE SEQUENCE seq_node START 1")

    con.execute("DROP SEQUENCE IF EXISTS seq_model")
    con.execute("CREATE SEQUENCE seq_model START 1")

    con.execute(
        """
        CREATE TABLE model(
            id INTEGER PRIMARY KEY DEFAULT NEXTVAL('seq_model'),
            name TEXT
        )
        """
    )
    con.execute(
        """
        CREATE TABLE node(
            id INTEGER PRIMARY KEY DEFAULT NEXTVAL('seq_node'),
            model_id INTEGER,
            bias REAL,
            name TEXT
        )"""
    )
    con.execute(
        """
        CREATE TABLE edge(
            model_id INTEGER,
            src INTEGER,
            dst INTEGER,
            weight REAL
        )"""
    )
    con.execute(
        """
        CREATE TABLE input(
            input_set_id INTEGER,
            input_node_idx INTEGER,
            input_value REAL
        )"""
    )


    for epoch in range(0, 15):
        model_path = f"models/mnist_cnn_{epoch}.pt"
        model = Net()
        model.load_state_dict(torch.load(model_path, weights_only=True))
        load_model_into_db(model, f"Epoch {epoch}")

    con.execute(f"EXPORT DATABASE '{save_path}'")

create_db()

We create a separate database that only holds the single final
model.

In [11]:
con = duckdb.connect()

def load_or_create_database(model):
    save_path = "dbs/cnn_single.db"
    if os.path.exists(save_path):
        return

    con.execute("DROP TABLE IF EXISTS edge")
    con.execute("DROP TABLE IF EXISTS node")
    con.execute("DROP SEQUENCE IF EXISTS seq_node")
    con.execute("DROP TABLE IF EXISTS input")

    con.execute("CREATE SEQUENCE seq_node START 1")
    con.execute(
        """
        CREATE TABLE node(
            id INTEGER PRIMARY KEY DEFAULT NEXTVAL('seq_node'),
            bias REAL,
            name TEXT
        )"""
    )
    con.execute(
        """
        CREATE TABLE edge(
            src INTEGER,
            dst INTEGER,
            weight REAL
        )"""
    )

    con.execute(
        """
        CREATE TABLE input(
            input_set_id INTEGER,
            input_node_idx INTEGER,
            input_value REAL
        )"""
    )

    # We have to hardcode this
    input_size = 28

    state_dict = model.state_dict()
    _, conv1_kernel_size, _ = state_dict['conv1.weight'][0].size()

    node_idx = 1
    nodes = {}

    conv1_weight = state_dict['conv1.weight']
    conv1_bias = state_dict['conv1.bias']
    conv1_output_size = input_size - conv1_kernel_size + 1

    conv1_out_channels, conv1_in_channels, conv1_kernel_size, _ = conv1_weight.size()
    if conv1_in_channels != 1:
        raise Exception("Not handling >1 input channels for now")

    conv2_weight = state_dict['conv2.weight']
    conv2_bias = state_dict['conv2.bias']

    conv2_out_channels, conv2_in_channels, conv2_kernel_size, _ = conv2_weight.size()
    conv2_output_size = conv1_output_size - conv2_kernel_size + 1

    fc1_weight = state_dict['fc1.weight']
    fc1_output_size, fc1_input_size = fc1_weight.size()

    fc2_weight = state_dict['fc2.weight']
    fc2_output_size, fc2_input_size = fc2_weight.size()

    def node_generator(nodes):
        node_idx = 1

        # Input nodes (1 channel for now)
        for y in range(0, input_size):
            for x in range(0, input_size):
                name = f"input.{x}.{y}"
                yield [node_idx, 0, name]
                nodes[name] = node_idx
                node_idx += 1

        # Conv1
        for y in range(0, conv1_output_size):
            for x in range(0, conv1_output_size):
                for c in range(0, conv1_out_channels):
                    name = f"conv1.{c}.{x}.{y}"
                    # The bias of this layer is simply the bias of the corresponding
                    # kernel
                    bias = conv1_bias[c]

                    yield [node_idx, bias.item(), name]
                    nodes[name] = node_idx
                    node_idx += 1

        # Conv2
        for y in range(0, conv2_output_size):
            for x in range(0, conv2_output_size):
                for c in range(0, conv2_out_channels):
                    name = f"conv2.{c}.{x}.{y}"
                    bias = conv2_bias[c]

                    yield [node_idx, bias.item(), name]
                    nodes[name] = node_idx
                    node_idx += 1

        # fc1
        for i in range(0, fc1_output_size):
            name = f"fc1.{i}"
            bias = state_dict['fc1.bias'][i]
            yield [node_idx, bias.item(), name]
            nodes[name] = node_idx
            node_idx += 1

        # fc2
        for i in range(0, fc2_output_size):
            name = f"fc2.{i}"
            bias = state_dict['fc2.bias'][i]
            yield [node_idx, bias.item(), name]
            nodes[name] = node_idx
            node_idx += 1

    def edge_generator(nodes):
        # Add the edges from input to conv1. Per channel, per output pixel of the
        # convolution, we have to match the 9 input pixels to it (for a 3x3 kernel)
        for c in range(0, conv1_out_channels):
            for y_conv in range(0, conv1_output_size):
                for x_conv in range(0, conv1_output_size):
                    # (x_conv, y_conv) is the position in the output channel. We can
                    # find the 9 matching input values from them.
                    for (p_in, p_kernel) in get_contributing_points((x_conv, y_conv), conv1_kernel_size):
                        (x_in, y_in) = p_in
                        (x_kernel, y_kernel) = p_kernel

                        # 0 corresponds to the input channel (which we only have one
                        # of).
                        kernel = conv1_weight[c][0]
                        weight = kernel[y_kernel][x_kernel]

                        src = nodes[f"input.{x_in}.{y_in}"]
                        dst = nodes[f"conv1.{c}.{x_conv}.{y_conv}"]

                        yield [src, dst, weight.item()]


        # Add the edges from conv1 to conv2. This is similar as connecting the input to
        # conv1, except that we have 2 input channels. Outputs are summed per output
        # channel.
        for c_out in range(0, conv2_out_channels):
            for y_conv2 in range(0, conv2_output_size):
                for x_conv2 in range(0, conv2_output_size):
                    for (p_in, p_kernel) in get_contributing_points((x_conv2, y_conv2), conv2_kernel_size):
                        (x_in, y_in) = p_in
                        (x_kernel, y_kernel) = p_kernel

                        for c_in in range(0, conv2_in_channels):
                            kernel = conv2_weight[c_out][c_in]
                            weight = kernel[y_kernel][x_kernel]

                            src = nodes[f"conv1.{c_in}.{x_in}.{y_in}"]
                            dst = nodes[f"conv2.{c_out}.{x_conv2}.{y_conv2}"]

                            yield [src, dst, weight.item()]

        # Connect conv2 to fc1.
        for c in range(0, conv2_out_channels):
            for y_conv in range(0, conv2_output_size):
                for x_conv in range(0, conv2_output_size):
                    for i in range(0, fc1_output_size):
                        # By adding the channel offset, we flatten.
                        channel_offset = c * conv2_output_size * conv2_output_size
                        weight = fc1_weight[i][y_conv * conv2_output_size + x_conv + channel_offset]

                        src = nodes[f"conv2.{c}.{x_conv}.{y_conv}"]
                        dst = nodes[f"fc1.{i}"]

                        yield [src, dst, weight.item()]

        # Connect fc1 to fc2.
        for i in range(0, fc2_input_size):
            for j in range(0, fc2_output_size):
                weight = fc2_weight[j][i]

                src = nodes[f"fc1.{i}"]
                dst = nodes[f"fc2.{j}"]

                yield [src, dst, weight.item()]

    batch_insert(con, node_generator(nodes), "node")
    batch_insert(con, edge_generator(nodes), "edge")

    con.execute(f"EXPORT DATABASE '{save_path}'")

model_path = f"models/mnist_cnn_14.pt"
model = Net()
model.load_state_dict(torch.load(model_path, weights_only=True))
load_or_create_database(model)

Now we do something similar: we train multiple models, with ever decreasing
size. This way we can check if smaller models still yield good results.

In [12]:
class ShrinkingNet(nn.Module):
    def __init__(self, shrink_factor=0):
        super(ShrinkingNet, self).__init__()
        scale = 2 ** shrink_factor

        self.conv1 = nn.Conv2d(1, 32 // scale, 3, 1)
        self.conv2 = nn.Conv2d(32 // scale, 64 // (scale * 4), 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216 // scale, 128 // scale)
        self.fc2 = nn.Linear(128 // scale, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

for i in range(0, 5):
    base_save_path = "models/mnist_cnn_shrinking_{i}.pt"
    save_path = base_save_path.format(i=i)

    if os.path.exists(save_path):
        continue

    print(f"Training with shrinkfactor {i}")

    class args():
        epochs = 14
        torch.manual_seed(1)
        device = torch.device("cpu")
        batch_size = 64
        learning_rate = 1.0
        gamma = 0.7
        log_interval = 10
        dry_run = False

    args = args()

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    dataset1 = datasets.MNIST('../data', train=True, download=True,
                        transform=transform)
    dataset2 = datasets.MNIST('../data', train=False,
                        transform=transform)

    train_loader = torch.utils.data.DataLoader(dataset1, batch_size=args.batch_size)
    test_loader = torch.utils.data.DataLoader(dataset2, batch_size=args.batch_size)

    model = ShrinkingNet(i).to(args.device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.learning_rate)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

    for epoch in range(1, args.epochs + 1):
        train(args, model, args.device, train_loader, optimizer, epoch)
        test(model, args.device, test_loader)
        scheduler.step()

    torch.save(model.state_dict(), save_path)


In [13]:
def create_db():
    save_path = 'dbs/cnn_multimodel_size.db'
    if os.path.exists(save_path):
        return

    con.execute("DROP TABLE IF EXISTS edge")
    con.execute("DROP TABLE IF EXISTS node")
    con.execute("DROP TABLE IF EXISTS model")

    con.execute("DROP SEQUENCE IF EXISTS seq_node")
    con.execute("CREATE SEQUENCE seq_node START 1")

    con.execute("DROP SEQUENCE IF EXISTS seq_model")
    con.execute("CREATE SEQUENCE seq_model START 1")

    con.execute(
        """
        CREATE TABLE model(
            id INTEGER PRIMARY KEY DEFAULT NEXTVAL('seq_model'),
            name TEXT
        )
        """
    )
    con.execute(
        """
        CREATE TABLE node(
            id INTEGER PRIMARY KEY DEFAULT NEXTVAL('seq_node'),
            model_id INTEGER,
            bias REAL,
            name TEXT
        )"""
    )
    con.execute(
        """
        CREATE TABLE edge(
            model_id INTEGER,
            src INTEGER,
            dst INTEGER,
            weight REAL
        )"""
    )
    con.execute(
        """
        CREATE TABLE input(
            input_set_id INTEGER,
            input_node_idx INTEGER,
            input_value REAL
        )"""
    )


    for i in range(0, 5):
        model_path = f"models/mnist_cnn_shrinking_{i}.pt"
        model = ShrinkingNet(i)
        model.load_state_dict(torch.load(model_path, weights_only=True))
        labels = ["Regular", "2x smaller", "4x smaller", "8x smaller", "16x smaller"]
        load_model_into_db(model, labels[i])

    con.execute(f"EXPORT DATABASE '{save_path}'")

create_db()

## Basic eval

In [17]:
import torch
from torch import nn
from torch import optim
import os


class ReLUFNN(nn.Module):
    def __init__(
        self, input_size=1, hidden_size=4, num_hidden_layers=10, output_size=1
    ):
        super().__init__()
        layers = []

        layers.append(nn.Linear(input_size, hidden_size))
        layers.append(nn.ReLU())

        for _ in range(num_hidden_layers - 1):
            layers.append(nn.Linear(hidden_size, hidden_size))
            layers.append(nn.ReLU())

        layers.append(nn.Linear(hidden_size, output_size))

        self.linear_relu_stack = nn.Sequential(*layers)

    def forward(self, x):
        return self.linear_relu_stack(x)


def train(model, x_train, y_train, epochs=1000, save_path=None):
    if save_path and os.path.exists(save_path):
        model.load_state_dict(torch.load(save_path, weights_only=True))
        return

    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)

    x_train_tensor = ensure_tensor(x_train).unsqueeze(1)
    y_train_tensor = ensure_tensor(y_train).unsqueeze(1)

    num_epochs = epochs
    for _ in range(num_epochs):
        model.train()
        optimizer.zero_grad()
        outputs = model(x_train_tensor)
        loss = criterion(outputs, y_train_tensor)
        loss.backward()
        optimizer.step()

    if save_path:
        torch.save(model.state_dict(), save_path)


def ensure_tensor(tensor_or_array):
    if torch.is_tensor(tensor_or_array):
        return tensor_or_array
    else:
        return torch.tensor(tensor_or_array, dtype=torch.float32)

con = duckdb.connect()

def _initialize_database(save_path=None):
    if save_path and os.path.isdir(save_path):
        shutil.rmtree(save_path)

    con.execute("DROP TABLE IF EXISTS edge")
    con.execute("DROP TABLE IF EXISTS node")
    con.execute("DROP SEQUENCE IF EXISTS seq_node")
    con.execute("DROP TABLE IF EXISTS input")

    con.execute("CREATE SEQUENCE seq_node START 1")
    con.execute(
        """
        CREATE TABLE node(
            id INTEGER PRIMARY KEY DEFAULT NEXTVAL('seq_node'),
            bias REAL,
            name TEXT
        )"""
    )
    # Foreign keys are omitted for performance.
    con.execute(
        """
        CREATE TABLE edge(
            src INTEGER,
            dst INTEGER,
            weight REAL
        )"""
    )

    con.execute(
        """
        CREATE TABLE input(
            input_set_id INTEGER,
            input_node_idx INTEGER,
            input_value REAL
        )"""
    )


def load_pytorch_model_into_db(model, save_path=None):
    return load_state_dict_into_db(model.state_dict(), save_path)


def batch_insert(generator, table, batch_size=8_000_000):
    """
    Inserts data in batches into duckdb, to find a middle ground between
    performance and memory consumption. A batch size of 10M consumes ~4GB RAM.
    """
    while True:
        chunk = list(itertools.islice(generator, batch_size))
        if not chunk:
            break

        df = pd.DataFrame(chunk)
        con.execute(f"INSERT INTO {table} SELECT * FROM df")


def load_state_dict_into_db(state_dict, save_path=None):
    _initialize_database(save_path)

    # We keep the node IDs per layer in memory so we can insert the edges later on.
    node_ids = [[]]

    def nodes():
        # First, insert the input nodes.

        # Retrieves the input x weights matrix
        input_weights = list(state_dict.items())[0][1].tolist()
        num_input_nodes = len(input_weights[0])

        id = 0
        for i in range(0, num_input_nodes):
            id += 1
            yield [id, 0, f"input.{i}"]
            node_ids[0].append(id)

        layer = 0
        # In the first pass, insert all nodes with their biases
        for name, values in state_dict.items():
            # state_dict alternates between weight and bias tensors.
            if not "bias" in name:
                continue

            node_ids.append([])

            layer += 1
            for i, bias in enumerate(values.tolist()):
                id += 1
                yield [id, bias, f"{name}.{i}"]
                node_ids[layer].append(id)

    def edges():
        # In the second pass, insert all edges and their weights. This assumes a fully
        # connected network.
        layer = 0
        for name, values in state_dict.items():
            # state_dict alternates between weight and bias tensors.
            if not "weight" in name:
                continue

            # Each weight tensor has a list for each node in the next layer. The
            # elements of this list correspond to the nodes of the current layer.
            weight_tensor = values.tolist()
            for from_index, from_node in enumerate(node_ids[layer]):
                for to_index, to_node in enumerate(node_ids[layer + 1]):
                    weight = weight_tensor[to_index][from_index]
                    yield [from_node, to_node, weight]

            layer += 1

    batch_insert(nodes(), "node")
    batch_insert(edges(), "edge")

    if save_path:
        con.execute(f"EXPORT DATABASE '{save_path}'")

import torch
import numpy as np
import pandas as pd

# Add a manual seed for reproducibility.
torch.manual_seed(223)

# Define a simple function to train the network on.
def f(x):
    if x < 0:
        return 0
    elif 0 <= x < 5:
        return x
    elif 5 <= x < 10:
        return 10-x
    else:
        return 0

# The function only does interesting stuff between x=0 and x=10, so limit the
# training data to that area.
x_train = np.linspace(-5, 15, 1000)
y_train = np.array([f(x) for x in x_train])

model = ReLUFNN(input_size=1, hidden_size=2, num_hidden_layers=2, output_size=1)
train(model, x_train, y_train, save_path="models/basic_eval.pt")
load_pytorch_model_into_db(model)


## PWL model

In [18]:
import math

torch.manual_seed(1)

x_train = np.linspace(-2*math.pi, 2*math.pi, 10000)
y_train = np.array([math.sin(x) for x in x_train])

model = ReLUFNN(input_size=1, output_size=1, hidden_size=1000, num_hidden_layers=1)
train(model, x_train, y_train, epochs=750, save_path="models/pwl_geometric_sine.pt")

model.eval()
with torch.no_grad():
    predicted = model(torch.tensor(x_train, dtype=torch.float32).unsqueeze(1)).detach().numpy()

load_pytorch_model_into_db(model, save_path="dbs/pwl_geometric_sine.db")