In [None]:
import copy
from pathlib import Path
from time import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from tqdm.notebook import trange

from kgi import apply_kgi_to_model

In [None]:
# attempt to enable LaTeX rendering
# change to False if get an error during plotting (latex not installed)
plt.rcParams['text.usetex'] = True

In [None]:
def default_init_to_he_uniform(model):
    """ Change PyTorch's default initialization to He uniform """
    # PyTorch default uses 1/sqrt(m) as the bound
    # He uniform uses sqrt(3) / sqrt(m) as the bound
    sqrt3 = np.sqrt(3)
    for layer in model.modules():
        if isinstance(layer, torch.nn.Linear):
            layer.weight.data *= sqrt3
            layer.bias.data *= sqrt3


class MLP1d(nn.Module):
    def __init__(self, hidden_size, activation=torch.relu):
        super(MLP1d, self).__init__()
        self.fc1 = nn.Linear(1, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)
        self.act = activation
        default_init_to_he_uniform(self)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        x = self.act(x)
        x = self.fc3(x)
        return x

# Visualize initialized models

In [None]:
def find_knots(y, dx, threshold=1e-11):
    """ Find indices of knots in a curve """
    if len(y) < 2:
        return []
    # calculate the slopes between consecutive points
    slopes = (y[1:] - y[:-1]) / dx
    # find where the slopes change
    slope_changes = torch.abs(torch.diff(slopes))
    knots = torch.where(slope_changes > threshold)[0] + 1  # noqa
    return knots.tolist()

In [None]:
# use double for more accurate knot finding
torch.set_default_dtype(torch.float64)
torch.manual_seed(42)

# set number of curves
n_curve = 10

# hidden sizes to consider
n_hidden = [5, 20]

# input
x_in = torch.linspace(0, 1, 1000).unsqueeze(1)

fig, axes = plt.subplots(1, 2 * len(n_hidden), figsize=(10, 3), dpi=200)
plt.subplots_adjust(wspace=.15)
for i_n, n in enumerate(n_hidden):
    n_knots_def = 0
    n_knots_kgi = 0
    for i in range(n_curve):
        # default model
        model_def = MLP1d(n)
        model_def.eval()

        # KGI model
        model_kgi = copy.deepcopy(model_def)
        apply_kgi_to_model(model_kgi, knot_low=0.2, knot_high=0.8,
                           perturb_factor=0.2, kgi_by_bias=False)

        # forward
        with torch.no_grad():
            y_def = model_def(x_in).squeeze(1)
            y_kgi = model_kgi(x_in).squeeze(1)

        # find knots
        knots_def = find_knots(y_def, x_in[1] - x_in[0])
        knots_kgi = find_knots(y_kgi, x_in[1] - x_in[0])
        n_knots_def += len(knots_def)
        n_knots_kgi += len(knots_kgi)

        # plot
        axes[i_n + 0].plot(x_in, y_def, lw=1)
        axes[i_n + 0].scatter(x_in[knots_def], y_def[knots_def], marker="|", s=80, lw=.5)
        axes[i_n + 2].plot(x_in, y_kgi, lw=1)
        axes[i_n + 2].scatter(x_in[knots_kgi], y_kgi[knots_kgi], marker="|", s=80, lw=.5)

    # title
    axes[i_n + 0].text(x=0.5, y=-0.1,
                       s="(%s) No KGI, $H=%d$" % (chr(ord('a') + i_n), n),
                       fontsize=12, ha='center',
                       transform=axes[i_n + 0].transAxes)
    axes[i_n + 0].text(x=0.5, y=1.03,
                       s="$N_\\mathrm{knot}=%d$" % (round(n_knots_def / n_curve),),
                       fontsize=12, ha='center', va='bottom',
                       transform=axes[i_n + 0].transAxes)
    axes[i_n + 2].text(x=0.5, y=-0.1,
                       s="(%s) KGI, $H=%d$" % (chr(ord('c') + i_n), n,),
                       fontsize=12, ha='center',
                       transform=axes[i_n + 2].transAxes)
    axes[i_n + 2].text(x=0.5, y=1.03,
                       s="$N_\\mathrm{knot}=%d$" % (round(n_knots_kgi / n_curve),),
                       fontsize=12, ha='center', va='bottom',
                       transform=axes[i_n + 2].transAxes)

