In [1]:
from torchvision import transforms
import glob
import pathlib
import torch
import torch.utils.data as dt
import torchvision

In [None]:
from datasets import load_dataset
fw = load_dataset("HuggingFaceFW/fineweb-edu", name="CC-MAIN-2024-10", split="train", streaming=True)

In [None]:
n_stream = 2500
texts = []
for x in fw:
    texts.append(x["text"])
    if len(texts) > n_stream: break

print(texts[:10])

In [None]:
import torch
from transformers import AutoTokenizer, GPTNeoModel

def extract_embeddings(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=False)
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    return outputs.hidden_states[-1].mean(axis=(0, 1))


# Load pre-trained model and tokenizer from Hugging Face
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m")
model = GPTNeoModel.from_pretrained("EleutherAI/gpt-neo-125m")
model.eval()

embeddings = []
for x in fw:
    embeddings.append(extract_embeddings(x["text"], model, tokenizer))
    if len(embeddings) > n_stream: break

In [15]:
import numpy as np

def threshold(x, lambd, clamp=True):
    x_ = np.sign(x) * np.maximum(np.abs(x) - lambd, 0)
    if clamp:
        x_[x_ < 0] = 0
    return x_

def alpha_update(X, D, alpha, lambd, max_iter=20):
    L = np.linalg.norm(D, ord=2) ** 2
    a = np.zeros_like(alpha)
    a_ = np.copy(a)
    
    t = 1
    for _ in range(max_iter):
        a_prev = np.copy(a)
        a = threshold(a_ + (1 / L) * D.T @ (X - D @ a_), lambd / L)
        t_new = (1 + np.sqrt(1 + 4 * t ** 2)) / 2
        a_ = a + ((t - 1) / t_new) * (a - a_prev)
        t = t_new
    
    return a

def dictionary_update(X, a):
    return X @ np.linalg.pinv(a)

In [None]:
# Example usage
K = 50
#X = torch.stack(embeddings).numpy().T
N, D = embeddings.shape
Phi = np.random.randn(D, K)  # Initial dictionary
alpha = np.random.randn(K, N)  # Initial sparse codes
l = 0.5  # Regularization parameter

# Perform dictionary learning
for i in range(10):
    alpha = alpha_update(X, Phi, alpha, l)
    Phi = dictionary_update(X, alpha)
    print(i)

In [None]:
top_10_indices = np.argsort(alpha[23, :])[-10:][::-1]
[texts[i] for i in top_10_indices]