# Designing Floating Point Activation Functions

In [6]:
from typing import Type
import pyrtl
from pyrtl import *
import numpy as np
from enum import IntEnum
from hardware_accelerators.dtypes import BaseFloat

In [None]:
def relu(x: WireVector):
    return pyrtl.select(x[-1], 0, x)


def sigmoid(x: WireVector):
    pass


# Just an example
class ActivationFunction(IntEnum):
    RELU = 0
    SIGMOID = 1
    TANH = 2
    SOFTMAX = 3


class ActivationUnit:
    def __init__(self, size: int, dtype: Type[BaseFloat], activation_functions: list):
        self.size = size
        self.dtype = dtype
        self.activation_functions = activation_functions

        self.activation_select = WireVector(len(activation_functions).bit_length())
        self.data_ins = [WireVector(dtype.bitwidth()) for _ in range(size)]
        self.outputs = [WireVector(dtype.bitwidth()) for _ in range(size)]

        for input, output in zip(self.data_ins, self.outputs):
            with conditional_assignment:
                with self.activation_select == ActivationFunction.RELU:
                    output |= relu(input)
                with self.activation_select == ActivationFunction.SIGMOID:
                    output |= sigmoid(input)
                with self.activation_select == ActivationFunction.TANH:
                    output |= tanh(input)
                with self.activation_select == ActivationFunction.SOFTMAX:
                    output |= softmax(input)
                with otherwise:
                    output |= input

    def connect_inputs(
        self, data_ins: list | None = None, activation_select: WireVector | None = None
    ):
        if data_ins is not None:
            assert len(data_ins) == self.size
            for i in range(self.size):
                self.data_ins[i] <<= data_ins[i]
        if activation_select is not None:
            assert (
                activation_select.bitwidth
                == len(self.activation_functions).bit_length()
            )
            self.activation_select <<= activation_select