## Objectives

I want to investigate the following questions:
- How many monolingual features are there? And how much are exclusive for English?
- Are multilingual features focused on related languages (like Spanish, Portuguese, Italian, etc.)? Are there massive non-English multilingual features (i.e., features that identifies texts that are not in English?)

## Setup

In [1]:
import torch as t
import altair as alt
import polars as pl
import umap
import zlib
import base64

from datasets import load_dataset
from transformer_lens import HookedTransformer
from sae_lens import SAE
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

In [2]:
device = "cuda"

lang_map = {
    "eng_Latn": "English",
    "deu_Latn": "German",
    "rus_Cyrl": "Russian",
    "isl_Latn": "Icelandic",
    "spa_Latn": "Spanish",
    "por_Latn": "Portuguese",
    "fra_Latn": "French",
    "zho_Hans": "Chinese",
    "jpn_Jpan": "Japanese",
    "kor_Hang": "Korean",
    "hin_Deva": "Hindi",
    "arb_Arab": "Arabic"
}

# langs = sorted(list(lang_map.values()))
langs = list(lang_map.values())

# alt.data_transformers.enable("vegafusion")


In [3]:
df_pythia = pl.read_json("../results/metrics-pythia.json")
df_gpt2 = pl.read_json("../results/metrics-gpt2.json").with_columns(pl.lit("gpt2").alias("model")).select(df_pythia.columns)

df = (
    pl
    .concat([df_gpt2, df_pythia])
    .with_columns(
        pl.col("layer").str.extract(r"\.(\d+)", 1).str.to_integer(),
        pl.col("lang").replace_strict(lang_map),
    )
)

df.head()

model,lang,layer,latent_acts,latent_count,recon_mse,n_tokens
str,str,i64,list[f64],list[i64],f64,i64
"""gpt2""","""English""",0,"[1.216241, 0.0, … 0.14242]","[7, 0, … 1]",7149.613281,632610816
"""gpt2""","""German""",0,"[0.710374, 0.007594, … 0.0]","[9, 2, … 0]",20682.03125,1353129984
"""gpt2""","""Russian""",0,"[0.0, 1.004786, … 0.422731]","[0, 391, … 48]",160057.734375,3643318272
"""gpt2""","""Icelandic""",0,"[1.842339, 0.00663, … 0.0]","[30, 2, … 0]",35016.292969,1530200064
"""gpt2""","""Spanish""",0,"[6.635237, 0.0, … 0.004174]","[208, 0, … 3]",23529.304688,1255219200


In [4]:
def encode_chart(chart):
    json_string = chart.properties(width="container", height="container").to_json()
    compressed_data = zlib.compress(json_string.encode('utf-8'))
    encoded_data = base64.b64encode(compressed_data).decode('utf-8')
    return encoded_data

## Analysis

### How many tokens were produced for each language?


In [5]:
def plot_lang_token_count(model):
    chart_df = (
        df
        .filter(pl.col("model") == model)
        .group_by("model", "lang")
        .agg(pl.first("n_tokens"))
    )
    
    return alt.Chart(chart_df).mark_bar(
    ).encode(
        x=alt.X("n_tokens", title="Number of tokens").axis(format=".1"),
        y=alt.Y("lang", title="Language").sort("-x"),
    ).properties(
        width=800,
        height=400,
    )

chart_gpt2 = plot_lang_token_count("gpt2")
chart_pythia = plot_lang_token_count("pythia")

print("GPT2:", encode_chart(chart_gpt2))
print("Pythia:", encode_chart(chart_pythia))

alt.vconcat(chart_gpt2, chart_pythia)

