# Examples of how to build concept-based explanations

Load the model and list modules to find where to split it.

In [1]:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("gpt2")

print(list(model.named_children()))
del model

[('transformer', GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)), ('lm_head', Linear(in_features=768, out_features=50257, bias=False))]


### Split the model using the `ModelWithSplitPoints` class

In [3]:
from interpreto.commons import ModelWithSplitPoints

splitted_model = ModelWithSplitPoints(
    model_or_repo_id="gpt2",
    split_points="transformer.h.1.mlp",
    model_autoclass=AutoModelForCausalLM,
)

### Load the dataset and compute activations

In [4]:
from datasets import load_dataset

rotten_tomatoes = load_dataset("cornell-movie-review-data/rotten_tomatoes")['train']['text']

activations = splitted_model.get_activations(rotten_tomatoes[1000:])

print(activations.shape)

### Create and fit the concept explainer

In [7]:
from interpreto.concepts import ICAConcepts

concept_explainer = ICAConcepts(splitted_model, nb_concepts=10)

concept_explainer.fit(activations)



### Interpret the concepts

In [15]:
from interpreto.concepts.interpretations import Granularities, InterpretationSources, TopKInputs

some_words = [
    "excellent", "amazing", "fantastic", "outstanding", "great",
    "perfect", "wonderful", "superb", "love", "loved",
    "satisfied", "recommend", "best", "delightful", "pleasant",
    "enjoyed", "happy", "awesome", "brilliant", "incredible",
    "flawless", "impressive", "top-notch", "well-done", "five stars",
    "worth it", "will buy again", "high quality", "fast delivery", "friendly"
    "terrible", "awful", "bad", "poor", "disappointing",
    "hate", "horrible", "worst", "boring", "rude",
    "unsatisfied", "never again", "waste", "problem", "issue",
    "slow", "cheap", "low quality", "broken", "not recommended",
    "annoying", "frustrating", "one star", "unacceptable", "defective",
    "dirty", "late delivery", "noisy", "difficult", "didn't work"
]

interpretations = concept_explainer.interpret(
    TopKInputs,
    concepts_indices="all",
    source=InterpretationSources.INPUTS,
    granularity=Granularities.TOKENS,
    inputs=some_words,
    k=5,
)

In [16]:
for concept_id, words_importance in interpretations.items():
    print(f"Concept {concept_id}: {words_importance}")

Concept 0: {'onder': 1.7090288400650024, 'ful': 1.6273586750030518, 'ied': 1.359918236732483, 'ed': 1.2922464609146118, 'ude': 1.2332704067230225}
Concept 1: {'top': 13.391587257385254, 'dis': 10.805965423583984, 'didn': 8.963983535766602, 'br': 7.61521053314209, 'worth': 7.578242301940918}
Concept 2: {'ude': -1.0211806297302246, 'acceptable': -1.0793287754058838, 'aw': -1.0856738090515137, 'isy': -1.1359949111938477, 'azing': -1.1628979444503784}
Concept 3: {'high': 3.893566608428955, 'fl': 3.5968098640441895, 'inc': 3.3934872150421143, 'ex': 3.312361717224121, 'fr': 3.283090114593506}
Concept 4: {'Ġdelivery': 1.8142285346984863, 'dis': 1.7381315231323242, 'azing': 1.6744166612625122, 'acceptable': 1.648449420928955, 'Ġit': 1.5403884649276733}
Concept 5: {'friendly': 10.231117248535156, 'hor': 9.910233497619629, 'br': 9.903297424316406, 'aw': 9.84243106842041, 'inc': 9.8255615234375}
Concept 6: {'Ġwork': 1.7911837100982666, "'t": 1.4406126737594604, 'astic': 1.2444143295288086, 'frien