In [None]:
import numpy as np 
import pandas as pd
import json
import torch
import torch.nn as nn
from sae import TopKSAE, SAELoss
from tqdm import tqdm
from sklearn.metrics import pairwise_distances_argmin_min
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict
import random
from datasets import load_dataset

In [2]:
random.seed(a="SAE Training")

## Load data

In [3]:
df = pd.DataFrame(load_dataset("jam963/indigeneity_fr", split="train"))
df["embedding"] = df["embedding"].map(json.loads)

In [4]:
with open("periods_1825_1950_10.json", "r") as f:
    time_periods = json.load(f)

In [5]:
embeddings = np.stack(df.embedding.tolist())

# Standardize embeddings 
mean = np.mean(embeddings, axis=0)
std = np.std(embeddings, axis=0)
std = np.where(std == 0, 1e-7, std)
df["embedding"] = ((embeddings - mean) / std).tolist()

In [6]:
df.head()

Unnamed: 0,file_name,sentence,term,id,title,creator,publisher,date,type,language,relation,length,genre,embedding,sen_len
0,0004976.json,"de La- martine et Victor Hugo se rattachent, a...",indigène,4976,Histoire de la littérature française sous la R...,"Nettement, Alfred",J. Lecoffre (Paris),1853,monograph,fre,ark:/12148/cb37273565t,116867,Littérature française,"[-1.9171812741988257, -2.2278929959485767, 0.7...",59
1,0004976.json,"En littérature, il conti- nuait ses doctrines ...",indigène,4976,Histoire de la littérature française sous la R...,"Nettement, Alfred",J. Lecoffre (Paris),1853,monograph,fre,ark:/12148/cb37273565t,116867,Littérature française,"[-1.1204840080007559, -0.21415018738329644, -0...",85
3,0006815.json,"J'appelle causes positives, celles qui modifie...",indigènes,6815,Race et milieu social : essais d'anthroposocio...,"Vacher de Lapouge, Georges (1854-1936). Auteur...",M. Rivière (Paris),1909,monograph,fre,ark:/12148/cb31515687b,108554,Sociologie,"[0.42317150320799435, -0.510985201874166, 0.61...",52
4,0006815.json,Elles tendent sans cesse à rendre aux indigène...,indigènes,6815,Race et milieu social : essais d'anthroposocio...,"Vacher de Lapouge, Georges (1854-1936). Auteur...",M. Rivière (Paris),1909,monograph,fre,ark:/12148/cb31515687b,108554,Sociologie,"[-1.49437695570914, -0.1559418505879319, 0.931...",37
5,0006815.json,Les nègres d'Améri- que sont le seul exemple t...,indigènes,6815,Race et milieu social : essais d'anthroposocio...,"Vacher de Lapouge, Georges (1854-1936). Auteur...",M. Rivière (Paris),1909,monograph,fre,ark:/12148/cb31515687b,108554,Sociologie,"[0.3046182655151907, 0.1384508724158276, 0.021...",37


## SAE Training

In [7]:
def group_data_by_time_period(df, time_periods):
    data_by_period = defaultdict(list)
    for i, (_, row) in enumerate(df.iterrows()):
        year = row['date']
        for period, (start, end) in time_periods.items():
            if start <= year <= end:
                data_by_period[period].append((i, df.iloc[i]["embedding"]))
                break
    
    empty_periods = [period for period, data in data_by_period.items() if len(data) == 0]
    if empty_periods:
        print(f"Warning: The following periods have no data: {', '.join(empty_periods)}")
    
    return data_by_period

def balanced_sample(data_by_period, sample_size):
    sampled_data = []
    for period, data in data_by_period.items():
        if len(data) == 0:
            continue
        if len(data) < sample_size:
            sampled = data
            print(f"Warning: Period {period} has fewer samples ({len(data)}) than the requested sample size ({sample_size})")
        else:
            sampled = random.sample(data, sample_size)
        sampled_data.extend(sampled)
    random.shuffle(sampled_data)
    return sampled_data

