In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.dataset import random_split
from torch.utils.data import DataLoader, TensorDataset
import torch

In [None]:
class ImprovedTwoLayerNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(ImprovedTwoLayerNN, self).__init__()
        # Increase depth and capacity
        self.layer1 = nn.Linear(input_size, hidden_size)
        self.relu1 = nn.LeakyReLU()
        self.batch_norm1 = nn.BatchNorm1d(hidden_size)
        self.dropout1 = nn.Dropout(0.5)  # Adjust dropout rate as needed
        
        # Additional layer
        self.layer2 = nn.Linear(hidden_size, hidden_size * 2)
        self.relu2 = nn.LeakyReLU()
        self.batch_norm2 = nn.BatchNorm1d(hidden_size * 2)
        self.dropout2 = nn.Dropout(0.5)  # Adjust dropout rate as needed
        
        # Output layer
        self.layer3 = nn.Linear(hidden_size * 2, output_size)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.relu1(x)
        x = self.batch_norm1(x)
        x = self.dropout1(x)
        
        x = self.layer2(x)
        x = self.relu2(x)
        x = self.batch_norm2(x)
        x = self.dropout2(x)
        
        x = self.layer3(x)
        return x
    
class TwoLayerNN(nn.Module):
    def __init__(self, input_size, hidden_size,  output_size):
        super(TwoLayerNN, self).__init__()
        self.layer1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        return x



In [None]:
def train(data_in, data_out, model, criterion, num_epochs, save=False, learning_rate=0.001):
    
    
    # Create a complete dataset
    full_dataset = TensorDataset(data_in.to("cuda"), data_out.to("cuda"))

    # Define the sizes for your training and validation sets
    total_size = len(full_dataset)
    train_size = int(0.8 * total_size)
    val_size = total_size - train_size

    # Split the dataset
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    # Create DataLoaders for both training and validation sets
    train_dataloader = DataLoader(train_dataset, batch_size=400, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=500)  # No need to shuffle the validation data


    model.to("cuda")
     
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)



    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, targets in train_dataloader:
            inputs, targets = inputs.to("cuda").float(), targets.to("cuda").float()
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
         
            running_loss += loss.item() * inputs.size(0)
    
        epoch_loss = running_loss / len(train_dataloader.dataset)
        # Validation phase
        model.eval()  # Set the model to evaluation mode
        running_loss = 0.0
        with torch.no_grad():  # No gradients need to be calculated
            for inputs, targets in val_dataloader:
                inputs, targets = inputs.to("cuda").float(), targets.to("cuda").float()
                outputs = model(inputs)
                
                loss = criterion(outputs, targets)
                running_loss += loss.item() * inputs.size(0)
        val_loss = running_loss / len(val_dataloader.dataset)

        print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {epoch_loss:.8f}, Validation Loss: {val_loss:.8f}')
        
    if(save):
        # Ensure the model is in evaluation mode
        model.eval()

        predictions = []
        with torch.no_grad():  # No gradients needed for inference
            for inputs, _ in full_dataset:  # Assuming your dataset returns inputs and targets
                inputs = inputs.to('cuda').float().unsqueeze(0)
                
                # Get the model output
                outputs = model(inputs)
                

                predictions.append(outputs.cpu())

        # Concatenate all batches of predictions
        all_predictions = torch.stack(predictions, dim=0)

        # Save the tensor to a file
        torch.save(all_predictions, 'model_predictions.pth')



In [72]:
id = "1"

loaded_activations  = torch.load(f'activations-{id}.pth')
loaded_embeddings = torch.load(f"embeddings-{id}.pth")
loaded_residual_stream = torch.load(f"residual_stream-{id}.pth")
#loaded_rebased_embeddings = torch.load(f"rebased_embeddings-{id}.pth")
loaded_one_hot = torch.load(f"one-hot-{id}.pth")

torch.Size([25, 19, 768])


In [93]:
data_in.shape

torch.Size([195, 768])

In [101]:
# consts
attention = 2
ff = 3

layer = 10


residual_stream =  [i[layer, -1, :] for i in loaded_residual_stream]

residual_stream = torch.stack(residual_stream)




data_in = residual_stream
data_out = loaded_one_hot.T

input_size = data_in.shape[1]
output_size = data_out.shape[1]
hidden_size = 8000

criterion = nn.MSELoss() 


model = TwoLayerNN(input_size, hidden_size, output_size)
train(data_in, data_out, model, criterion, num_epochs=20)
    

Epoch [1/20], Training Loss: 0.06029065, Validation Loss: 3.91040969
Epoch [2/20], Training Loss: 3.95423579, Validation Loss: 0.31829616
Epoch [3/20], Training Loss: 0.30948395, Validation Loss: 0.21837139
Epoch [4/20], Training Loss: 0.22504313, Validation Loss: 0.45714369
Epoch [5/20], Training Loss: 0.47738150, Validation Loss: 0.34621349
Epoch [6/20], Training Loss: 0.36371037, Validation Loss: 0.19990803
Epoch [7/20], Training Loss: 0.21037224, Validation Loss: 0.11369795
Epoch [8/20], Training Loss: 0.11905828, Validation Loss: 0.07359096
Epoch [9/20], Training Loss: 0.07595395, Validation Loss: 0.05665905
Epoch [10/20], Training Loss: 0.05709793, Validation Loss: 0.04952656
Epoch [11/20], Training Loss: 0.04878549, Validation Loss: 0.04680055
Epoch [12/20], Training Loss: 0.04523652, Validation Loss: 0.04648336
Epoch [13/20], Training Loss: 0.04422811, Validation Loss: 0.04743056
Epoch [14/20], Training Loss: 0.04458777, Validation Loss: 0.04831565
Epoch [15/20], Training Loss: