# Introduction Figure

In [None]:
import matplotlib
from matplotlib import pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import torch

import pam_ops


In [None]:
x = torch.linspace(0, 8, 1000)
x_np = x.numpy(force=True)
c = 1.5

pam_mul = pam_ops.mul(x, c)
mul = x * c
pam_div = pam_ops.div(1, x)
div = 1 / x

pam_square = pam_ops.mul(x, x)
square = x * x
pam_sqrt = pam_ops.pow(x, 0.5)
sqrt = x ** 0.5

pam_exp2 = pam_ops.exp2(x)
exp2 = torch.exp2(x)
pam_log2 = pam_ops.log2(x)
log2 = torch.log2(x)


In [None]:
matplotlib.rcParams.update({'font.size': 12})
a = 0.5
w = 1

fig = plt.figure()
plt.plot(x_np, pam_mul.numpy(force=True), ls='--', linewidth=w, color='k')
l1, = plt.plot(x_np, mul.numpy(force=True), ls='-', color=plt.gca().lines[-1].get_color(), label='$1.5 \cdot x$', linewidth=w)
plt.plot(x_np, pam_div.numpy(force=True), ls='--', linewidth=w, color='tab:red')
l2, = plt.plot(x_np, div.numpy(force=True), ls='-', color=plt.gca().lines[-1].get_color(), label='$1 \,/\, x$', linewidth=w)
plt.plot(x_np, pam_exp2.numpy(force=True), ls='--', linewidth=w, color='tab:blue')
l3, = plt.plot(x_np, exp2.numpy(force=True), ls='-', color=plt.gca().lines[-1].get_color(), label='$\exp_2(x)$', linewidth=w)
plt.plot(x_np, pam_log2.numpy(force=True), ls='--', linewidth=w, color='darkorange')
l4, = plt.plot(x_np, log2.numpy(force=True), ls='-', color=plt.gca().lines[-1].get_color(), label='$\log_2(x)$', linewidth=w)

# plt.plot(x_np, pam_square.numpy(force=True), ls='--', linewidth=w)
# l5, = plt.plot(x_np, square.numpy(force=True), ls='-', color=plt.gca().lines[-1].get_color(), label='$x^2$', linewidth=w)
# plt.plot(x_np, pam_sqrt.numpy(force=True), ls='--', linewidth=w,)
# l6, = plt.plot(x_np, sqrt.numpy(force=True), ls='-', color=plt.gca().lines[-1].get_color(), label='$\sqrt{x}$', linewidth=w)

plt.ylim([x_np.min(), x_np.max()])
plt.xlim([x_np.min(), x_np.max()])
plt.legend()
plt.grid(alpha=a)

for line in [l1,l2,l3,l4]:
    line.set_alpha(a)
plt.xlabel('$x$')
plt.tight_layout()
plt.savefig('affine_overview.pdf')
plt.show()

# PAM Error Figure

One row or colum shows three subfigures:
* Contour plot for pam in the range [1:2]x[1:2]
* Contour plot for mul in the range [1:2]x[1:2]
* Relative error of pam and mul over the range [1:2]x[1:2]

In [None]:
import matplotlib
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import torch

import pam_ops

# https://stackoverflow.com/questions/32462881/add-colorbar-to-existing-axis
# https://jakevdp.github.io/PythonDataScienceHandbook/04.04-density-and-contour-plots.html

In [None]:
x1 = torch.linspace(1, 2, 101)
x2 = torch.linspace(1, 2, 101)
X1, X2 = torch.meshgrid(x1, x2, indexing='xy')

X1_exact, X2_exact = torch.clone(X1).requires_grad_(True), torch.clone(X2).requires_grad_(True)
pam_product = pam_ops.mul(X1_exact, X2_exact, approx_bwd=False)
pam_product.backward(torch.ones_like(pam_product))

