In [1]:
from transformers import GPTNeoXForCausalLM, AutoTokenizer
from datasets import load_dataset
import multiprocessing

  from .autonotebook import tqdm as notebook_tqdm


#### Basic Code Setup 

Loading in model and dataset as well a functionality for iterating over the training data

NOTE: we are just randomly picking a model 

In [2]:
model_size = "1.4b"
checkpoint_step = 143_000

model = GPTNeoXForCausalLM.from_pretrained(
    f"EleutherAI/pythia-{model_size}-deduped",
    revision=f"step{checkpoint_step}",
    cache_dir=f"./pythia-{model_size}-deduped/step{checkpoint_step}",
).to('cuda')

  return self.fget.__get__(instance, owner)()


In [3]:
tokenizer = AutoTokenizer.from_pretrained(
  f"EleutherAI/pythia-{model_size}-deduped",
  revision=f"step{checkpoint_step}",
  cache_dir=f"./pythia-{model_size}-deduped/step{checkpoint_step}",
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
checkpoint_dataset = load_dataset(
    "rdiehlmartinez/pythia-pile-presampled",
    "checkpoints",
    split='train',
    num_proc=multiprocessing.cpu_count()
)

Resolving data files: 100%|██████████| 77/77 [00:01<00:00, 46.45it/s]


In [5]:
# checkpointing steps used in evaluation by pythia 
checkpoint_steps = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1000, ]
checkpoint_steps.extend([3000 + (i * 10000) for i in range(0, 15)])

ORIGINAL_BATCH_SIZE = 1024 

MAX_STEP = 142_999 # Last step in training (used to index final batc)

# NOTE: setting up the data batch sizes 

ordered_steps = list(set(checkpoint_dataset['step']))
ordered_steps.sort()
step_to_start_index = {step: i*ORIGINAL_BATCH_SIZE for i, step in enumerate(ordered_steps)}

def get_data_batch(step):
    """
    Get a data batch for a given step in the training process.
    """

    assert(step in step_to_start_index), f"Step {step} not valid checkpoint step."
    start_idx = step_to_start_index[step]
    end_idx = start_idx + 1024

    return {
        "input_ids": torch.tensor(checkpoint_dataset[start_idx:end_idx]['ids'], device='cuda'),
    }


In [6]:
target_layers_suffix = ["attention.dense", "mlp.dense_4h_to_h", "attention.query_key_value"]

### Computing the Weight Matrix SVD 

In [8]:
def compute_explained_variance(S):
    """ For a matrix of singular values, compute the explained variance."""
    variances = S**2 / S.numel()
    explained_variances = variances / torch.sum(variances)
    return explained_variances

In [9]:
def compute_conditon_number(S):
    """ For a matrix of singular values, compute the condition number."""
    return S[0] / S[-1]

In [10]:
from torch.linalg import svd
import torch

from collections import defaultdict

from tabulate import tabulate


In [11]:
de_embedding_matrix = None 
for name, module in model.named_modules():
    if name == "embed_out":
        de_embedding_matrix = module.weight

In [25]:
projection_matrices_dict = dict()

for name, param in model.named_parameters():
    # only do this for the weight matrix of the target_layers_suffix
    if any(suff_name in name for suff_name in target_layers_suffix) and "weight" in name:
        projection_matrices_dict[name] = param

In [26]:
projection_matrices_dict.keys()

