In [2]:
# KLN vs Standard MLP Benchmark on OpenMathReasoning Dataset

from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from kln import FlexibleConditional

# --- Standard MLP Model ---
class StandardMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(StandardMLP, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.model(x)

# --- KLN Model ---
class KLNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(KLNModel, self).__init__()
        self.model = FlexibleConditional(input_dim, hidden_dim, output_dim)
        self.input_dim = input_dim

    def forward(self, x):
        i = x[:, :self.input_dim // 2]
        j = x[:, self.input_dim // 2:]
        return self.model(i, j)

# --- Load OpenMathReasoning Dataset ---
dataset = load_dataset("nvidia/OpenMathReasoning", split="train")

# Load a sentence transformer model
encoder = SentenceTransformer('all-MiniLM-L6-v2')

# Encode questions into dense vectors
questions = [d['question'] for d in dataset]
X = torch.tensor(encoder.encode(questions, convert_to_numpy=True), dtype=torch.float32)

# Create numeric targets based on answer length (number of words)
answers = [d['answer'] for d in dataset]
y = torch.tensor([len(ans.split()) for ans in answers], dtype=torch.float32)

print(f"Encoded X shape: {X.shape}")
print(f"Target y shape: {y.shape}")

# --- Train Function ---
def train_model(model, X_train, y_train, epochs=500, lr=5e-4, track_alpha=False):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    losses = []
    alphas = []
    for epoch in range(epochs):
        optimizer.zero_grad()
        output = model(X_train).squeeze()
        loss = criterion(output, y_train)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

        if track_alpha:
            with torch.no_grad():
                alphas.append(torch.sigmoid(model.model.alpha).item())

        if epoch % 50 == 0:
            print(f"Epoch {epoch}: Loss = {loss.item():.6f}")

    if track_alpha:
        return losses, alphas
    else:
        return losses

# --- Train both models ---
input_dim = X.shape[1]
hidden_dim = 64
output_dim = 1

# Standard MLP
standard_model = StandardMLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim)
standard_losses = train_model(standard_model, X, y)

# KLN Model
kln_model = KLNModel(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim)
kln_losses, kln_alphas = train_model(kln_model, X, y, track_alpha=True)

# --- Plot Training Losses ---
plt.figure(figsize=(10,6))
plt.plot(standard_losses, label='Standard MLP Loss')
plt.plot(kln_losses, label='KLN FlexibleConditional Loss')
plt.xlabel('Epoch')
plt.ylabel('Training Loss (MSE)')
plt.title('Training Loss Comparison on OpenMathReasoning Dataset')
plt.legend()
plt.grid(True)
plt.savefig('figures/openmath_training_loss.png')
plt.show()

# --- Plot Alpha Evolution ---
plt.figure(figsize=(8,6))
plt.plot(kln_alphas)
plt.xlabel('Epoch')
plt.ylabel('Alpha Value (Mixing Coefficient)')
plt.title('Evolution of Alpha during KLN Training on OpenMathReasoning Dataset')
plt.grid(True)
plt.savefig('figures/openmath_alpha_evolution.png')
plt.show()

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Downloading data:  76%|███████▌  | 109/144 [13:17<04:15,  7.31s/files]


OSError: [Errno 28] No space left on device