GPT2: eJydlU2PmzAQhu/7K5DVY0L8gcHkVlVttx+qqvbQQ1WtDDjgLtgpNtmsVvnvNZAQdtNKkAtgv/Y8M2M883TjeeCVSQtRcbD2QGHt1qxXq53IuZ9LWzSJL/WqX9DNLktpxWpHfQx95P82WoFFayTVaiNzZ+PJjdx4J8XDMOp1K1WjG3MrZF5YpxEIF5fyD5nZolc78eCehw6RccvPAMUr0brczi4jLCDLGIdhmGSQpBhDmMYIUZZkAYkJeGbECGvOhqYZWHs/j76eQnJbS67akMEnXQveJ+KoVDoTZSvlW4vHgrqz+l6olk9QTGgUBRE+yofF/xEfUuE+M5nOpSBKoAsGhsEEyntRV/MDQYQShOOYTUG8KaQSRsxlYIgpYhQzOIFxK12m5hKCCLIgRnhSol7XPJl/FjgKifuv2STEV13bJm+uyBXCiDlOiKdg3tVCpcV8BA0gYxGcgvjIt/yaM0csCihlcRROgHxrjJFXXMIwIAQxPOkSvlV5Kc3sZIUEhwgyNCWM7y5VVyDceVCMYgyH69G9fw11zx2yzqQaVej9uDzzvTSjsZvZaFcK2jINfAReOg42UpRZKw4+DJKVtuwq85emSkTt6Y13seRx263403BX9i23cid6xpEAHsfODbAuSYMV4y5IO7vc/wP+2S1teC4uoUpXUvESvOguxaktdb2IuyJV952t4vX9OWsnKwmvz03l4dizRltvDn8BUuG5QA==
Pythia: eJydlU2PmzAQhu/7K5DVY0KMjcHkVlVttx+qqvbQQ7WqDBhwF+wUm2yiVf77GpIQdtNKLhfAfs08M2N75vHG88ArnVW8YWDtgcqYjV6vVlteMr8UpupSX6jVccEwu6yF4ast8RH0A/+3VhIseiOZkoUorY1HO7LjreAP4+ioGyE71elbLsrKWA1DuLiWf4jcVEd1EA/2eRgQOTPsApCs4b3L/ewygShJI5wSGlNepBENElbgnAUZpYhlEXhmR

### How many monolingual and multilingual features are there?

In [6]:
def get_mono_latent_langs(data: pl.Series):
    latent_count = data.struct.field("latent_count")
    latent_count = t.Tensor(latent_count.to_list()).float().to(device)

    n_tokens = data.struct.field("n_tokens")
    n_tokens = t.Tensor(n_tokens.to_list()).float().to(device)

    latent_count = latent_count / n_tokens[:, None]
    latent_count_norm = latent_count / latent_count.sum(dim=0, keepdim=True) 
    mono_latent_langs = ((latent_count_norm >= 0.5)).nonzero()[:, 0].tolist()
    mono_latent_langs = [langs[i] for i in mono_latent_langs]
    
    return mono_latent_langs

def plot_mono_chart(model):
    
    chart_df = (
        df
        .filter(pl.col("model") == model)
        .group_by("model", "layer")
        .agg(
            pl.struct(["latent_count", "n_tokens"])
            .map_elements(get_mono_latent_langs, return_dtype=pl.List(pl.String))
            .alias("mono_latent_langs")
        )
        .explode("mono_latent_langs")
        .group_by("model", "layer", "mono_latent_langs")
        .agg(pl.len().alias("count"))
        .rename({"mono_latent_langs": "lang"})
    )

    selection = alt.selection_point(fields=['lang'], bind='legend')
    
    return (
        alt.Chart(chart_df).mark_area(
            line=True,
            interpolate="monotone",
        )
        .encode(
            x=alt.X("layer:Q", title="Layer").axis(values=list(range(12 if model == "gpt2" else 6))),
            y=alt.Y("count", title="Count"),
            color=alt.Color("lang:N", title="Language").scale(scheme="rainbow", domain=sorted(langs)),
        )
        .add_params(selection)
        .transform_filter(selection)
        .properties(
            width=900,
            height=400,
        )
    )

chart_gpt2 = plot_mono_chart("gpt2")
chart_pythia = plot_mono_chart("pythia")

print("GPT2:", encode_chart(chart_gpt2))
print("Pythia:", encode_chart(chart_pythia))

