# Examples of how to build concept-based explanations

Load the model and list modules to find where to split it.

In [1]:
from transformers import AutoModelForMaskedLM, AutoTokenizer

model = AutoModelForMaskedLM.from_pretrained("EuroBERT/EuroBERT-210m", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("EuroBERT/EuroBERT-210m")
split_point = "model.layers.10.mlp"

print(list(model.named_children()))

[('model', EuroBertModel(
  (embed_tokens): Embedding(128256, 768, padding_idx=128001)
  (layers): ModuleList(
    (0-11): 12 x EuroBertDecoderLayer(
      (self_attn): EuroBertAttention(
        (q_proj): Linear(in_features=768, out_features=768, bias=False)
        (k_proj): Linear(in_features=768, out_features=768, bias=False)
        (v_proj): Linear(in_features=768, out_features=768, bias=False)
        (o_proj): Linear(in_features=768, out_features=768, bias=False)
      )
      (mlp): EuroBertMLP(
        (gate_proj): Linear(in_features=768, out_features=3072, bias=False)
        (up_proj): Linear(in_features=768, out_features=3072, bias=False)
        (down_proj): Linear(in_features=3072, out_features=768, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): EuroBertRMSNorm((768,), eps=1e-05)
      (post_attention_layernorm): EuroBertRMSNorm((768,), eps=1e-05)
    )
  )
  (norm): EuroBertRMSNorm((768,), eps=1e-05)
  (rotary_emb): EuroBertRotaryEmbedding()
)), (

### Split the model using the `ModelWithSplitPoints` class

In [2]:
from interpreto import ModelWithSplitPoints

splitted_model = ModelWithSplitPoints(
    model_or_repo_id=model,
    tokenizer=tokenizer,
    split_points=split_point,
    device_map="cuda",
    batch_size=64,
)

### Load the dataset and compute activations

In [3]:
from datasets import load_dataset

rotten_tomatoes = load_dataset("cornell-movie-review-data/rotten_tomatoes")["train"]["text"]

activations = splitted_model.get_activations(
    rotten_tomatoes,
    activation_granularity=ModelWithSplitPoints.activation_granularities.WORD,
)

print(activations[split_point].shape)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


torch.Size([187890, 768])


### Create and fit the concept explainer

In [4]:
from interpreto.concepts import ICAConcepts

concept_explainer = ICAConcepts(splitted_model, nb_concepts=50)

concept_explainer.fit(activations)



### Interpret the concepts

In [5]:
from interpreto.concepts.interpretations import TopKInputs

interpretations = concept_explainer.interpret(
    TopKInputs,
    concepts_indices="all",
    source=TopKInputs.sources.LATENT_ACTIVATIONS,
    granularity=TopKInputs.granularities.WORD,
    inputs=rotten_tomatoes,
    latent_activations=activations,
    k=10,
)

In [6]:
for concept_id, words_importance in interpretations.items():
    print(f"Concept {concept_id}: {list(words_importance.keys()) if words_importance is not None else 'None'}")

Concept 0: [' all', ' everything', ' two', ' under', ' beyond', ' below', ' above', ' self', ' sibling', ' southern']
Concept 1: [' see', ' watching', ' watch', ' view', ' seen', ' find', ' seeing', ' admire', ' sees', ' look']
Concept 2: [' to', ' about', ' will', ' in', ' and', ' with', ' of', ' that', ' out', ' feels']
Concept 3: [' doesn', ' does', ' believes', ' thing', ' to', ' into', ' of', ' claims', ' promises', ' cannot']
Concept 4: [' in', '-in', 'in', ' en', ' into', ' em', ' inside', ' within', ' on', ' at']
Concept 5: [' of', 'of', ' de', ' or', ' del', ' da', ' to', ' dos', '-of', ' do']
Concept 6: [' like', ' as', ' be', ' apparently', ' corrupt', ' cult', ' supernatural', ' shot', ' towering', ' skin']
Concept 7: [' one', ' on', ' end', ' upon', ' photographed', '-on', ' next', ' start', ' may', ' first']
Concept 8: [' reason', ' admirable', ' reasons', ' genesis', ' love', ' thing', ' success', ' passion', ' wonderful', ' emerged']
Concept 9: [' on', ' through', ' int

In [None]:
import os

from interpreto.concepts.interpretations.llm_labels import SAMPLING_METHOD, LLMLabels
from interpreto.model_wrapping.llm_interface import OpenAILLM
from interpreto.model_wrapping.model_with_split_points import ActivationGranularity

interpretations = concept_explainer.interpret(
    LLMLabels,
    concepts_indices=[0, 1, 2],
    activation_granularity=ActivationGranularity.SAMPLE,
    llm_interface=OpenAILLM(api_key=os.environ["OPENAI_API_KEY"]),
    sampling_method=SAMPLING_METHOD.TOP,
    inputs=rotten_tomatoes[:20],
    k_context=3,
)

In [None]:
for concept_id, label in interpretations.items():
    print(f"Concept {concept_id}: {label}")