In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torch.distributions as dists
import numpy as np

from gpt_embeds import embeds, toks, enc
from interp_utils import heatmap, reload_module, hist

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


In [211]:
from interp_utils import reload_module

reload_module('interp_models')
from interp_models import Autoencoder, SparseAutoencoder, SparseNNMF
from tqdm import tqdm
from interp_utils import get_scheduler

reload_module('toy_models')
from toy_models import SparseIndependent

from copy import deepcopy

N_FEATURES = 8000
D_MODEL = 768
N_EPOCHS = 1
CODE_STEPS = 1000
ATOM_STEPS=1000
SPARSE_COEF = 10
MINIBATCHES_PER_EPOCH = 4
LR = 3e-2

batch_size = embeds.shape[0]//MINIBATCHES_PER_EPOCH
print(f'batch_size: {batch_size}')

def inv_perm(p):
    inv = torch.zeros_like(p)
    for i, v in enumerate(p):
        inv[v] = i
    return inv

nnmf = SparseNNMF(n_features=N_FEATURES, d_model=D_MODEL, orthog_k=False, bias=True).to(device)

randperm_cache = []
codes_cache = []

for epoch_idx in range(N_EPOCHS):
    randperm = torch.randperm(embeds.shape[0])
    randperm_cache.append(randperm)
    for minibatch_idx in range(MINIBATCHES_PER_EPOCH):
        if minibatch_idx == MINIBATCHES_PER_EPOCH - 1:
            batch_perm = randperm[minibatch_idx*batch_size:]
        else:
            batch_perm = randperm[minibatch_idx*batch_size:(minibatch_idx+1)*batch_size]

        batch = deepcopy(embeds[batch_perm]).to(device)
        # just codes unless minibatch_idx = epoch_idx = 0

        nnmf.train(batch, n_epochs=CODE_STEPS, sparse_coef=SPARSE_COEF, reinit_codes=True, frozen_atoms=(minibatch_idx > 0 or epoch_idx > 0), mean_init = minibatch_idx == 0 and epoch_idx == 0, lr=LR)

        if minibatch_idx > 0 or epoch_idx > 0:
            # codes and atoms
            nnmf.train(batch, n_epochs=ATOM_STEPS, sparse_coef=SPARSE_COEF, reinit_codes=False, lr=LR)
        
        codes_cache.append(nnmf.codes().detach().cpu())
        
        del batch
        torch.cuda.empty_cache()





batch_size: 12564


loss: 0.008, mse: 0.008, sparse: 0.000: 100%|██████████| 1000/1000 [02:01<00:00,  8.23it/s]
loss: 0.014, mse: 0.014, sparse: 0.000: 100%|██████████| 1000/1000 [01:37<00:00, 10.22it/s]
loss: 0.008, mse: 0.008, sparse: 0.000: 100%|██████████| 1000/1000 [02:01<00:00,  8.23it/s]
loss: 0.014, mse: 0.014, sparse: 0.000: 100%|██████████| 1000/1000 [01:38<00:00, 10.19it/s]
loss: 0.008, mse: 0.008, sparse: 0.000: 100%|██████████| 1000/1000 [02:01<00:00,  8.25it/s]


reinitializing codes because train_data size changed


loss: 0.014, mse: 0.014, sparse: 0.000: 100%|██████████| 1000/1000 [01:37<00:00, 10.23it/s]
loss: 0.007, mse: 0.007, sparse: 0.000: 100%|██████████| 1000/1000 [02:01<00:00,  8.25it/s]


In [23]:
MINIBATCHES_PER_EPOCH*batch_size

50256

In [71]:
codes = torch.cat(codes_cache, dim=0)
# find index of 50256
inverted_perm = inv_perm(randperm_cache[0])
# the og idx that was sent to 50256
# missing_idx = randperm_cache[0][-1]