alt.vconcat(chart_gpt2, chart_pythia)

GPT2: eJy9XE1v40YMvedXCEKP2USjb+VWLNpuPw5Fe+ihWBQTe2KrlSVXkpNdLPLfK8tfcvyG6eClvWRXkiWK5CP5yBn7y5Xn+V91s6VZaf/O85d9v+7ubm8fzULfLMp+ubm/KZvb3QfGs++qsje3j8lNGNyomz+7pvavtw+ZNfVDuRie8WU4Go4fS/N0PNpd78t602y6D6ZcLPvhWhQE15eXfyvn/XJ3dbz4PPx9HkXMda9PAmq9MttX3p5995AU2VzPCh3NjbkPwjzLTRHdq7RIwyCPY//sIZ3pu9OD/t0D7rzf9+96UGl860291USlR02Gs5Wut4bwf13ruuyW/tmlz6bd3jD9/KqZm2p7w2Ldh/7+/PO1XVycREDct62pZ1AaKUwBYV+3+r6cIWEhJSxGwj6U9bxEsgpGVhgiWd/PzPDfOdaNc5tKkdt+btp+s9iYzry560IVZy7mpFwXBSmQ9WPTGl2/vTCVA2HfmXaFhXFAyQonYRnnM5RKftl0XYmlcZjMUoSQH/SQuiyIzBlxcYxAIsYchROVIpxI1vwfsxeFkyhEycQeborSLImdMxeFEwVxIvgtorQLXYOAtaZb9uJSZZI4CUsYYUWMgPJ+WdoMGXMxgHKlII3CZIGEyTHAwUQlCJZCEFA4CeMExZw1e/0HVEisA1S2TAK3IODSSeDaDTDSclUgevJNvags4igyFGVIOXs7wDEvyE4ES1LpRMUIJCIoqYSSwAbEbkouV2bIlEKu5Hjl0E47llQOKDA5C9pR0lSmXI1J0coiBNLsyYvzXJqh1GxHJVfiFOILQoCnlLQCBZxcwkmCgoiePGLgoJIjewpQ4bznmp4pbqmCCCFTiDqON4eorApxQEVdBHUTTEnOatyaR67nh+MTwW1UjEeQDQnSuGkNpAx2Q1K8Mi0c8U+FWxTGSJzAKynlFAw3OTdT+mVwXmkfonPa5bAnEIzJjb3gEPGVXpUbNUN6KeQvypxhCENBnN

In [14]:
def get_lang_count_var(data: pl.Series):
    latent_count = data.struct.field("latent_count")
    latent_count = t.Tensor(latent_count.to_list()).float().to(device)

    n_tokens = data.struct.field("n_tokens")
    n_tokens = t.Tensor(n_tokens.to_list()).float().to(device)

    latent_count = latent_count / n_tokens[:, None]

    latent_count_norm = latent_count / latent_count.sum(dim=0, keepdim=True)
    latent_count_norm[latent_count_norm.isnan()] = 0
    latent_count_var = latent_count_norm.var(dim=0).mean()

    return latent_count_var.item()

def plot_latent_count_var(model):
    chart_df = (
        df
        .filter(pl.col("model") == model)
        .group_by("layer")
        .agg(
            pl.struct(["latent_count", "n_tokens"])
            .map_elements(get_lang_count_var, return_dtype=pl.Float64)
            .alias("latent_count_var")
        )
    )
    
    return alt.Chart(chart_df).mark_area(
        line=True,
        interpolate="monotone",
    ).encode(
        alt.X("layer:Q", title="Layer").axis(values=list(range(12 if model == "gpt2" else 6))),
        y=alt.Y("latent_count_var", title="Variance"),
    ).properties(
        width=900,
        height=300,
    )


chart_gpt2 = plot_latent_count_var("gpt2")
chart_pythia = plot_latent_count_var("pythia")

print("GPT2:", encode_chart(chart_gpt2))
print("Pythia:", encode_chart(chart_pythia))