# setup
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlim(0, 1)
plt.savefig("figs/curve_knots.pdf", bbox_inches='tight', pad_inches=0.01)
plt.show()

# Curve fitting

### Target

In [None]:
torch.set_default_dtype(torch.float32)


def gaussian(x, mu, sigma, height):
    """ Gaussian shape """
    exponent = -((x - mu) ** 2) / (2 * sigma ** 2)
    return height * torch.exp(exponent)


# ground truth
x_in = torch.linspace(0, 1, 1000)
y_true = torch.zeros_like(x_in)
y_true += gaussian(x_in, 0.15, 0.005, 1)
y_true -= gaussian(x_in, 0.34, 0.003, .8)
y_true -= gaussian(x_in, 0.41, 0.003, .6)
y_true += gaussian(x_in, 0.62, 0.004, 1.2)
y_true += gaussian(x_in, 0.8, 0.002, 0.8)
plt.figure(figsize=(10 / 3, 8 / 3), dpi=200)
plt.xlim(0, 1)
plt.plot(x_in, y_true, lw=1)

# dataset
dataset = TensorDataset(x_in.unsqueeze(1), y_true.unsqueeze(1))
dataloader = DataLoader(dataset, batch_size=1000, shuffle=True)  # no need to batch

### Training

In [None]:
def train(kgi, hidden_size, seed, activation=torch.relu,
          num_epochs=30000, log_loss_every=100, device="cpu", pbar=True):
    """ Train a model """
    torch.manual_seed(seed)
    model = MLP1d(hidden_size, activation)
    if kgi:
        apply_kgi_to_model(model, knot_low=0.2, knot_high=0.8,
                           perturb_factor=0.2, kgi_by_bias=False)
    model.to(device)
    model.train()
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # training loop
    loss_hist = []
    for epoch in trange(num_epochs, disable=not pbar):
        running_loss = 0.0
        for batch_x, batch_y in dataloader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * batch_x.size(0)
        if (epoch + 1) % log_loss_every == 0:
            loss_hist.append(running_loss / len(dataset))
    return model, loss_hist

In [None]:
# train all models
reproduce_paper = False
if reproduce_paper:
    seeds = [0, 1, 2, 3, 4]
    hidden_sizes = [25, 50, 100, 200, 400, 600, 800, 1000, 1200]
    activations = [torch.nn.functional.relu, torch.nn.functional.leaky_relu,
                   torch.nn.functional.glu, torch.nn.functional.tanh]
    epochs = 30000
    device = "cuda"
else:
    seeds = [0]
    hidden_sizes = [200]
    activations = [torch.nn.functional.relu]
    epochs = 10000
    device = "cpu"

out_dir = Path("results/curve")
out_dir.mkdir(exist_ok=True, parents=True)
for act_ in activations:
    for seed_ in seeds:
        for hidden_size_ in hidden_sizes:
            for kgi_ in [False, True]:
                act_name = str(act_).split(" ")[1]
                name = f"{act_name}_{seed_}_{hidden_size_}_{kgi_}"
                t0 = time()
                _, hist = train(kgi_, hidden_size_, seed_, act_, device=device,
                                num_epochs=epochs, pbar=False)
                np.savetxt(out_dir / name, hist)
                print(f"{name} trained in {(time() - t0) / 60:.1f} min, loss={hist[-1]:.2e}")

In [None]:
# quick comparison
hist_def = np.loadtxt(out_dir / "relu_0_200_False")
hist_kgi = np.loadtxt(out_dir / "relu_0_200_True")
plt.plot(hist_def)
plt.plot(hist_kgi)
plt.show()

### Analysis

The following cells works only when `reproduce_paper` was set `True` for training.