The goal of this notebook is to implement other ops, in particular:
* Division
* Exponents and logarithms (arbitrary base)
* Normalization
* Square root
* Softmax
* Softmax cross entropy

In [None]:
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))

In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
import pam
import cuda_bindings
import native

In [None]:
# Primitives along with their real equivalents
c = 1.7
offset = 1.0
N_tensor = 32

functions = [
    [f'{c}*x',
     lambda x, **kwargs: cuda_bindings.pam(x, torch.full_like(x, c), **kwargs),
     lambda x, **kwargs: native.pam(x, torch.full_like(x, c), **kwargs),
     lambda x: c*x,
     -3, 3],
    [f'x^2',
     lambda x, **kwargs: cuda_bindings.pam(x, x, **kwargs),
     lambda x, **kwargs: native.pam(x, x, **kwargs),
     lambda x: x*x,
     -3, 3],
    [f'{c}/x',
     lambda x, **kwargs: cuda_bindings.pad(torch.full_like(x, c), x, **kwargs),
     lambda x, **kwargs: native.pad(torch.full_like(x, c), x, **kwargs),
     lambda x: c/x,
     1/3, 3],
    [f'x/{c}',
     lambda x, **kwargs: cuda_bindings.pad(x, torch.full_like(x, c), **kwargs),
     lambda x, **kwargs: native.pad(x, torch.full_like(x, c), **kwargs),
     lambda x: x/c,
     1/3, 3],
    ['exp2', cuda_bindings.pa_exp2, native.pa_exp2, torch.exp2, -3, 5],
    ['exp', cuda_bindings.pa_exp, native.pa_exp, torch.exp, -3, 5],
    ['log2', cuda_bindings.pa_log2, native.pa_log2, torch.log2, 0.5, 5],
    ['ln', cuda_bindings.pa_log, native.pa_log, torch.log, 0.5, 5],
    ['sqrt',
     lambda x, **kwargs: pam.pow(x, 0.5, use_kernel=True, **kwargs),
     lambda x, **kwargs: pam.pow(x, 0.5, use_kernel=False, **kwargs),
     lambda x: x**0.5,
     0.1, 10],
    ['softmax',
     lambda x, **kwargs: pam.softmax(x, use_kernel=True, **kwargs),
     lambda x, **kwargs: pam.softmax(x, use_kernel=False, **kwargs),
     lambda x: torch.nn.functional.softmax(x, dim=0),
     -2, 2],
    ['log_softmax',
     lambda x, **kwargs: pam.log_softmax(x, use_kernel=True, **kwargs),
     lambda x, **kwargs: pam.log_softmax(x, use_kernel=False, **kwargs),
     lambda x: torch.nn.functional.log_softmax(x, dim=0),
     -2, 2],
    ['layer_norm',
     lambda x, **kwargs: pam.layer_norm(x, N_tensor, use_kernel=True, **kwargs),
     lambda x, **kwargs: pam.layer_norm(x, N_tensor, use_kernel=False, **kwargs),
     lambda x: torch.nn.functional.layer_norm(x, (N_tensor,)),
     -5, 5],
]

for f in functions:
    name, cuda_f, native_f, real_f, min_x, max_x = f
    if name in ['softmax', 'log_softmax', 'layer_norm']:
        X = torch.rand(N_tensor, device='cuda') * (max_x - min_x) + min_x
        dY = torch.randn_like(X)
        X_axis = torch.arange(N_tensor).numpy()
    else:
        N = 1000
        X = torch.linspace(min_x, max_x, N, device='cuda')
        dY = torch.full_like(X, 1.2)
        X_axis = X.numpy(force=True)

    X1 = torch.clone(X).requires_grad_(True)
    Y1 = cuda_f(X1, offset=offset)
    Y1.backward(dY)
    X1a = torch.clone(X).requires_grad_(True)
    Y1a = cuda_f(X1a, approx_bwd=True, offset=offset)
    Y1a.backward(dY)


    X2 = torch.clone(X).requires_grad_(True)
    Y2 = native_f(X2, offset=offset)
    Y2.backward(dY)
    X2a = torch.clone(X).requires_grad_(True)
    Y2a = native_f(X2a, approx_bwd=True, offset=offset)
    Y2a.backward(dY)
    
    X3 = torch.clone(X).requires_grad_(True)
    Y3 = real_f(X3)
    Y3.backward(dY)
    
    plt.figure(figsize=(9, 3))
    plt.subplot(131)
    plt.plot(X_axis, Y1.numpy(force=True), 'k', label='cuda',)
    plt.plot(X_axis, Y2.numpy(force=True), 'r:', label='native')
    plt.plot(X_axis, Y3.numpy(force=True), 'b--', label='real')
    plt.ylabel(name)
    # plt.yscale('log')
    plt.grid()
    
    plt.subplot(132)
    plt.plot(X_axis, X1.grad.numpy(force=True), 'k', label='cuda')
    plt.plot(X_axis, X2.grad.numpy(force=True), 'r:', label='native')
    plt.plot(X_axis, X3.grad.numpy(force=True), 'b--', label='real')
    plt.ylabel('d/dx ' + name)
    # plt.yscale('log')
    plt.grid()

    plt.subplot(133)
    plt.plot(X_axis, X1a.grad.numpy(force=True), 'k', label='cuda')
    plt.plot(X_axis, X2a.grad.numpy(force=True), 'r:', label='native')
    plt.plot(X_axis, X3.grad.numpy(force=True), 'b--', label='real')
    plt.ylabel('approx d/dx ' + name)
    # plt.yscale('log')
    plt.grid()
    plt.legend()
    
    plt.tight_layout()
    
