construct dataset

In [1]:
from datasets import load_dataset

ds = load_dataset("JeanKaddour/minipile")

In [6]:
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
import pandas as pd
from datetime import datetime

# Initialize an empty DataFrame
df_all_sentences = pd.DataFrame(columns=['sentence'])

row_index = 0
for text in ds['train']['text'][:10000]:
    sentences = sent_tokenize(text)
    for sentence in sentences:
        # Use .at to set the value in the DataFrame
        df_all_sentences.at[row_index, 'sentence'] = sentence
        row_index += 1

    if row_index % 100 == 0:
        print(row_index)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f'files/all_sentences_{timestamp}.csv'
df_all_sentences.to_csv(filename, index=False)


[nltk_data] Downloading package punkt to /sailhome/joetey/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


2900
5100
8300
14400
14600
15100
22900
25200
25600
29200
34600
35500
36000
36200
38900
41600
51400
53800
55500
57900
58600
59600
61700
69600
71700
73200
74400
77200
79300
82000
83100
86200
92500
93400
96000
99100
99500
106200
109400
113600
115900
121300
126400
135500
143600
145100
145200
146400
156900
158200
158600
158900
164700
175200
176700
190500
192900
195000
197500
197800
198300
202200
203400
213700
216400
218400
220200
221100
223800
224400
227100
228800
238600
244700
247000
253000
253200
258700
263300
268500
270700
277100
283900
287600
288500
289600
298400
309600
311200
316500
319600
319800
322400
323800
324000
338900
343400
345600
346500
348800
354400
354900
357600
359800
365100
369200
369900
374300
374400
375900
376800
381100
389700


In [7]:
from torch.utils.data import Dataset

class MiniPileDataset(Dataset):
    def __init__(self, sentences, embeddings):
        self.sentences = sentences
        self.embeddings = embeddings

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        return self.sentences[idx], self.embeddings[idx]

In [12]:
from transformers import AutoTokenizer, AutoModel

bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
bert_model = AutoModel.from_pretrained("bert-base-uncased")

In [None]:
# Write the dataset to a CSV file
import csv
from datetime import datetime
import torch

embeddings = []
sentences_passed = []

batch_size = 10
sentences = df_all_sentences['sentence'].tolist()

for i in range(0, len(sentences), batch_size):
    batch = sentences[i:i + batch_size]
    try:
        inputs = bert_tokenizer(batch, return_tensors="pt", padding=True, truncation=True)
        outputs = bert_model(**inputs)
        batch_embeddings = outputs.last_hidden_state.mean(dim=1).detach()
        embeddings.extend(batch_embeddings)
        sentences_passed.extend(batch)
        
    except Exception as e:
        print(f"Error: {e}")
        continue
        
    print(f"Processed {i + len(batch)} sentences")

embeddings = torch.stack(embeddings)

# Initialize the dataset
mini_pile_dataset = MiniPileDataset(sentences, embeddings)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f'files/all_sentences_with_embeddings_{timestamp}.csv'
with open(filename, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['sentence', 'embedding'])
    for sentence, embedding in zip(mini_pile_dataset.sentences, mini_pile_dataset.embeddings):
        writer.writerow([sentence, embedding])


Processed 10 sentences
Processed 20 sentences
Processed 30 sentences
Processed 40 sentences
Processed 50 sentences
Processed 60 sentences
Processed 70 sentences
Processed 80 sentences
Processed 90 sentences
Processed 100 sentences
Processed 110 sentences
Processed 120 sentences
Processed 130 sentences
Processed 140 sentences
Processed 150 sentences
Processed 160 sentences
Processed 170 sentences
Processed 180 sentences
Processed 190 sentences
Processed 200 sentences
Processed 210 sentences
Processed 220 sentences
Processed 230 sentences
Processed 240 sentences
Processed 250 sentences
Processed 260 sentences
Processed 270 sentences
Processed 280 sentences
Processed 290 sentences
Processed 300 sentences
Processed 310 sentences
Processed 320 sentences
Processed 330 sentences
Processed 340 sentences
Processed 350 sentences
Processed 360 sentences
Processed 370 sentences
Processed 380 sentences
Processed 390 sentences
Processed 400 sentences
Processed 410 sentences
Processed 420 sentences
P

train SAE

In [10]:
import torch.optim as optim
import torch
from utils.sae import SparseAutoencoder, SparseAutoencoderConfig

# Assuming `dataset` is a PyTorch Dataset loaded and ready to use
data_loader = torch.utils.data.DataLoader(mini_pile_dataset, batch_size=512, shuffle=True)

# Initialize the model
config = SparseAutoencoderConfig(d_model=768, d_sparse=8 * 768)
model = SparseAutoencoder(config)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 2
for epoch in range(num_epochs):
    total_loss = 0
    for sentence, embeddings in data_loader:
        optimizer.zero_grad()

        # Assuming data is already on the correct device and in the correct format
        _, _, loss, _ = model.forward(embeddings, return_loss=True, sparsity_scale=1.0)
        
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    print(f"Epoch {epoch+1}, Average Loss: {total_loss / len(data_loader)}")

Epoch 1, Average Loss: 0.0963502898812294
Epoch 2, Average Loss: 0.14200088381767273


In [11]:
import pickle

# Save the model to a pickle file
model_path = "sparse_autoencoder_model_2.pkl"
with open(model_path, "wb") as f:
    pickle.dump(model.state_dict(), f)

print(f"Model saved to {model_path}")

Model saved to sparse_autoencoder_model.pkl