X1_approx, X2_approx = torch.clone(X1).requires_grad_(True), torch.clone(X2).requires_grad_(True)
pam_product = pam_ops.mul(X1_approx, X2_approx, approx_bwd=True)
pam_product.backward(torch.ones_like(pam_product))

X1_true, X2_true = torch.clone(X1).requires_grad_(True), torch.clone(X2).requires_grad_(True)
product = X1_true * X2_true
product.backward(torch.ones_like(pam_product))

pam_product = pam_product.numpy(force=True)
product = product.numpy(force=True)
relative_error = 100*(pam_product-product)/product
X1, X2 = X1.numpy(force=True), X2.numpy(force=True)

X1_pam_grad_exact = X1_exact.grad.numpy(force=True)
X1_pam_grad_approx = X1_approx.grad.numpy(force=True)
X1_grad = X1_true.grad.numpy(force=True)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(9, 6))

matplotlib.rcParams.update({'font.size': 12})
cfs = 8
cis = 10

ax = axs[0]
cp = ax.contour(X1, X2, pam_product, 15, colors='white', alpha=0.5)
ax.clabel(cp, inline=True, inline_spacing=cis, fontsize=cfs)
im = ax.imshow(pam_product, extent=[1, 2, 1, 2], origin='lower', cmap='magma', alpha=1.0)
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
cb = fig.colorbar(im, cax=cax, orientation='vertical', ticks=[1, 4])
# cb.set_label(r'$x_1 \hat{\cdot} x_2$', labelpad=-12)
ax.set_yticks([1,2])
ax.set_xticks([1,2])
ax.set_xlabel('$x_1$', labelpad=-15)
ax.set_ylabel('$x_2$', labelpad=-12)

ax = axs[1]
cp = ax.contour(X1, X2, product, 15, colors='white', alpha=0.5)
ax.clabel(cp, inline=True, inline_spacing=cis, fontsize=cfs)
im = ax.imshow(product, extent=[1, 2, 1, 2], origin='lower', cmap='magma', alpha=1.0)
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
cb = fig.colorbar(im, cax=cax, orientation='vertical', ticks=[1, 4])
ax.set_yticks([1,2])
ax.set_xticks([1,2])
ax.set_xlabel('$x_1$', labelpad=-15)
ax.set_ylabel('$x_2$', labelpad=-12)

ax = axs[2]
cp = ax.contour(X1, X2, relative_error, np.linspace(-10, -2, 5), colors='black', negative_linestyles='solid', alpha=0.5)
ax.clabel(cp, inline=True, inline_spacing=cis, fontsize=cfs)
im = ax.imshow(relative_error, extent=[1, 2, 1, 2], origin='lower', cmap='Reds_r', alpha=1)
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
cb = fig.colorbar(im, cax=cax, orientation='vertical', ticks=[relative_error.min(), relative_error.max()], format="{x:0.1f}")
ax.set_yticks([1.0,2.0])
ax.set_xticks([1.0,2.0])
ax.set_xlabel('$x_1$', labelpad=-15)
ax.set_ylabel('$x_2$', labelpad=-12)


plt.tight_layout()
plt.savefig('pam_error.pdf', bbox_inches='tight')
plt.show()

# PAM Functions and Derivatives

In [None]:
# Primitives along with their real equivalents
c = 1.5
offset = 1.0
N_tensor = 32
dY_value = 1.25

func_dict = {
    'mul': [
        [f'${c} \cdot x$',
         lambda x, **kwargs: pam_ops.mul(x, torch.full_like(x, c), **kwargs),
         lambda x: c*x,
         -3, 3],
        [f'${c}/x$',
         lambda x, **kwargs: pam_ops.div(torch.full_like(x, c), x, **kwargs),
         lambda x: c/x,
         1/3, 3],
        [f'$x/{c}$',
         lambda x, **kwargs: pam_ops.div(x, torch.full_like(x, c), **kwargs),
         lambda x: x/c,
         1/3, 3],
        [f'$x^2$',
         lambda x, **kwargs: pam_ops.mul(x, x, **kwargs),
         lambda x: x*x,
         -3, 3],
        ['$\sqrt{x}$',
         lambda x, **kwargs: pam_ops.pow(x, 0.5, use_kernel=True, **kwargs),
         lambda x: x**0.5,
         0.1, 10],
    ],
    'exp': [
        ['$\\exp_2(x)$', pam_ops.exp2, torch.exp2, -3, 5],
        ['$\\exp(x)$', pam_ops.exp, torch.exp, -3, 5],
        ['$\\log_2(x)$', pam_ops.log2, torch.log2, 1.1, 5],
        ['$\ln(x)$', pam_ops.log, torch.log, 1.1, 100],
    ]
}