dict_keys(['gpt_neox.layers.0.attention.query_key_value.weight', 'gpt_neox.layers.0.attention.dense.weight', 'gpt_neox.layers.0.mlp.dense_4h_to_h.weight', 'gpt_neox.layers.1.attention.query_key_value.weight', 'gpt_neox.layers.1.attention.dense.weight', 'gpt_neox.layers.1.mlp.dense_4h_to_h.weight', 'gpt_neox.layers.2.attention.query_key_value.weight', 'gpt_neox.layers.2.attention.dense.weight', 'gpt_neox.layers.2.mlp.dense_4h_to_h.weight', 'gpt_neox.layers.3.attention.query_key_value.weight', 'gpt_neox.layers.3.attention.dense.weight', 'gpt_neox.layers.3.mlp.dense_4h_to_h.weight', 'gpt_neox.layers.4.attention.query_key_value.weight', 'gpt_neox.layers.4.attention.dense.weight', 'gpt_neox.layers.4.mlp.dense_4h_to_h.weight', 'gpt_neox.layers.5.attention.query_key_value.weight', 'gpt_neox.layers.5.attention.dense.weight', 'gpt_neox.layers.5.mlp.dense_4h_to_h.weight', 'gpt_neox.layers.6.attention.query_key_value.weight', 'gpt_neox.layers.6.attention.dense.weight', 'gpt_neox.layers.6.mlp.dens

In [27]:
for layer_num in range(model.config.num_hidden_layers):
    qkv_projection = projection_matrices_dict[f"gpt_neox.layers.{layer_num}.attention.query_key_value.weight"]
    value_projection = qkv_projection[-qkv_projection.shape[0]//3:,]
    output_projection = projection_matrices_dict[f"gpt_neox.layers.{layer_num}.attention.dense.weight"]


    ov_projection =  output_projection @ value_projection 

    projection_matrices_dict[f"gpt_neox.layers.{layer_num}.output_value_projection"] = ov_projection

In [28]:
# remove attention.dense and attention.query_key_value
projection_matrices_dict = {k: v for k, v in projection_matrices_dict.items() if "attention.query_key_value" not in k }

# sort dictionary by layer number
projection_matrices_dict = dict(sorted(projection_matrices_dict.items(), key=lambda x: int(x[0].split(".")[2])))

In [30]:
NUM_TOP_SINGULAR_VECTORS = 10
NUM_TOKENS = 20 

for global_idx, (proj_name, proj_matrix) in enumerate(projection_matrices_dict.items()):
    print("Analysis for projection matrix: ", proj_name)

    U, S, Vh = svd(proj_matrix, full_matrices=False)

    explained_variances = compute_explained_variance(S)
    print("Explained variance: ", explained_variances[:10])
    condition_number = compute_conditon_number(S)
    print("Condition number: ", condition_number)

    results = []

    for vec_idx in range(NUM_TOP_SINGULAR_VECTORS):
        row = [f"Top singular vector {vec_idx}"]
        
        if "dense_4h_to_h" == proj_name:
            singular_vec = Vh[vec_idx, :]
        else:
            singular_vec = U[:, vec_idx]

        logits = de_embedding_matrix @ singular_vec.T
        top_indices = torch.topk(logits, NUM_TOKENS, dim=0).indices

        tokens = [tokenizer.decode(tok_idx) for tok_idx in top_indices]
        row.extend(tokens)
        results.append(row)

    # Printing the results as a table
    print(tabulate(results, headers=['Vector Index'] + [f"Token {i+1}" for i in range(NUM_TOKENS)]))

    print("\n\n")


Analysis for projection matrix:  gpt_neox.layers.0.attention.dense.weight
Explained variance:  tensor([0.0326, 0.0091, 0.0079, 0.0065, 0.0063, 0.0060, 0.0059, 0.0056, 0.0054,
        0.0054], device='cuda:0', grad_fn=<SliceBackward0>)
Condition number:  tensor(61296.4375, device='cuda:0', grad_fn=<DivBackward0>)
Vector Index           Token 1        Token 2     Token 3                           Token 4    Token 5    Token 6    Token 7    Token 8     Token 9    Token 10    Token 11    Token 12       Token 13                                       Token 14    Token 15       Token 16    Token 17      Token 18    Token 19    Token 20
---------------------  -------------  ----------  --------------------------------  ---------  ---------  ---------  ---------  ----------  ---------  ----------  ----------  -------------  ---------------------------------------------  ----------  -------------  ----------  ------------  ----------  ----------  ----------
Top singular vector 0  CURI           

### Computing the Grad Matrix SVD 

In [177]:
target_layers_suffix = ["attention.dense", "mlp.dense_4h_to_h", ]

In [161]:
REDUCED_BATCH_SIZE = 128 
import gc


def forward_pass(model, batch):
    """
    Perform a forward pass of the model on a given batch of data; assumes that the model 
    has hooks setup to save the hidden states at each layer.
    """

    batch_index = 0

    batch_size = 1
    static_batch_size = None # NOTE: static_batch_size is only set when batch size is reduced

    while batch_index < REDUCED_BATCH_SIZE:

        try:
            if static_batch_size is None:
                _batch_size = batch_size
            else: 
                # NOTE: reached when we've run out of memory and have reduced the batch size
                _batch_size = static_batch_size

            batch_end_index = min(batch_index + _batch_size, REDUCED_BATCH_SIZE)

            _inputs = batch['input_ids'][batch_index : batch_end_index]
            if 'labels' in batch: 
                # NOTE: If labels are present, then we are iterating over the gradient batches 
                _labels = batch['labels'][batch_index : batch_end_index]
            else:
                _labels = None 

            if _labels is None:
                # we can throw away the outputs, we are only interested in the hidden states
                with torch.no_grad():
                    _ = model(_inputs)

            else: 
                # NOTE: we are performing the forward and backward passes to get the gradients 
                _outputs = model(_inputs, labels=_labels)

                try: 
                    # TODO: test whether the graidnet losses are what is expected
                    _outputs['loss'].backward()
                except: 
                    # NOTE - can't figure out how often we'll have an issue in the backward call 
                    # so just exit
                    raise Exception("Error in backward pass")

        except RuntimeError:
            # NOTE: Exception is thrown when the batch size is too large for the GPU

            if batch_size == 1:
                raise Exception("Batch size of 1 is too large for the GPU")

            _batch_size //= 2
            static_batch_size = _batch_size

            gc.collect()
            torch.cuda.empty_cache()

            continue

        batch_index = batch_end_index

        if static_batch_size is None:
            batch_size *= 2

In [158]:
data_batch = get_data_batch(142_999)

In [181]:
tokenizer.decode(data_batch['input_ids'][0])

', somebody who’s not very bright, kind of the court jester. But in Native American mythology, the trickster is somebody who plays practical jokes, but he or she may also be the person who brings fire to the human species.\n\nYou can’t just walk in and say, okay, take me to your shaman. I can’t say, oh well, this trip makes 37 shamans that I’ve worked with, because there are people that I’ve worked with who have yet to reveal themselves to be shaman,and there’s people I’ve worked with who claim they’re shamans who I truly believe are not.\n\nMiranda: When did you meet the Jaguar Shaman?Mark: The old Jaguar Shaman was one of my first mentors down here, who was described as the greatest healer there. In December ‘82, civil war broke out in Suriname. They closed the borders of the country, no Americans, nobody could leave. The leftists were claiming the CIA was fomenting an anti-leftist coup and they were looking for American troublemakers. Well we know for a fact that was one of the few 

In [160]:
grad_batch = { 
    "labels": data_batch["input_ids"].clone().detach(), 
    **data_batch
} 

In [164]:
forward_pass(model, grad_batch)

In [179]:
projection_matrices_dict = dict()

for name, param in model.named_parameters():
    # only do this for the weight matrix of the target_layers_suffix
    if any(suff_name in name for suff_name in target_layers_suffix) and "weight" in name:
        projection_matrices_dict[name] = param

In [184]:
NUM_TOP_SINGULAR_VECTORS = 10
NUM_TOKENS = 20 

for global_idx, (proj_name, proj_matrix) in enumerate(projection_matrices_dict.items()):
    print("Analysis for projection matrix: ", proj_name)

    U, S, Vh = svd(proj_matrix.grad, full_matrices=False)

    explained_variances = compute_explained_variance(S)
    condition_number = compute_conditon_number(S)
    print("Explained variance: ", explained_variances[:10])
    print("Condition number: ", condition_number)

    results = []

    for vec_idx in range(NUM_TOP_SINGULAR_VECTORS):
        row = [f"Top singular vector {vec_idx}"]
        
        if "dense_4h_to_h" == proj_name:
            singular_vec = Vh[vec_idx, :]
        else:
            singular_vec = U[:, vec_idx]

        logits = de_embedding_matrix @ singular_vec.T
        top_indices = torch.topk(logits, NUM_TOKENS, dim=0).indices

        tokens = [tokenizer.decode(tok_idx) for tok_idx in top_indices]
        row.extend(tokens)
        results.append(row)

    # Printing the results as a table
    print(tabulate(results, headers=['Vector Index'] + [f"Token {i+1}" for i in range(NUM_TOKENS)]))

    print("\n\n")


Analysis for projection matrix:  gpt_neox.layers.0.attention.dense.weight
Explained variance:  tensor([0.1732, 0.0742, 0.0374, 0.0332, 0.0244, 0.0178, 0.0148, 0.0124, 0.0107,
        0.0099], device='cuda:0')
Condition number:  tensor(5.5590e+08, device='cuda:0')
Vector Index           Token 1     Token 2       Token 3    Token 4        Token 5    Token 6    Token 7     Token 8        Token 9     Token 10    Token 11    Token 12    Token 13    Token 14    Token 15    Token 16    Token 17      Token 18    Token 19    Token 20
---------------------  ----------  ------------  ---------  -------------  ---------  ---------  ----------  -------------  ----------  ----------  ----------  ----------  ----------  ----------  ----------  ----------  ------------  ----------  ----------  ----------
Top singular vector 0  ames        orus          thereto    ised           iflu       (_         oured       orche          mad         Authority   hell        CONTROL     ymphony     ени         hed 