In [11]:
import torch
import transformer_lens

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

model = transformer_lens.HookedTransformer.from_pretrained("ai-forever/mGPT", device=device)
print(model.cfg.device)


Loaded pretrained model ai-forever/mGPT into HookedTransformer
mps


In [46]:
print(model.cfg)

HookedTransformerConfig:
{'NTK_by_parts_factor': 8.0,
 'NTK_by_parts_high_freq_factor': 4.0,
 'NTK_by_parts_low_freq_factor': 1.0,
 'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': np.float64(11.313708498984761),
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 128,
 'd_mlp': 8192,
 'd_model': 2048,
 'd_vocab': 100000,
 'd_vocab_out': 100000,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': device(type='mps'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': np.float64(0.017677669529663688),
 'load_in_4bit': False,
 'model_name': 'mGPT',
 'n_ctx': 2048,
 'n_devices': 1,
 'n_heads': 16,
 'n_key_value_heads': None,
 'n_layers': 24,
 'n_params': 1207959552,
 'normalization_typ

In [51]:
from typing import List

from transformer_lens import ActivationCache


def collect_activation_cache(data: List[dict[str, str]]):
    activation_cache: dict[str, List[ActivationCache]] = {}
    for entry in data:
        for language, text in entry.items():
            if language not in activation_cache:
                activation_cache[language] = []

            with torch.no_grad():
                tokens = model.to_tokens(text)
                logits, cache = model.run_with_cache(tokens)
                activation_cache[language].append(cache)

    return activation_cache


In [58]:
import numpy as np


def collect_hidden_space_by_language(activation_cache: dict[str, List[ActivationCache]]):
    hidden_space_for_language = {}

    for language, language_caches in activation_cache.items():
        # d_model, n_prompts, n_layers
        hidden_space_for_language_by_layer = np.zeros(
            (model.cfg.d_model, len(language_caches), model.cfg.n_layers)
        )

        for cache_i, cache in enumerate(language_caches):
            # layer, batch, pos, d_model
            accum_resid = cache.accumulated_resid(apply_ln=True)
            hidden_space_for_language_by_layer[:, cache_i, :] = accum_resid[1:, 0, -1, :].cpu().numpy().T

        hidden_space_for_language[language] = hidden_space_for_language_by_layer

In [59]:
toy_data = [{"EN": "Hello", "RU": "Здравствуйте"}]
activation_cache = collect_activation_cache(toy_data)
hidden_space_by_language = collect_hidden_space_by_language(activation_cache)
