# Visualizing the Structure of Embeddings via dimensionality reduction

See `configs/config.yaml` for selected model, layer_idx and other parameters.


Sidenote: We'll denote tensor shapes via suffixes.
- B: Batch
- T: Time / Sequence position / context length
- D: Model embedding dimension

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src.activation_cache import load_labeled_acts

from sklearn.decomposition import PCA

  from .autonotebook import tqdm as notebook_tqdm


## Days of the Week


In [3]:
from src.config import load_config

cfg = load_config("gpt2")
cfg.llm.sequence_aggregation_method = "final"
cfg

Config(
  env=EnvironmentConfig(
    dtype=torch.bfloat16
    device='cuda'
    hf_cache_dir=PosixPath('/home/can/models')
    texts_dir='data/texts'
    tokens_dir='data/tokens'
    activations_dir='data/activations'
    debug=False
  )
  llm=LLMConfig(
    hf_name='openai-community/gpt2'
    layer_idx=8
    batch_size=100
    sequence_aggregation_method='final'
  )
  data=DataConfig(
    name='days_filtered'
  )
  filter=FilterConfig(
    corpus='HuggingFaceFW/fineweb'
    regex_file='days'
    num_occurences=20
  )
)

In [4]:
labels, texts, act_BD = load_labeled_acts(cfg, force_recompute=True)

EPS = 1e-8
act_normalized_BD = act_BD / (act_BD.norm(dim=-1, keepdim=True) + EPS)
act_normalized_BD = act_normalized_BD.cpu().float().numpy()

Caching Activations: 100%|██████████| 2/2 [00:00<00:00,  8.09it/s]


In [5]:
pca = PCA(n_components=3)
act_pca_BD = pca.fit_transform(act_normalized_BD)

In [6]:
import json
import plotly.graph_objs as go

unique_labels = list(dict.fromkeys(labels))
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
label_indices = [label_to_idx[label] for label in labels]

# Create a custom discrete colorscale based on matplotlib's tab10
tab10_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', 
                '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
num_labels = len(unique_labels)
colorscale = [[i / (num_labels - 1) if num_labels > 1 else 0, tab10_colors[i % len(tab10_colors)]] 
              for i in range(num_labels)]

dict_strings = [
    json.dumps({"label": label, "text": text}, ensure_ascii=False)
    for label, text in zip(labels, texts)
]
customdata = list(zip(labels, dict_strings))

fig = go.Figure(
    data=[
        go.Scatter3d(
            x=act_pca_BD[:, 0],
            y=act_pca_BD[:, 1],
            z=act_pca_BD[:, 2],
            mode="markers",
            customdata=customdata,
            hovertemplate="Label: %{customdata[0]}<br>Metadata: %{customdata[1]}<extra></extra>",
            marker=dict(
                size=6,
                color=label_indices,
                colorscale=colorscale,
                opacity=0.8,
                showscale=True,
                colorbar=dict(
                    title="Label",
                    tickmode="array",
                    tickvals=list(range(len(unique_labels))),
                    ticktext=unique_labels,
                ),
            ),
        )
    ]
)
fig.update_layout(
    scene=dict(
        xaxis_title="PC 1",
        yaxis_title="PC 2",
        zaxis_title="PC 3"
    ),
    margin=dict(l=0, r=0, b=0, t=30),
    title="PCA 3D Scatter of Activations"
)
fig.show()
