In [1]:
import torch
from geqtrain.nn.nonlinearities import ShiftedSoftPlus, ShiftedSoftPlusModule, SwiGLUModule, SwiGLU
from geqtrain.nn._fc import select_nonlinearity
from e3nn.math import normalize2mom
import random
import math

In [3]:
from tqdm import tqdm

@torch.no_grad()
def test_gain(non_linearity:str, gain:float=None, use_normalize2mom:bool=False, seed: int=42):
    activation = select_nonlinearity(non_linearity)
    in_size = 512
    out_size = in_size if non_linearity != "swiglu" else 2*in_size

    if not gain and not use_normalize2mom: raise ValueError('no gain (nor automatic gain) provided, idk what to do')
    if gain and use_normalize2mom: raise ValueError('both gain and automatic gain provided, idk what to do')

  # todo implement automatic gain
  # nonlinearity = {
  #   None: None,
  #   "silu": torch.nn.functional.silu,
  #   "ssp": ShiftedSoftPlusModule,
  #   "selu": torch.nn.functional.selu,
  #   "relu": torch.nn.functional.relu,
  #   "swiglu": SwiGLUModule,
  # }[non_linearity]
  # if use_normalize2mom: gain = normalize2mom()

    incr_counter = 0
    decr_counter = 0
    for _ in tqdm(range(5000), desc="Test gain Progress"):
        seed = random.randint(0, 100000)
        torch.manual_seed(seed)
        a = torch.randn(in_size, in_size)
        # l = torch.nn.Sequential(
        #         torch.nn.LayerNorm(in_size),
        #         torch.nn.Linear(in_size, out_size, bias=False),
        #         activation,
        #         torch.nn.LayerNorm(in_size),
        #         torch.nn.Linear(in_size, out_size, bias=False),
        #         activation,
        #         torch.nn.LayerNorm(in_size),
        #         torch.nn.Linear(in_size, out_size, bias=False),
        #         activation,
        #         torch.nn.LayerNorm(in_size),
        #         torch.nn.Linear(in_size, out_size, bias=False),
        # )

        l = torch.nn.Sequential(
                torch.nn.LayerNorm(in_size),
                torch.nn.Linear(in_size, out_size, bias=False),
                activation,
                torch.nn.LayerNorm(in_size),
                torch.nn.Linear(in_size, out_size, bias=False),
        )

    # l = torch.nn.Linear(in_size, out_size, bias=False)
    # torch.nn.init.orthogonal_(l.weight, gain=gain)
        for i, layer in enumerate(l):
            if i == len(l) - 1:
                gain = 1.0
            if isinstance(layer, torch.nn.Linear):
                # torch.nn.init.orthogonal_(layer.weight, gain=gain)
                fan_out, fan_in = layer.weight.size()
                std = gain / math.sqrt(fan_in)
                torch.nn.init.normal_(layer.weight, mean=0, std=std)

        for _ in range(100):
            x = l(a)

        # print (f"Test {test} - input std: {a.std().item():.6f}, out std: {x.std().item():.6f}, difference: {a.std().item() - x.std().item()}")
        if a.std().item() - x.std().item() > 0:
            incr_counter +=1
        else:
            decr_counter +=1
    print(f'incr_counter {incr_counter}, decr_counter {decr_counter}')

In [4]:
# test_gain('swiglu', gain=1.735) # decr
# test_gain('swiglu', gain=1.55) # decr
# test_gain('swiglu', gain=1.33) # decr
# test_gain('swiglu', gain=1.3) # decr
# test_gain('swiglu', gain=1.25) # incr
# test_gain('swiglu', gain=1.27) # incr_counter 174, decr_counter 326
# test_gain('swiglu', gain=1.26) # incr
# test_gain('swiglu', gain=1.265) # incr_counter 448, decr_counter 52
# test_gain('swiglu', gain=1.267) # incr_counter 341, decr_counter 159
# test_gain('swiglu', gain=1.268) # incr_counter 293, decr_counter 207
# test_gain('swiglu', gain=1.2685) # incr_counter 276, decr_counter 224
# test_gain('swiglu', gain=1.2687) # incr_counter 545, decr_counter 455
# test_gain('swiglu', gain=1.2688) # incr_counter 486, decr_counter 514
# test_gain('swiglu', gain=1.2687) # incr_counter 2568, decr_counter 2432
# test_gain('swiglu', gain=1.26875) # incr_counter 2533, decr_counter 2467
# test_gain('swiglu', gain=1.26876) # incr_counter 2501, decr_counter 2499
# test_gain('swiglu', gain=1.26877) # incr_counter 2501, decr_counter 2499
# test_gain('swiglu', gain=1.268765) # incr_counter 5117, decr_counter 4883
# test_gain('swiglu', gain=1.2687657) # incr_counter 5053, decr_counter 4947

In [6]:
# test_gain('swiglu', gain=1.26876575) incr_counter 2621, decr_counter 2379
test_gain('swiglu', gain=1.2687658)


Test gain Progress: 100%|██████████| 5000/5000 [16:06<00:00,  5.17it/s]

incr_counter 2523, decr_counter 2477