relative_error = lambda x, y: (x-y)/(np.abs(y)+1e-4)

for func_types, functions in func_dict.items():

    matplotlib.rcParams.update({'font.size': 10})
    fig, axs = plt.subplots(len(functions), 4, figsize=(12, len(functions)*3))
    lw = 1
    a = 0.5

    for idx, f in enumerate(functions):
        name, pam_f, torch_f, min_x, max_x = f

        N = 1000
        X = torch.linspace(min_x, max_x, N, device='cuda')
        dY = torch.full_like(X, dY_value)
        X_axis = X.numpy(force=True)

        Xe = torch.clone(X).requires_grad_(True)
        Ye = pam_f(Xe, approx_bwd=False, offset=offset)
        Ye.backward(dY)

        Xa = torch.clone(X).requires_grad_(True)
        Ya = pam_f(Xa, approx_bwd=True, offset=offset)
        Ya.backward(dY)

        Xt = torch.clone(X).requires_grad_(True)
        Yt = torch_f(Xt)
        Yt.backward(dY)

        ax = axs[idx, 0]
        ax.plot(X_axis, Yt.numpy(force=True), linewidth=lw, color='k', alpha=a)
        ax.plot(X_axis, Ye.numpy(force=True), linewidth=lw, color='k', linestyle='--')
        ax.grid(alpha=0.3)
        ax.set_title(name)
        # ax.set_yticks([min(Yt.numpy(force=True).min(), Ye.numpy(force=True).min()), max(Yt.numpy(force=True).max(), Ye.numpy(force=True).max())])
        # ax.set_xticks([X_axis.min(), X_axis.max()])
        # ax.set_yticks([])
        # ax.set_xticks([])
        # ax.yaxis.set_major_formatter(FormatStrFormatter('%.1e'))

        ax = axs[idx, 1]
        ax.plot(X_axis, relative_error(Ye.numpy(force=True), Yt.numpy(force=True)), linewidth=lw, color='k')
        ax.grid(alpha=0.3)
        ax.set_title('Relative error for ' + name)

        ax = axs[idx, 2]
        ax.plot(X_axis, Xt.grad.numpy(force=True), linewidth=lw, color='k', alpha=0.7, label='Standard')
        ax.plot(X_axis, Xe.grad.numpy(force=True), linewidth=lw, color='tab:blue', alpha=1, ls=':', label='Exact')
        ax.plot(X_axis, Xa.grad.numpy(force=True), linewidth=lw, color='tab:red', alpha=1, ls='--', label='Approximate')
        ax.grid(alpha=0.3)
        ax.set_title('Derivatives for ' + name)
        # ax.legend()

        ax = axs[idx, 3]
        ax.plot(X_axis, relative_error(Xe.grad.numpy(force=True), Xt.grad.numpy(force=True)), linewidth=lw, color='tab:blue', alpha=1, ls=':')
        ax.plot(X_axis, relative_error(Xa.grad.numpy(force=True), Xt.grad.numpy(force=True)), linewidth=lw, color='tab:red', alpha=1, ls='--')
        ax.grid(alpha=0.3)
        ax.set_title('Relative gradient error for ' + name)

    plt.tight_layout()
    plt.savefig(f'error_and_derivatives_{func_types}.pdf', bbox_inches='tight')
    plt.show()
    