In [1]:
import sys
sys.path.append("..")

Fine tune a pretrained DistilBert model on dataset (`conll2003`) for named entity recognition. This will be our basic starting model.

In [None]:
from src.models.train_model import train_model
from src.features.build_features import load_token_class_dataset

dataset = load_token_class_dataset()
model_trainer = train_model()

In [None]:
from src.models.evaluate_model import evaluate
from src.visualization.visualize import get_basic_model

# Train and evaluate a basic model on pretrained DistilBert
basic_model = get_basic_model(model_trainer)
evaluate(basic_model, dataset["validation"])

Identify concept neurons in the basic model associated with location names.

In [None]:
from src import basic_model_path, basic_activations_path
from src.visualization.analyse_model import analyse_model

# Identify neurons in the basic model to ablate
basic_analyser = analyse_model(basic_model_path, basic_activations_path)

In [None]:
neurons_to_prune = basic_analyser.identify_concept_neurons()

In [None]:
basic_analyser.show_top_words(neurons_to_prune)

Prune the identified concept neurons from the model by setting their weights to zero.

In [None]:
from src import basic_model_path, pruned_model_path
from src.models.prune_model import prune_model
from src.models.evaluate_model import evaluate
from src.visualization.visualize import get_pruned_model

# If the model has been saved
pruned_model = get_pruned_model(pruned_model_path, model_trainer)

# Ablate neurons
# pruned_model = prune_model(basic_model_path, model_trainer, neurons_to_prune)
# pruned_model.save_pretrained(pruned_model_path)
evaluate(pruned_model, dataset["validation"])

Retrain the pruned model until performance recovers. Examine:

- In which neurons does the concept of location names reappear in the pruned model?
- Is there any relation between the new concept (location names) and what concepts originally existed in these neurons?

In [None]:
from src import pruned_model_path, retrained_model_path
from src.visualization.visualize import get_retrained_model

# Retrain the pruned model
retrained_model = get_retrained_model(retrained_model_path, pruned_model_path, model_trainer)

In [None]:
from src import retrained_model_path, retrained_activations_path
from src.visualization.analyse_model import analyse_model

# Examine the retrained model for concepts and compare with old model
retrained_analyser = analyse_model(retrained_model_path, retrained_activations_path)

In [None]:
new_concept_neurons = retrained_analyser.identify_concept_neurons()

In [None]:
retrained_analyser.show_top_words(new_concept_neurons)
basic_analyser.show_top_words(new_concept_neurons)