In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
import torch

from dictionary_learning.dictionary import (
    AutoEncoder,
    GatedAutoEncoder,
    AutoEncoderNew,
    JumpReluAutoEncoder,
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"

torch.set_grad_enabled(False)

d_model = 100

torch.manual_seed(1)

scale = 4

x = torch.randn(1000, d_model, device=device)

x_scaled = x / scale



In [None]:
jumprelu_ae = JumpReluAutoEncoder(activation_dim=d_model, dict_size=d_model * 8, device=device)

jumprelu_ae.b_enc.data = torch.randn_like(jumprelu_ae.b_enc.data)
jumprelu_ae.b_dec.data = torch.randn_like(jumprelu_ae.b_dec.data)
jumprelu_ae.threshold.data = abs(torch.randn_like(jumprelu_ae.threshold.data))

reconstruction_1 = jumprelu_ae(x_scaled)

def scale_jumprelu(ae: JumpReluAutoEncoder, scale: float):
    ae.b_dec.data *= scale
    ae.b_enc.data *= scale
    ae.threshold.data *= scale

print(jumprelu_ae.threshold.mean())
scale_jumprelu(jumprelu_ae, (scale))
print(jumprelu_ae.threshold.mean())

reconstruction_2 = jumprelu_ae(x)

reconstruction_1 = reconstruction_1 * scale

diff = torch.abs(reconstruction_1 - reconstruction_2)
print(f"max diff: {diff.max()}, mean diff: {diff.mean()}")

assert torch.allclose(reconstruction_1, reconstruction_2, atol=1e-5)

In [None]:
gated_ae = GatedAutoEncoder(activation_dim=d_model, dict_size=d_model * 8, device=device)

gated_ae.r_mag.data = torch.randn_like(gated_ae.r_mag.data)
gated_ae.decoder_bias.data = torch.randn_like(gated_ae.decoder_bias.data)
gated_ae.mag_bias.data = torch.randn_like(gated_ae.mag_bias.data)
gated_ae.gate_bias.data = torch.randn_like(gated_ae.gate_bias.data)

reconstruction_1 = gated_ae(x_scaled)

def scale_gated(ae: GatedAutoEncoder, scale: float):
    ae.decoder_bias.data *= scale
    ae.mag_bias.data *= scale
    ae.gate_bias.data *= scale

print(gated_ae.r_mag.mean(), gated_ae.decoder_bias.mean(), gated_ae.mag_bias.mean(), gated_ae.gate_bias.mean())
scale_gated(gated_ae, (scale))
scale_gated(gated_ae, (1 / scale))
scale_gated(gated_ae, (scale))


print(gated_ae.r_mag.mean(), gated_ae.decoder_bias.mean(), gated_ae.mag_bias.mean(), gated_ae.gate_bias.mean())

reconstruction_2 = gated_ae(x)

reconstruction_1 = reconstruction_1 * scale

diff = torch.abs(reconstruction_1 - reconstruction_2)

print(f"max diff: {diff.max()}, mean diff: {diff.mean()}")
assert torch.allclose(reconstruction_1, reconstruction_2, atol=1e-5)

In [None]:
relu_ae = AutoEncoder(activation_dim=d_model, dict_size=d_model * 8)

relu_ae.encoder.bias.data = torch.randn_like(relu_ae.decoder.bias.data)
relu_ae.bias.data = torch.randn_like(relu_ae.bias.data)

reconstruction_1 = relu_ae(x_scaled)

def scale_relu(ae: AutoEncoder, scale: float):
    ae.encoder.bias.data *= scale
    ae.decoder.bias.data *= scale

reconstruction_2 = relu_ae(x)

reconstruction_1 = reconstruction_1 * scale

diff = torch.abs(reconstruction_1 - reconstruction_2)

print(f"max diff: {diff.max()}, mean diff: {diff.mean()}")
assert torch.allclose(reconstruction_1, reconstruction_2, atol=1e-5)