In [8]:
def train_balanced_topk_sae(model, df, time_periods, total_steps, learning_rate, batch_size, device, 
                            dead_latent_update_pct=1.0,
                            sample_size_pct=1.0,
                            n_latents=768*2,
                            k_active=16):

    
    data_by_period = group_data_by_time_period(df, time_periods)
    non_empty_periods = [period for period, data in data_by_period.items() if len(data) > 0]
    if not non_empty_periods:
        raise ValueError("No data available for any time period")
    
    min_period_size = min(len(data) for period, data in data_by_period.items() if len(data) > 0)
    sample_size = int(sample_size_pct * min_period_size)
    
    print(f"Sample size per period: {sample_size}")
    for period, data in data_by_period.items():
        print(f"Period {period}: {len(data)} samples")
        
    samples_for_dead_latent_update = round(len(df)*dead_latent_update_pct)
    
    criterion = SAELoss(n_latents, k_active, samples_per_epoch=samples_for_dead_latent_update).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    writer = SummaryWriter()
    global_step = 0

    training_total_loss, training_aux_loss, training_dead_latents = [], [], []
    
    while global_step < total_steps:
        sampled_data = balanced_sample(data_by_period, sample_size)
        if not sampled_data:
            raise ValueError("No data sampled for training")
        
        for i in range(0, len(sampled_data), batch_size):
            batch_indices, batch_embeddings = zip(*sampled_data[i:i+batch_size])
            batch = torch.tensor(batch_embeddings, dtype=torch.float32).to(device)
            
            optimizer.zero_grad()
            x_recon, h_sparse, h = model(batch)
            
            loss, mse_loss, aux_loss = criterion(batch, x_recon, h_sparse, model.encoder, model.decoder)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            training_total_loss.append(loss.item())
            training_aux_loss.append(aux_loss.item())
            if hasattr(criterion, "dead_latents"): 
                training_dead_latents.append(criterion.dead_latents.sum().item())
            else: 
                training_dead_latents.append(0)
            
            if global_step % 100 == 0:
                writer.add_scalar('Loss/total', loss.item(), global_step)
                
                # Log grad norms
                for name, param in model.named_parameters():
                    if param.grad is not None:
                        grad_norm = param.grad.norm().item()
                        writer.add_scalar(f'grad_norm/{name}', grad_norm, global_step)
                
                # Log weight norms
                for name, param in model.named_parameters():
                    weight_norm = param.norm().item()
                    writer.add_scalar(f'weight_norm/{name}', weight_norm, global_step)
                
                # Log sparsity
                sparsity = (h == 0).float().mean().item()
    
                writer.add_scalar('Sparsity/hidden', sparsity, global_step)
        
                writer.add_scalar('Loss/aux', aux_loss, global_step)
            
            global_step += 1
    
    writer.close()
    return model, sample_size, {"total_loss": training_total_loss, "aux_loss": training_aux_loss, "dead_latents": training_dead_latents}

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [10]:
input_dim = len(df.embedding.iloc[0]) 
d_multiple = 4
n_latents = int(input_dim*d_multiple)
k_active = 8
samples_per_epoch = df.shape[0]
batch_size = 1024
total_steps = 20000
learning_rate = 1e-4


writer = SummaryWriter()
log_interval = 100

In [11]:
model = TopKSAE(input_dim, n_latents, k_active).to(device)

trained_model, sample_size, training_stats = train_balanced_topk_sae(model, 
                                                                    df, 
                                                                    time_periods, 
                                                                    total_steps, 
                                                                    learning_rate, 
                                                                    batch_size, 
                                                                    device,
                                                                    n_latents=n_latents,
                                                                    k_active=k_active)

Sample size per period: 4942
Period 1851-1875: 4942 samples
Period 1901-1925: 92767 samples
Period 1876-1900: 67331 samples
Period 1926-1950: 39823 samples
Period 1825-1850: 5442 samples


In [12]:
import os

def save_model_and_hyperparams(model, hyperparams, base_dir='models'):
    model_dir = f"{hyperparams["n_latents"]}_{hyperparams["k_active"]}_{hyperparams["hidden_state"]}"
    
    os.makedirs(base_dir, exist_ok=True)
    os.mkdir(os.path.join(base_dir, model_dir))
    
    # Save model
    model_filename = "SAE.pth"
    model_path = os.path.join(base_dir, model_dir, model_filename)
    torch.save(model.state_dict(), model_path)
    print(f"Model saved to {model_path}")
    
    # Save hyperparameters
    hyperparams_filename = "hyperparams.json"
    hyperparams_path = os.path.join(base_dir, model_dir, hyperparams_filename)
    with open(hyperparams_path, 'w') as f:
        json.dump(hyperparams, f, indent=4)
    print(f"Hyperparameters saved to {hyperparams_path}")
    return model_dir


hyperparams = {
    "embed_mean": mean.tolist(), 
    "embed_std": std.tolist(), 
    "n_latents": n_latents, 
    "k_active": k_active,
    "batch_size": batch_size,
    "total_steps": total_steps,
    "time_periods": time_periods,
    "sample_size": sample_size,
    "hidden_state": -1
}
    
model_dir = save_model_and_hyperparams(trained_model, hyperparams)

with open(f"models/{model_dir}/training_stats.json", "w") as f: 
    json.dump(training_stats, f, indent=4)

Model saved to models/4096_8_-1/SAE.pth
Hyperparameters saved to models/4096_8_-1/hyperparams.json


In [13]:
from scipy import sparse


def generate_sparse_embeddings(model, df, batch_size=1024):
    model.eval()
    device = next(model.parameters()).device
    print(f"Device: {device}")
    hidden_dim = model.encoder.out_features

    all_sparse_embeddings = []
    with torch.no_grad(): 
        for i in tqdm(range(0, len(df), batch_size), desc="Processing batches..."):
            batch = torch.tensor(df.embedding.iloc[i:i+batch_size].tolist(), dtype=torch.float32).to(device)
            _, h_sparse, _ = model(batch)
            h_sparse = h_sparse.cpu().numpy()
            all_sparse_embeddings.append(sparse.csr_array(h_sparse))
            
    all_sparse_embeddings = sparse.vstack(all_sparse_embeddings)
    assert len(df) == all_sparse_embeddings.shape[0]
    
    return all_sparse_embeddings

In [14]:
all_sparse_embeddings = generate_sparse_embeddings(trained_model, df)

Device: cuda:0


Processing batches...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 206/206 [00:11<00:00, 18.02it/s]


In [15]:
model_dir

'4096_8_-1'

In [16]:
# sparse.save_npz(f"models/{model_dir}/sparse_embeddings.npz", all_sparse_embeddings)