In [10]:
import sys
sys.path.append("../")
import torch
from main import *
import numpy as np
import pickle
import wandb
from torch.utils.data import DataLoader

In [85]:
class LinearAutoencoder(torch.nn.Module):

    def __init__(self,
                 in_features,
                 out_features,
                 device="cpu"):
        
        super().__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.device = device

        self.encoder = torch.nn.Sequential(
                torch.nn.Linear(in_features, out_features,
                bias=True, 
                device=device), 
                torch.nn.ReLU())

        self.decoder = torch.nn.Linear(
                    out_features,
                    in_features,
                    bias=False,
                    device=device)
        torch.nn.init.orthogonal_(self.decoder.weight.data)

    def encode(self, x):
        out = self.encoder(x)
        return out
    
    def decode(self, z):
        self.decoder.weight.data = self.decoder.weight.data / (1e-10 + self.decoder.weight.data.norm(dim=0))
        out = self.decoder(z)
        return out

    def forward(self, x):
        latent = self.encode(x)
        recon = self.decode(latent)
        return latent, recon

In [63]:
def sparse_ae_loss(x, x_hat, z, lam, in_features, out_features):
    recon_loss = torch.nn.functional.mse_loss(x, x_hat) # divide by in_features?
    l1_reg = lam * torch.norm(z, p=1) # divide by out_features?
    total_loss = recon_loss + l1_reg
    return total_loss

def train_sparse_autoencoder(
        model, 
        train_loader,
        val_loader, 
        lam,
        optimizer,
        n_epochs):
    
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            latent, recon = model(data)
            loss = sparse_ae_loss(data, 
                                  recon, 
                                  latent, 
                                  lam,
                                  model.in_features, 
                                  model.out_features)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader.dataset)


        # Validate the model
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for data, target in val_loader:
                latent, recon = model(data)
                loss = sparse_ae_loss(data, 
                                  recon, 
                                  latent, 
                                  lam,
                                  model.in_features, 
                                  model.out_features)
                val_loss += loss.item()

        val_loss /= len(val_loader.dataset)

                
        # Log training loss to wandb
        wandb.log({
            "epoch": epoch,
            "n_examples": len(train_loader.dataset) * (epoch+1),
            "train_loss": train_loss, 
            "val_loss": val_loss
            })
        
        print(f"Epoch: {epoch+1}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

In [58]:
from torch.utils.data import Dataset
class ActivationDataset(Dataset):

    def __init__(self, activations, labels=None):
        super().__init__()
        self.data = activations
        if labels is not None:
            self.labels = labels
        else:
            self.labels = torch.arange(len(activations))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [57]:
# GENERATE DATA
import os
try:
    os.makedirs("resnet-data")
except:
    print("directory already exists")

file = open("../logs/results.pkl",'rb')
results = pickle.load(file)


for l in np.arange(1, 5):
    layer = "layer{}".format(l)
    activations = results[layer]['activations_normed']
    activations = activations - activations.mean(axis=0)
    activations = activations / (1e-10 + activations.std(axis=0))
    activations = torch.tensor(activations)
    torch.save(activations, "resnet-data/{}-activations-zscored.pt".format(layer))

directory already exists


In [81]:
# CONFIG

activations = torch.load("resnet-data/layer1-activations-zscored.pt")

config = {
    "in_features": activations.shape[-1],
    "out_features": activations.shape[-1] * 2,
    "lr": 0.001,
    "lam": 0.1,
    "fraction_val": fraction_val,
    "batch_size": batch_size
}

wandb.init(project="monosemanticity")
wandb.log({"config": config})



VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011168181944716101, max=1.0…

In [82]:
config

{'in_features': 256,
 'out_features': 512,
 'lr': 0.001,
 'lam': 0.1,
 'fraction_val': 0.2,
 'batch_size': 64}

In [86]:
model = LinearAutoencoder(
    in_features=config['in_features'],
    out_features=config['out_features']
)

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=config['lr']
    )

train_sparse_autoencoder(
    model=model, 
    train_loader=train_loader,
    val_loader=val_loader,
    lam=config["lam"],
    optimizer=optimizer,
    n_epochs=1000)

Epoch: 1, Training Loss: 0.8856, Validation Loss: 0.0188
Epoch: 2, Training Loss: 0.0169, Validation Loss: 0.0166
Epoch: 3, Training Loss: 0.0162, Validation Loss: 0.0165
Epoch: 4, Training Loss: 0.0165, Validation Loss: 0.0169
