In [1]:
import torch
import torch.nn as nn
from torch import Tensor
import math

In [2]:
def get_slopes(num_heads):
        """for n heads, a set of slopes is the geometric sequence that starts
        2^(-8/n) and uses this same value as its ratio

        """

        def get_slopes_power_of_2(n):
            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * ratio**i for i in range(n)]

        if math.log2(num_heads).is_integer():
            return get_slopes_power_of_2(num_heads)

        # paper authors note they only trained models that have 2^a heads for some a.
        # This has beneficial properties related to input being power of 2.
        # Closest power of 2 below is workaround for when num of heads is not power of 2

        closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
        return (
            get_slopes_power_of_2(closest_power_of_2)
            + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
                : num_heads - closest_power_of_2
            ]
        )

In [7]:
res = get_slopes(4)

In [8]:
res

[0.25, 0.0625, 0.015625, 0.00390625]