alt.vconcat(chart_gpt2, chart_pythia)

GPT2: eJydlc1u4jAQx+88BYp6pMGfY5sn2MOedw+ramWCIW6DwyYOLap493UCBENzIM0h0X/G/s1kPBp/TqbT5KnOcrPVyWKa5N7v6sV8vjcbnW6sz5tlasv5aUFnfS6sN/M9TwlKcfpaly6ZtZCsdGu7CYzPoILeW/Peq5PfW9eUTf3D2E3ug48iNPvq/m1XPj95O+cxvI9diJX2+hrA6a1pU26tz7Baci3RGjIhtFgzRrVmwJdyKbOMCpHcQGrj6yvoMcBi+uec6+WXwtZCe+P836xswnuvq7AKpYgwxRUTGAkQBBSms3jHwbTL6Nl0nD1M5QRAMQwUJApoGKCS8VSJgFApQ7Wl4qAGoHI0lCIQSBCiQgUJAzIAxeMzFRyIpEphRKVQaqiqbDw15Ei5QByYkjzEGKCq8VRQhArFSEgTGBqqKnynAJQiwWQAIkrkUFnFeCrjGGHMqeSIc86GqOgbuRIpuOQASOLQAwNQPr6tJMLAgWPGQtOiISge31eUKhoaKpwWC71Ahs4K9wXovi/9MDEuK1fWRWPvI555+sPWkW6noi4aU0fDpH3QLBI4FiQWNBYsFjwWEAsRCxkLdRP0NgXci5f7ciZra4pVO3dPxent3vqim8c/7+yHXWf+1+gw4b32dm+SScRMDnHJIvzdYX2N9EtXVrvMPBSsP7P8cv10d4627pxtstXV2/UgrfOm2pVtGu3abelKX7pLrKQI+4LdV405Wy6xdWX09bp5P99mUbDJ8T/687xQ
Pythia: eJyVlMuOmzAUhvfzFBHqkiG+4FueoIuu20U1qgx4gltip2CYiUZ599oEiJthkYBE8p9jf+ciH388bTbJl66s1UEmu01SO3fsdtvtoPYy22tX90Wm7fayYLQ+N9qp7UAyBDKY/e6sSdIAKa151XvP+PDK60Grt0Vd/E6b3vbdV6X3tfM+DED62f1DV66+eEfn2X/PY4hKOnkNYORBhZSD9ZkKKTArOa+Ko

In [9]:
(t.Tensor(latent_multi_count) >= 11).nonzero()

tensor([[    5],
        [    9],
        [   11],
        ...,
        [32761],
        [32763],
        [32765]])

### Others

In [16]:
for lang in langs:
    mean = lang_latent_acts[lang] / lang_n_tokens[lang]
    mean[t.isnan(mean)] = 0
    mean_sorted, mean_indices = mean.sort()
    print(lang)
    print(mean_sorted[-5:])
    print(mean_indices[-5:])
    print(lang_n_latents[lang][mean_indices[-5:]])

eng_Latn
tensor([1.0528e-05, 1.3369e-05, 2.1907e-05, 2.9826e-05, 3.6364e-05],
       device='cuda:0')
tensor([  228, 19198, 18831, 20750, 12659], device='cuda:0')
tensor([ 6715.,  5926., 10625.,  1465.,  7970.], device='cuda:0')
deu_Latn
tensor([9.4514e-05, 1.1934e-04, 1.7432e-04, 2.0605e-04, 3.0229e-04],
       device='cuda:0')
tensor([10684, 18336, 24316, 13675, 22131], device='cuda:0')
tensor([15392., 31961., 40504., 46913., 43946.], device='cuda:0')
rus_Cyrl
tensor([0.0002, 0.0002, 0.0002, 0.0003, 0.0003], device='cuda:0')
tensor([10863, 24452, 11531,  8897, 22020], device='cuda:0')
tensor([26906., 57506., 85119., 81226., 90698.], device='cuda:0')
isl_Latn
tensor([8.5571e-05, 9.2705e-05, 1.4611e-04, 1.9067e-04, 2.5000e-04],
       device='cuda:0')
