In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch as t
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

device = t.device("cuda:0" if t.cuda.is_available() else "cpu")


# A sparse autoencoder architecture
class SAE(nn.Module):
    def __init__(self, dimension, hidden_size, nonlinearity=nn.ReLU(), freeze=True):
        super(SAE, self).__init__()
        self.fc1 = nn.Linear(dimension, hidden_size)
        self.fc2 = nn.Linear(hidden_size, dimension, bias=False)
        self.nonlinearity = nonlinearity
        if freeze:
            self.fc2.weight.requires_grad = False

    def forward(self, x):
        x = self.fc1(x)
        acts = self.nonlinearity(x)
        out = self.fc2(acts)
        return out, acts


# a dataset of random unit vectors in R^dimension of size dataset_size
class RandomUnitVectors(Dataset):
    def __init__(self, dataset_size, dimension):
        self.dataset_size = dataset_size
        self.dimension = dimension
        self.data = t.randn(self.dataset_size, self.dimension).to(device)
        self.data = self.data / t.norm(self.data, dim=1).unsqueeze(1)

    def __len__(self):
        return self.dataset_size

    def __getitem__(self, idx):
        return self.data[idx]

    def get_data(self):
        return self.data


def train(
    model,
    train_loader,
    optimizer,
    sparsity_loss_fn,
    sparsity_penality,
    epochs=1,
    checkpoint_freq=10,
):
    reconstruction_losses = []
    sparsity_losses = []
    epochs = list(range(epochs))
    checked_epochs = []
    explained_variances = []
    reconstruction_criterion = nn.MSELoss()
    compositenesses = []
    for epoch in epochs:
        for data in train_loader:
            optimizer.zero_grad()
            output, acts = model(data)
            reconstruction_loss = reconstruction_criterion(output, data)
            sparsity_loss = sparsity_loss_fn(acts) / len(data)
            loss = reconstruction_loss + sparsity_penality * sparsity_loss
            loss.backward()
            reconstruction_losses.append(reconstruction_loss.item())
            sparsity_losses.append(sparsity_loss.item())
            optimizer.step()
        if epoch % checkpoint_freq == 0:
            checked_epochs.append(epoch)
            # calculate explained variance
            explained_variance = t.var(output) / t.var(data)
            explained_variances.append(explained_variance.item())
            compositenesses.append((acts != 0).sum(dim=-1).detach().cpu().numpy().mean())
            print(
                f"Epoch {epoch}, reconstruction loss {reconstruction_loss.item()}, "
                f"sparsity loss {sparsity_loss.item()}, explained variance {explained_variance}, "
                f"compositeness {compositenesses[-1]}"
            )
    out_dict = {
        "reconstruction_losses": reconstruction_losses,
        "sparsity_losses": sparsity_losses,
        "all_epochs": epochs,
        "checked_epochs": checked_epochs,
        "explained_variances": explained_variances,
        "compositenesses": compositenesses,
    }
    return out_dict

In [None]:
dataset_size = 50
dimension = 768
hidden_size = 25000
nonlinearity = nn.ReLU()
freeze = True
sparsity_penality = 1e-10
epochs = 100000
checkpoint_freq = 2000
batch_size = 50
learning_rate = 1e-4
plot = False
model = SAE(dimension, hidden_size, nonlinearity, freeze).to(device)
dataset = RandomUnitVectors(dataset_size, dimension)
train_loader = DataLoader(dataset, batch_size=batch_size)
optimizer = t.optim.Adam(model.parameters(), lr=learning_rate)
sparsity_loss_fn = lambda x: t.norm(x, p=0.5)
results = train(
    model,
    train_loader,
    optimizer,
    sparsity_loss_fn,
    sparsity_penality,
    epochs,
    checkpoint_freq,
)

In [None]:
out, acts = model(dataset.data)
(acts != 0).sum(dim=-1).detach().cpu().numpy()

In [None]:
model.fc2.weight

In [None]:
plt.plot(results["all_epochs"], results["reconstruction_losses"])
plt.title("Reconstruction Loss")
plt.show()
plt.figure()
plt.plot(results["all_epochs"], results["sparsity_losses"])
plt.title("Sparsity Loss")
plt.show()
plt.figure()
plt.plot(results["checked_epochs"], results["explained_variances"])
plt.title("Explained Variance")
plt.show()
plt.figure()
plt.plot(results["checked_epochs"], results["compositenesses"])
plt.title("Compositeness")

In [None]:
import scipy.stats as stats

# parameters
dfn = 1  # replace with your degrees of freedom numerator
dfd = 767  # replace with your degrees of freedom denominator
x = 0.203  # the threshold for the tail

# compute the survival function (1 - CDF)
tail_prob = stats.f.cdf(x, dfn, dfd)

print(
    f"The probability that a variable from an F-distribution with {dfn} and {dfd} degrees of freedom exceeds {x} is {tail_prob}."
)

In [None]:
a = 25000**266
len(str(a))

In [None]:
def prob_explains_variance(variance_explained, dimension, compositeness):
    dfd = compositeness
    dfn = dimension - compositeness
    x = 1 / variance_explained - 1
    x = x * dfd / dfn
    return stats.f.cdf(x, dfn, dfd)


dimension = 768
compositenesses = np.array(range(60, 80))
variance = 0.9
probs = [prob_explains_variance(variance, dimension, c) for c in compositenesses]
plt.plot(compositenesses, np.log10(probs))
samples = np.array([-compositeness * np.log10(25000) for compositeness in compositenesses[1:]])
plt.plot(compositenesses[1:], samples)