# Activation Functions

![logo](../images/logo-poster.png)

In [None]:
%run supportvectors-common.ipynb

## The imports

In [19]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

plt.rcParams['text.usetex'] = False
plt.rcParams['font.family'] = 'DejaVu Sans' 

## Function for plotting the activation functions

In [20]:
def plot_activation_function(activation_fn, x_range=(-10, 10), num_points=1000, title=None):
    """Function for plotting the passed in activation function

    Args:
        activation_fn (_type_): A valid pytorch activation function
        x_range (tuple, optional): the range specified from a min value to a max value. Defaults to (-10, 10).
        num_points (int, optional): The number of x-points on which to compute the y-value. Defaults to 1000.
        title (_type_, optional): The title for the plot. Defaults to None, in which case it picks the name from the activation function passed.
    """
    x = torch.linspace(x_range[0], x_range[1], num_points)
    y = activation_fn(x)
    
    plt.figure(figsize=(8, 6))
    plt.plot(x.numpy(), y.numpy(), label=title or activation_fn.__name__)
    plt.title(title or activation_fn.__name__)
    plt.xlabel('Input')
    plt.ylabel('Output')
    plt.grid(True)
    plt.legend()
    plt.show()


## Plots of some of the pytorch activation functions

In [None]:
# Example activation functions
activation_functions = [
    F.relu,
    F.sigmoid,
    F.tanh,
    F.elu,
    F.leaky_relu,
    F.softmax,  # Needs special treatment because of additional arguments that it takes
    F.threshold, # Also needs special treatment for same reason
    F.hardtanh, # Also needs special treatment for same reason
]

for activation_fn in activation_functions:
    if activation_fn.__name__ == "softmax":
        # Softmax needs to be handled differently because it requires a dimension
        plot_activation_function(lambda x: activation_fn(x, dim=0), title="softmax")
    elif activation_fn.__name__ == "_threshold":
        plot_activation_function(lambda x: activation_fn(x, threshold=4.0, value=4.0), title="threshold")
    elif activation_fn.__name__ == "hardtanh":
        plot_activation_function(lambda x: activation_fn(x, min_val=-3, max_val=6.0), title="hardtanh")
    else:
        plot_activation_function(activation_fn)
