## Concept Similarity RSA v1.1

Model: Qwen2.5-1.5B

#### Changes made since v1:
- Changed model from GPT-2 to Qwen 
- Replace human similarity matrix with machine-generated data
- Replace mean pooling with last token in subtoken sequence

#### NOTE: This is the second iteration of concept-sim. Later versions can be found in `/notebooks`.


#### Prerequisites: install dependencies, import libraries, set up visualization functions, and load model


In [1]:
%pip install rsatoolbox
%pip install transformer_lens
%pip install circuitsvis
#%pip install numpy==1.26.4
# Install a faster Node version
!curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa

# Import utils

import circuitsvis as cv
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import pandas as pd
import plotly.express as px
from numpy.linalg import norm
import json
import numpy as np
from scipy.stats import pearsonr

from jaxtyping import Float
from functools import partial

# Import transformer_lens

import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix

# set automatic differentiation off to save memory, bc only inference here, no training
# TODO: switch this if training/fine-tuning/interfering
torch.set_grad_enabled(False)

# Plotting helper functions

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)





    [1m[4m Node.js 16.x is no longer actively supported![m

  [1mYou will not receive security or critical stability updates[m for this version.

  You should migrate to a supported version of Node.js as soon as possible.
  Use the installation script that corresponds to the version of Node.js you
  wish to install. e.g.
  
   * [31mhttps://deb.nodesource.com/setup_16.x — Node.js 16 "Gallium" [1m(deprecated)[m
   * [32mhttps://deb.nodesource.com/setup_18.x — Node.js 18 "Hydrogen" (Maintenance)[m
   * [31mhttps://deb.nodesource.com/setup_19.x — Node.js 19 "Nineteen" [1m(deprecated)[m
   * [1m[32mhttps://deb.nodesource.com/setup_20.x — Node.js 20 LTS "Iron" (recommended)[m
   * [32mhttps://deb.nodesource.com/setup_21.x — Node.js 21 "Iron" (current)[m
   


  Please see [1mhttps://github.com/nodejs/Release[m for details about which
  version may be appropriate for you.

  The [32m[1mNodeSource[m Node.js distributions repository contains
  information both about s

Model Load

In [2]:
# Model setup

device = utils.get_device()
model = HookedTransformer.from_pretrained("Qwen2.5-1.5B", device=device)

# Cache all activations (do less to save memory)
gpt2_text = "What is the difference between a pet fish and a fish?"
gpt2_tokens = model.to_tokens(gpt2_text)
print(gpt2_tokens.device)
gpt2_logits, gpt2_cache = model.run_with_cache(gpt2_tokens, remove_batch_dim=True)

config.json:   0%|          | 0.00/684 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]



Loaded pretrained model Qwen2.5-1.5B into HookedTransformer
cpu


#### 1: Create a mapping from word pairs to human similarity ratings

Define a mapping from word pairs to human similarity ratings and store that in *pair_similarities*

In [None]:
def read_pair_similarities(file_name: str) -> dict:
  pair_similarities = {}

  with open(file_name, 'r') as f:
    for line in f:
      pair_1, pair_2, similiarity_score = line.split()
      pair_similarities[(pair_1.split('#', 1)[0], pair_2.split('#', 1)[0])] = float(similiarity_score)

  return pair_similarities

pair_similarities = read_pair_similarities('tree_sim_10_iter.txt')
pair_similarities

{('pine', 'redwood'): 4.0,
 ('pine', 'oak'): 3.5,
 ('pine', 'maple'): 3.5,
 ('pine', 'birch'): 3.6,
 ('pine', 'willow'): 2.8,
 ('redwood', 'oak'): 6.4,
 ('redwood', 'maple'): 3.3,
 ('redwood', 'birch'): 2.6,
 ('redwood', 'willow'): 3.9,
 ('oak', 'maple'): 4.9,
 ('oak', 'birch'): 3.7,
 ('oak', 'willow'): 3.5,
 ('maple', 'birch'): 5.1,
 ('maple', 'willow'): 4.2,
 ('birch', 'willow'): 5.5}

Normalize the values in *pair_similarities* so that all values are in [0,1] and the original scale are maintained. We'll call the new normalized mapping *n_pair_similarities*

In [None]:
def normalize_similarities(pair_similarities: dict) -> dict:
  return {k: round(v/10, 3) for k, v in pair_similarities.items()}

n_pair_similarities = normalize_similarities(pair_similarities)
n_pair_similarities

{('pine', 'redwood'): 0.4,
 ('pine', 'oak'): 0.35,
 ('pine', 'maple'): 0.35,
 ('pine', 'birch'): 0.36,
 ('pine', 'willow'): 0.28,
 ('redwood', 'oak'): 0.64,
 ('redwood', 'maple'): 0.33,
 ('redwood', 'birch'): 0.26,
 ('redwood', 'willow'): 0.39,
 ('oak', 'maple'): 0.49,
 ('oak', 'birch'): 0.37,
 ('oak', 'willow'): 0.35,
 ('maple', 'birch'): 0.51,
 ('maple', 'willow'): 0.42,
 ('birch', 'willow'): 0.55}

Functionality for saving and loading this similarity-to-score mapping to a .json, and vice versa

In [None]:
def encode_dict_to_json(data: dict, filename: str) -> None:
  """
  Encode the pair_similarities so that we can store tuples
  while following JSON formatting by seperating keys with '|'
  """
  encoded_dict = {k1 + '|' + k2: v for (k1,k2), v in data.items()}
  with open(filename, 'w') as f:
    json.dump(encoded_dict, f, indent=4)
    print("Dictionary saved as ", filename)

def decode_json_to_dict(filename: str) -> dict:
  """
  Decode the pair_similarities, removing the '|' from the keys
  and turning the keys back into a tuple format
  """
  with open(filename, 'r') as f:
    data = json.load(f)
  decoded_dict = {tuple(k.split('|')): val for k, val in data.items()}
  return decoded_dict

encode_dict_to_json(n_pair_similarities, 'n_pair_similarities')
decode_json_to_dict("n_pair_similarities")

Dictionary saved as  n_pair_similarities


{('salmon', 'tuna'): 0.43,
 ('salmon', 'haddock'): 0.33,
 ('salmon', 'eel'): 0.34,
 ('salmon', 'goldfish'): 0.1,
 ('salmon', 'koi'): 0.33,
 ('salmon', 'shark'): 0.14,
 ('salmon', 'mackerel'): 0.62,
 ('tuna', 'haddock'): 0.33,
 ('tuna', 'eel'): 0.3,
 ('tuna', 'goldfish'): 0.1,
 ('tuna', 'koi'): 0.11,
 ('tuna', 'shark'): 0.1,
 ('tuna', 'mackerel'): 0.33,
 ('haddock', 'eel'): 0.16,
 ('haddock', 'goldfish'): 0.11,
 ('haddock', 'koi'): 0.14,
 ('haddock', 'shark'): 0.14,
 ('haddock', 'mackerel'): 0.26,
 ('eel', 'goldfish'): 0.1,
 ('eel', 'koi'): 0.16,
 ('eel', 'shark'): 0.12,
 ('eel', 'mackerel'): 0.16,
 ('goldfish', 'koi'): 0.42,
 ('goldfish', 'shark'): 0.1,
 ('goldfish', 'mackerel'): 0.16,
 ('koi', 'shark'): 0.22,
 ('koi', 'mackerel'): 0.11,
 ('shark', 'mackerel'): 0.12}

#### 2: Create an ordered vector of the human similarity ratings, cache activations

In [None]:
ordered_pairs = sorted(n_pair_similarities.keys()) # lexographically sorted by keys

human_sim_vec = [n_pair_similarities[pair] for pair in ordered_pairs]
ovector_mapping = {pair: i for i, pair in enumerate(ordered_pairs)}

print("        Length:", len(human_sim_vec), "pairs")
print("Ordered vector:", human_sim_vec)
print("       Mapping:", ovector_mapping)

        Length: 15 pairs
Ordered vector: [0.55, 0.51, 0.42, 0.37, 0.49, 0.35, 0.36, 0.35, 0.35, 0.4, 0.28, 0.26, 0.33, 0.64, 0.39]
       Mapping: {('birch', 'willow'): 0, ('maple', 'birch'): 1, ('maple', 'willow'): 2, ('oak', 'birch'): 3, ('oak', 'maple'): 4, ('oak', 'willow'): 5, ('pine', 'birch'): 6, ('pine', 'maple'): 7, ('pine', 'oak'): 8, ('pine', 'redwood'): 9, ('pine', 'willow'): 10, ('redwood', 'birch'): 11, ('redwood', 'maple'): 12, ('redwood', 'oak'): 13, ('redwood', 'willow'): 14}


In [None]:
# normalized human similiarity values in the same order as word pairs
human_similarities = human_sim_vec
print("Ordered vector:", human_similarities)
print("Has", len(human_similarities), "values; corresponding to 97 pairs")

Ordered vector: [0.55, 0.51, 0.42, 0.37, 0.49, 0.35, 0.36, 0.35, 0.35, 0.4, 0.28, 0.26, 0.33, 0.64, 0.39]
Has 15 values; corresponding to 97 pairs


In [None]:
# Extraction of all the unique words that exist within all the pairs
items = sorted(set([w for pair in ovector_mapping for w in pair]))
print("All unique items:", items)
print("Has", len(items), "items; < 97 * 2 since some duplicate items")

All unique items: ['birch', 'maple', 'oak', 'pine', 'redwood', 'willow']
Has 6 items; < 97 * 2 since some duplicate items


In [None]:
# Logits aren't used for RSA - don't store logits
# into activation_map to reduce memory usage
activation_maps = {}
for word in items:
  toks = model.to_tokens(" " + word, prepend_bos=False)
  _, cache = model.run_with_cache(toks, remove_batch_dim=True)
  activation_maps[word] = cache

#### 3: Compute cosine similarities for each vector per layer per word

In [None]:
def get_vector_from_cache(cache, name, layer):
  """
  Grabs the activation matrix for a chosen hook and layer from the word's cache,
  returns as NumPy vector, suitable for cosine similarity analysis

  cache: the full set of intermediate activations produced by some model for some word
  name:  {"resid_post", "mlp_out", "attn_out", "resid_pre", "hook_embed"}
  layer: 0-11 (since we're using GPT-2)
  """
  vector = cache[name, layer][-1] # <- last-token pooling
  return vector.detach().float().cpu().numpy()

In [None]:
# Hooks of interest
hooks = (
    [("resid_pre",  l) for l in range(model.cfg.n_layers)] +
    [("attn_out",   l) for l in range(model.cfg.n_layers)] +
    [("mlp_out",    l) for l in range(model.cfg.n_layers)] +
    [("resid_post", l) for l in range(model.cfg.n_layers)]
)

# word_to_hook_vecs is a dictionary of dictionaries:
# Each word in _items_ is a key in word_to_hook_vecs
# which maps to a dictionary containing the hooks and
# the respective (mean-pooled) vectors
#
# word_to_hook_vecs[word] = {
#    'resid_pre' :   [array],
#    'attn_out'  :   [array],
#    'mlp_out'   :   [array],
#    'resid_post':   [array]
# }

word_to_hook_vecs = {}
for word, cache in activation_maps.items():
  activations = {}
  for name, layer in hooks:
    hook = f"{name}@{layer}"
    activations[hook] = get_vector_from_cache(cache, name, layer)
  word_to_hook_vecs[word] = activations

In [None]:
# Make human_vec (just a numpy'd version of
# human_sim_vec for purposes of computation)
human_vec = np.array(human_sim_vec)
human_vec

array([0.55, 0.51, 0.42, 0.37, 0.49, 0.35, 0.36, 0.35, 0.35, 0.4 , 0.28,
       0.26, 0.33, 0.64, 0.39])

In [None]:
# Organized list of activations that we'll use
# for cosine similarity
#['attn_out@0',
# 'attn_out@1',
# 'attn_out@10',
# 'attn_out@11',
# 'attn_out@2',
#     ...
# 'resid_pre@4',
# 'resid_pre@5',
# 'resid_pre@6',
# 'resid_pre@7',
# 'resid_pre@8',
# 'resid_pre@9']
hook_tags = sorted(next(iter(word_to_hook_vecs.values())).keys())

# List of pairs from WordSim
#[('Arafat', 'Jackson'),
# ('Harvard', 'Yale'),
# ('Japanese', 'American'),
#           ...
# ('aluminum', 'metal'),
# ('announcement', 'news'),
# ('asylum', 'madhouse')]
pairs = list(ovector_mapping.keys())

In [None]:
def cosine(u, v):
  nu, nv = norm(u), norm(v)
  return float(np.dot(u, v)/(nu*nv)) if (nu and nv) else np.nan

cosine_by_hook = {}

# Generate machine_sim for each hook
for tag in hook_tags:
  machine_sim = []
  for (w1, w2) in pairs:
    v1 = word_to_hook_vecs[w1][tag]
    v2 = word_to_hook_vecs[w2][tag]
    machine_sim.append(cosine(v1, v2))
  cosine_by_hook[tag] = np.asarray(machine_sim, dtype=float)

# Display result as table
cosine_df = pd.DataFrame(
  {tag: cosine_by_hook[tag] for tag in hook_tags},
  index=pairs
)
display(cosine_df)

Unnamed: 0,attn_out@0,attn_out@1,attn_out@10,attn_out@11,attn_out@12,attn_out@13,attn_out@14,attn_out@15,attn_out@16,attn_out@17,...,resid_pre@25,resid_pre@26,resid_pre@27,resid_pre@3,resid_pre@4,resid_pre@5,resid_pre@6,resid_pre@7,resid_pre@8,resid_pre@9
"(birch, willow)",0.222638,0.523583,0.833642,0.935951,0.984914,0.934264,0.880547,0.942156,0.994949,0.986238,...,0.850123,0.877208,0.910448,0.370993,0.378831,0.476673,0.554773,0.615682,0.637316,0.694555
"(maple, birch)",0.147769,0.436758,0.744685,0.916756,0.984554,0.866869,0.232233,0.867065,0.989408,0.951285,...,0.181567,0.194926,0.313689,0.09896,0.06955,0.10657,0.096681,0.093469,0.108606,0.12756
"(maple, willow)",0.150997,0.494177,0.935467,0.845077,0.992328,0.96805,0.483743,0.89211,0.99407,0.960319,...,0.183391,0.192876,0.311542,0.078303,0.091522,0.113761,0.091118,0.104607,0.113761,0.125141
"(oak, birch)",0.214227,0.480579,0.744516,0.916895,0.984573,0.866679,0.231948,0.866975,0.989374,0.95091,...,0.180732,0.194151,0.288897,0.09889,0.069242,0.106197,0.096493,0.09327,0.108246,0.126986
"(oak, maple)",0.44286,0.740768,0.999915,0.99996,0.999901,0.999879,0.999805,0.999918,0.999904,0.99966,...,0.999954,0.999953,0.898889,0.99995,0.999951,0.999952,0.999953,0.999954,0.999954,0.999954
"(oak, willow)",0.184401,0.475981,0.935401,0.845262,0.992186,0.968051,0.483864,0.892002,0.994111,0.960202,...,0.182663,0.192229,0.288768,0.078362,0.091309,0.113346,0.090864,0.10436,0.113404,0.124658
"(pine, birch)",0.156997,0.489823,0.744242,0.91675,0.984626,0.866509,0.232288,0.867152,0.989387,0.950879,...,0.180084,0.193596,0.251886,0.09858,0.069065,0.10589,0.095966,0.092618,0.107563,0.126278
"(pine, maple)",0.413119,0.738661,0.999886,0.999942,0.999898,0.999857,0.999771,0.999907,0.999877,0.999661,...,0.999919,0.999919,0.862865,0.999917,0.999917,0.999918,0.99992,0.99992,0.999919,0.999919
"(pine, oak)",0.509576,0.758789,0.999915,0.999961,0.999924,0.999915,0.99984,0.999941,0.99992,0.999824,...,0.999953,0.999953,0.910562,0.999951,0.999952,0.999953,0.999954,0.999954,0.999954,0.999954
"(pine, redwood)",0.282021,0.593063,0.827755,0.921653,0.960832,0.790866,-0.267171,0.799596,0.98494,0.927677,...,0.185675,0.19831,0.258865,0.087677,0.090958,0.114672,0.098219,0.118466,0.140941,0.146118


#### 4: Compute RSA (pearson correlation) between human and machine similarity

In [None]:
rsa_scores = {}

for hook, machine_vec in cosine_by_hook.items():
  r, _ = pearsonr(machine_vec, human_vec)
  rsa_scores[hook] = r

top_50 = sorted(rsa_scores.items(), key=lambda x: x[1], reverse=True)[:50]

for i, (hook, score) in enumerate(top_50):
  print(f"{i+1} - {hook}: {score:.4f}")

1 - attn_out@0: 0.2657
2 - resid_pre@0: 0.2159
3 - attn_out@27: 0.2137
4 - resid_post@0: 0.1665
5 - resid_pre@1: 0.1665
6 - attn_out@5: 0.1419
7 - attn_out@11: 0.0635
8 - resid_post@27: 0.0559
9 - attn_out@1: 0.0486
10 - mlp_out@0: 0.0379
11 - attn_out@7: 0.0281
12 - mlp_out@2: 0.0192
13 - attn_out@9: 0.0119
14 - mlp_out@27: 0.0078
15 - mlp_out@4: 0.0074
16 - attn_out@4: 0.0041
17 - mlp_out@8: -0.0059
18 - mlp_out@5: -0.0129
19 - resid_post@2: -0.0135
20 - resid_pre@3: -0.0135
21 - resid_post@26: -0.0190
22 - resid_pre@27: -0.0190
23 - mlp_out@7: -0.0194
24 - mlp_out@21: -0.0268
25 - resid_post@3: -0.0283
26 - resid_pre@4: -0.0283
27 - resid_post@1: -0.0286
28 - resid_pre@2: -0.0286
29 - mlp_out@18: -0.0287
30 - resid_post@5: -0.0292
31 - resid_pre@6: -0.0292
32 - mlp_out@3: -0.0305
33 - mlp_out@17: -0.0305
34 - mlp_out@6: -0.0313
35 - resid_post@6: -0.0315
36 - resid_pre@7: -0.0315
37 - resid_post@8: -0.0329
38 - resid_pre@9: -0.0329
39 - resid_post@19: -0.0329
40 - resid_pre@20: -0.0

#### 5: Results

tbd:
top 50,
visualize,
save info,
3 output files:
- mapping from word pair to human similarity score
- list of all unique words in human data
- last table/mapping of model hooks to r values