tensor([22131, 13459, 10684, 11554, 24316], device='cuda:0')
tensor([37025.,  8171., 20413., 49716., 48692.], device='cuda:0')
spa_Latn
tensor([7.1881e-05, 1.0783e-04, 1.0886e-04, 2.2747e-04, 3.3821e-04],
       device='cuda:0')
tensor([ 

In [17]:
recon_mse = [lang_recon_mse[lang] / lang_n_tokens[lang] for lang in langs]

alt.Chart(pl.DataFrame(dict(
    lang=langs,
    recon_mse=recon_mse
))).mark_bar().encode(
    x=alt.X("lang"),
    y=alt.Y("recon_mse"),
).properties(
    width=600,
    height=400,
)



## Analysis (legacy)

In [5]:
@t.inference_mode()
def get_features_from_sentences(sentences):
    tokens = model.tokenizer(sentences, return_tensors="pt", padding=True)
    mask = tokens.attention_mask.unsqueeze(-1).repeat(1, 1, sae.cfg.d_sae).to(device)
    n_tokens = mask.sum().item()

    _, cache = model.run_with_cache(sentences)
    inputs = cache[sae.cfg.hook_name]

    feature_acts = sae.encode(inputs)
    recons = sae.decode(feature_acts)
    recon_mse = (inputs - recons).norm(dim=-1)

    return (feature_acts[:, 1:]) * mask, recon_mse, n_tokens

lang_features = dict()
lang_acts = dict()
lang_recon_mse = dict()

for lang, sentences in lang_sentences.items():
    lang_features[lang] = t.zeros(sae.cfg.d_sae).to(device)
    lang_acts[lang] = t.zeros(sae.cfg.d_sae).to(device)
    lang_recon_mse[lang] = 0
    n_tokens = 0

    dataloader = DataLoader(sentences, batch_size=20)
    
    for batch in tqdm(dataloader):
        batch_feature_acts, batch_recon_mse, batch_n_tokens = get_features_from_sentences(batch)
        
        lang_features[lang] += (batch_feature_acts > 0).sum(dim=(0,1))
        lang_acts[lang] += batch_feature_acts.sum(dim=(0,1))
        lang_recon_mse[lang] += batch_recon_mse.sum()
        
        n_tokens += batch_n_tokens
    
    lang_features[lang] /= n_tokens
    lang_acts[lang] /= n_tokens
    lang_recon_mse[lang] /= n_tokens


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 32.54it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:05<00:00, 37.44it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:05<00:00, 35.76it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:09<00:00, 20.37it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 30.63it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:05<00:00, 37.10it/s]
100%|█████████████████████████████████████████████████████

In [22]:
# https://www.kaggle.com/code/danielkorth/visualizing-neural-networks-using-t-sne-and-umap

@t.inference_mode()
def reduce_dec_dim():
    dec = sae.W_dec
    umap_embedding = umap.UMAP(
        n_components=2,
    ).fit_transform(dec.cpu().numpy())
    return umap_embedding

dec_reduced = t.from_numpy(reduce_dec_dim())
dec_reduced, dec_reduced.shape

(tensor([[7.1770, 6.5576],
         [2.2450, 8.5244],
         [0.0622, 6.1438],
         ...,
         [5.5297, 5.7176],
         [8.1164, 8.3765],
         [5.2378, 8.9840]]),
 torch.Size([24576, 2]))

In [23]:
latent_acts = t.stack(list(lang_latent_acts.values()))
max_feature_lang = [langs[idx] for idx in latent_acts.argmax(dim=0)]

latent_count = t.stack(list(lang_latent_count.values())).sum(dim=0).tolist()

df = pl.DataFrame(dict(
    index=range(latent_acts.size(1)),
    x=dec_reduced[:, 0].tolist(),
    y=dec_reduced[:, 1].tolist(),
    max_feature_lang=max_feature_lang,
    latent_count=latent_count,
    latent_multi_count=latent_multi_count,
))

interval = alt.selection_interval()

point_chart = alt.Chart(df).mark_point(filled=True).encode(
    x=alt.X("x"),
    y=alt.Y("y"),  
    size=alt.Size("latent_count"),
    color=alt.Color("latent_multi_count:O").scale(scheme="turbo"),
    # shape=alt.Shape("max_feature_lang"),
    opacity=alt.condition(interval, alt.value(0.8), alt.value(0.1)),
    tooltip=["index", "latent_count", "latent_multi_count", "max_feature_lang"]
).add_selection(
    interval,
).properties(
    width=800,
    height=600,
)

bar_chart = alt.Chart(df).mark_bar().encode(
    x=alt.X("count()"),
    y=alt.Y("max_feature_lang")
).transform_filter(
    interval,
).properties(
    width=800,
    height=100,
)

(point_chart & bar_chart)

  ).add_selection(


In [9]:

active_features = (features > 0).float()
confusion = active_features @ active_features.T

lang_x, lang_y = t.meshgrid(t.arange(0, len(langs)), t.arange(0, len(langs)))
lang_x = [langs[i] for i in lang_x.ravel().tolist()]
lang_y = [langs[i] for i in lang_y.ravel().tolist()]

df = pl.DataFrame(dict(
    lang_x=lang_x,
    lang_y=lang_y,
    value=confusion.ravel().tolist()
))

alt.Chart(df).mark_rect().encode(
    x='lang_x',
    y='lang_y',
    color=alt.Color("value").scale(scheme="purples")
).properties(
    width=400,
    height=400,
)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [11]:
recon_mse = t.stack(list(lang_recon_mse.values())).tolist()

alt.Chart(pl.DataFrame(dict(
    lang=langs,
    recon_mse=recon_mse
))).mark_bar().encode(
    x=alt.X("lang"),
    y=alt.Y("recon_mse"),
).properties(
    width=600,
    height=400,
)


In [31]:
acts = t.stack(list(lang_acts.values()))
act_sorted, act_indices = (acts > 0.0001).sum(dim=0).sort()
act_sorted[-5:], act_indices[-5:]

(tensor([3, 3, 3, 4, 5], device='cuda:0'),
 tensor([ 8015,  8229, 10737, 21868,  7978], device='cuda:0'))

In [36]:
lang_sentences["rus_Cyrl"][42]
lang_sentences["eng_Latn"][42]

'The storm, situated about 645 miles (1040 km) west of the Cape Verde islands, is likely to dissipate before threatening any land areas, forecasters say.'

In [90]:
def get_lang_count(series: pl.Series):
    latent_count = t.Tensor(series.to_list()).float().to(device)
    latent_count_norm = latent_count / latent_count.sum(dim=0, keepdim=True)
    lang_count = ((0.05 <= latent_count_norm)).sum(dim=0).tolist()
    
    return lang_count

chart_df = (
    df.group_by("model", "layer")
    .agg(
        pl.col("latent_count")
        .map_elements(get_lang_count, return_dtype=pl.List(pl.Int64))
        .alias("lang_count")
    )
    .explode("lang_count")
    .group_by("model", "layer", "lang_count")
    .agg(pl.len().alias("count"))
    .filter(pl.col("lang_count") != 0)
)

chart_df.head()

alt.Chart(chart_df).mark_area(
    line=True,
    interpolate="monotone",
).encode(
    x=alt.X("layer:O"),
    y=alt.Y("count"),#.stack("normalize"),
    color=alt.Color("lang_count:N"),
    row=alt.Row("model:N"),
).properties(
    width=900,
    height=300,
)

## Findings

- Some features encode text in foreign (i.e., non-english) texts, like 21868.
- Another example is 7978, which activates for non-english languages that uses the Latin alphabet.
- Other features represent random characters from non-english scripts, probably related to problems in tokenization phase, like 8015,  8229, 10737.