# inverted_perm = torch.cat([inverted_perm[:missing_idx], inverted_perm[missing_idx+1:]])
codes = codes[inverted_perm]
# toks_subset = np.concatenate([toks[:missing_idx], toks[missing_idx+1:]])


In [106]:
import html
import random
from IPython.core.display import display, HTML

toks = np.array([tok.replace(' ', '∘').replace('\n', '⏎') for tok in toks])

def html_escape(text):
    return html.escape(text)

def decode(tok_id_list):
    return [enc.decode([tok]).replace(' ', '∘').replace('\n', '⏎') for tok in tok_id_list]

def render_toks_w_weights(toks, weights):
    if isinstance(weights, torch.Tensor):
        weights = weights.cpu().detach()
    else:
        weights = torch.tensor(weights)
    if isinstance(toks, torch.Tensor):
        toks = decode(toks)
        
    highlighted_text = []

    for weight, tok in zip(weights.tolist(), toks):
        if weight > 0.0:
            highlighted_text.append(f'<span style="background-color:rgba(135,206,250,{min(.7*weight, 0.7)});border: 0.3px solid black;padding: 0.3px">{tok}</span>')
        else:
            highlighted_text.append(f'<span style="background-color:rgba(135,206,250,{min(-1.3*weight, 1)});border: 0.3px solid black;padding: 0.3px">{tok}</span>')
    highlighted_text = ' '.join(highlighted_text)
    # make it render in a width-constrained way with wrapping (don't
    highlighted_text = f'<div style="width: 40rem">{highlighted_text}</div>'
    display(HTML(highlighted_text))
    return HTML(highlighted_text)



  from IPython.core.display import display, HTML


In [73]:

atom_perm = codes.mean(dim=0).argsort(descending=True)
codes = codes[:,atom_perm]


In [None]:
def render_atom(atom_idx):
    topk = codes[:,atom_idx].topk(k=200)
    render_toks_w_weights(topk.indices, topk.values);

In [204]:
tok_id = enc.encode(' bright')[0]
print(enc.decode([tok_id]))

 bright


In [205]:
whitened_embeds = embeds - nnmf.bias.data[None].cpu()
normed_embeds = whitened_embeds / whitened_embeds.norm(dim=-1, keepdim=True)
topk = (normed_embeds @ normed_embeds[tok_id]).topk(k=50)

render_toks_w_weights(topk.indices[1:], topk.values[1:]);

print(topk.values)

tensor([1.0000, 0.6334, 0.5137, 0.4763, 0.4644, 0.4232, 0.4094, 0.3924, 0.3915,
        0.3888, 0.3826, 0.3823, 0.3701, 0.3658, 0.3613, 0.3603, 0.3550, 0.3517,
        0.3495, 0.3468, 0.3408, 0.3315, 0.3293, 0.3173, 0.3151, 0.3106, 0.3080,
        0.3077, 0.3077, 0.3072, 0.3071, 0.3048, 0.3045, 0.3043, 0.3034, 0.3020,
        0.3004, 0.3003, 0.3000, 0.2973, 0.2972, 0.2963, 0.2943, 0.2911, 0.2911,
        0.2904, 0.2899, 0.2886, 0.2877, 0.2859])


In [206]:
topk = codes[tok_id].topk(k=10)

for atom, val in zip(topk.indices, topk.values):
    print(f'Atom {atom.item()}: {val.item()}')
    render_atom(atom)

Atom 2851: 0.7682396769523621


Atom 289: 0.6017424464225769


Atom 25: 0.5098152160644531


Atom 1818: 0.4537302255630493


Atom 1569: 0.43581685423851013


Atom 319: 0.3700118958950043


Atom 3151: 0.2729633152484894


Atom 1262: 0.238998681306839


Atom 6241: 0.22085803747177124


Atom 5576: 0.1721521019935608


In [209]:
# from interp_utils import hist

topk = codes[:,223].topk(k=500)
topk.values
render_toks_w_weights(topk.indices, topk.values);