In [1]:
%load_ext autoreload
%autoreload 2

In [18]:
import goodfire
import os
from typing import Dict, Callable, List
import numpy as np

from src.halligan import iterative_parameter_search
from src.halligan.feature_importance import get_feature_importance

client = goodfire.Client(os.environ["GOODFIRE_API_KEY"])

base_variant = goodfire.Variant("meta-llama/Llama-3.3-70B-Instruct")

In [3]:
PROMPT = "If you had a favorite color, what would it be? Answer in one word."

In [4]:
logits = client.chat.logits(
    model=base_variant,
    messages=[
        {"role": "user", "content": PROMPT},
        {"role": "assistant", "content": ""},
    ],
)


In [5]:
def create_variant_from_feature_boosts(feature_boosts: Dict[int, float]) -> goodfire.Variant:
    if not feature_boosts:
        return base_variant
    variant = goodfire.Variant("meta-llama/Llama-3.3-70B-Instruct")
    features = client.features.lookup(list(feature_boosts.keys()), base_variant)
    for feature_id, boost in feature_boosts.items():
        variant.set(features[feature_id], boost)
    return variant

In [6]:
create_variant_from_feature_boosts({0: 1.0, 1: -1.0})

Variant(
   base_model=meta-llama/Llama-3.3-70B-Instruct,
   edits={
      Feature("Explanatory statements about AI's impacts and applications"): 1.0,
      Feature("The assistant wishes enjoyment of food/recipes to users"): -1.0,
   }
   scopes={
   }
)

In [7]:
def variant_answer_logprob(variant, answer, top_k=10):
    logits_topk = client.chat.logits([{"role": "user", "content": PROMPT}, {"role": "assistant", "content": ""}], variant, top_k=top_k)
    logits_color = client.chat.logits([{"role": "user", "content": PROMPT}, {"role": "assistant", "content": ""}], variant, filter_vocabulary=[answer])
    logits_topk_values = np.array(list(logits_topk.logits.values()))
    max_logit = np.max(logits_topk_values)
    exp_logits = np.exp(logits_topk_values - max_logit)
    sum_exp_logits = np.sum(exp_logits)
    log_denominator = np.log(sum_exp_logits) + max_logit
    
    return logits_color.logits[answer] - log_denominator

In [9]:
def get_candidate_parameters(current_parameters: List[str], tried_parameters: List[str], num_features_to_replace: int) -> List[Dict]:
    features = client.features.lookup(current_parameters, base_variant)
    neighboring_features = client.features.neighbors(features.values(), base_variant, top_k=num_features_to_replace + len(tried_parameters))
    neighboring_ids = [str(f.index_in_sae) for f in neighboring_features]
    new_candidate_ids = [id for id in neighboring_ids if id not in tried_parameters]
    new_candidate_ids = new_candidate_ids[:num_features_to_replace]
    return [{"name": id, "type": "range", "bounds": [-1.0, 1.0]} for id in new_candidate_ids]

In [10]:
def objective_function(parameters: Dict) -> float:
    variant = create_variant_from_feature_boosts({int(k): v for k, v in parameters.items()})
    return variant_answer_logprob(variant, "Red")

In [11]:
red_features = client.features.search("color red", base_variant)
red_features

FeatureGroup([
   0: "The color red",
   1: "Explanatory descriptions of the color red and its associations",
   2: "The color yellow",
   3: "Enterprise software products with Red in their name",
   4: "Descriptions of redheaded people and associated descriptive language patterns",
   5: "Descriptions of characters blushing or their faces turning red from emotion",
   6: "The tomato turning red joke and its sexual innuendo",
   7: "CSS color property declarations",
   8: "The word orange/Orange when used as a significant token or keyword",
   9: "CSS color property declarations using hex codes"
])

In [12]:
initial_parameters = [{"name": str(f.index_in_sae), "type": "range", "bounds": [-1.0, 1.0]} for f in red_features[:4]]
initial_parameters

[{'name': '31815', 'type': 'range', 'bounds': [-1.0, 1.0]},
 {'name': '28087', 'type': 'range', 'bounds': [-1.0, 1.0]},
 {'name': '29816', 'type': 'range', 'bounds': [-1.0, 1.0]},
 {'name': '53668', 'type': 'range', 'bounds': [-1.0, 1.0]}]

In [48]:
def display_function(best_importance_scores, current_parameters):
    features = client.features.lookup(list(best_importance_scores.keys()) + [f["name"] for f in current_parameters], base_variant)
    readable_best_importance_scores = {features[int(k)].label: v for k, v in sorted(best_importance_scores.items(), key=lambda x: x[1], reverse=True)}
    readable_current_parameters = [features[int(f["name"])].label for f in current_parameters]
    print("\n\n")
    print(f"Best importance scores: {readable_best_importance_scores}")
    print(f"Current parameters: {readable_current_parameters}")
    print("\n\n")

In [49]:
result = iterative_parameter_search(
    parameter_generator=get_candidate_parameters,
    objective_function=objective_function,
    initial_parameters=initial_parameters,
    num_features_to_replace=2,
    outer_iterations=5,
    inner_iterations=20,
    maximize=True,
    neutral_value=0.0,
    display_function=display_function,
)

2025-02-18 17:00:21,903 - src.halligan.parameter_search - INFO - Starting outer iteration 1/5
2025-02-18 17:00:21,903 - src.halligan.parameter_search - INFO - Starting outer iteration 1/5


[INFO 02-18 17:00:21] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 31815. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:00:21] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 28087. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:00:21] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 29816. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:00:21] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 53668. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter di




Best importance scores: {}
Current parameters: ['The color red', 'Explanatory descriptions of the color red and its associations', 'The color yellow', 'Enterprise software products with Red in their name']





  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
2025-02-18 17:01:12,086 - src.halligan.parameter_search - INFO - Replacing parameters: ['53668', '29816']
2025-02-18 17:01:12,086 - src.halligan.parameter_search - INFO - Replacing parameters: ['53668', '29816']


Importance scores: {'31815': np.float64(4.239456565765613), '28087': np.float64(7.363976116846589), '29816': np.float64(2.191765250831251), '53668': np.float64(0.07053365664174027)}


2025-02-18 17:01:12,638 - src.halligan.parameter_search - INFO - New parameters added: ['63676', '56422']
2025-02-18 17:01:12,638 - src.halligan.parameter_search - INFO - New parameters added: ['63676', '56422']
2025-02-18 17:01:12,639 - src.halligan.parameter_search - INFO - Starting outer iteration 2/5
2025-02-18 17:01:12,639 - src.halligan.parameter_search - INFO - Starting outer iteration 2/5
[INFO 02-18 17:01:12] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 31815. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:01:12] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 28087. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:01:12] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT f




Best importance scores: {'Explanatory descriptions of the color red and its associations': np.float64(7.363976116846589), 'The color red': np.float64(4.239456565765613), 'The color yellow': np.float64(2.191765250831251), 'Enterprise software products with Red in their name': np.float64(0.07053365664174027)}
Current parameters: ['The color red', 'Explanatory descriptions of the color red and its associations', 'Color terms used as descriptive adjectives', 'Discussion of Blue Ocean Strategy business framework']





  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
2025-02-18 17:01:59,823 - src.halligan.parameter_search - INFO - Replacing parameters: ['63676', '31815']
2025-02-18 17:01:59,823 - src.halligan.parameter_search - INFO - Replacing parameters: ['63676', '31815']


Importance scores: {'31815': np.float64(0.3000242275169249), '28087': np.float64(1.4202571638161663), '63676': np.float64(-0.00027838327366502824), '56422': np.float64(0.7992918773849551)}


2025-02-18 17:02:00,178 - src.halligan.parameter_search - INFO - New parameters added: ['61432', '22844']
2025-02-18 17:02:00,178 - src.halligan.parameter_search - INFO - New parameters added: ['61432', '22844']
2025-02-18 17:02:00,178 - src.halligan.parameter_search - INFO - Starting outer iteration 3/5
2025-02-18 17:02:00,178 - src.halligan.parameter_search - INFO - Starting outer iteration 3/5
[INFO 02-18 17:02:00] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 28087. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:02:00] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 56422. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:02:00] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT f




Best importance scores: {'Explanatory descriptions of the color red and its associations': np.float64(7.363976116846589), 'The color red': np.float64(4.239456565765613), 'The color yellow': np.float64(2.191765250831251), 'Discussion of Blue Ocean Strategy business framework': np.float64(0.7992918773849551), 'Enterprise software products with Red in their name': np.float64(0.07053365664174027), 'Color terms used as descriptive adjectives': np.float64(-0.00027838327366502824)}
Current parameters: ['Explanatory descriptions of the color red and its associations', 'Discussion of Blue Ocean Strategy business framework', "Setup phrases in simple jokes, especially 'Because it' in punchlines", 'Descriptions of redheaded people and associated descriptive language patterns']





  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
2025-02-18 17:02:46,446 - src.halligan.parameter_search - INFO - Replacing parameters: ['61432', '56422']
2025-02-18 17:02:46,446 - src.halligan.parameter_search - INFO - Replacing parameters: ['61432', '56422']


Importance scores: {'28087': np.float64(4.8624415744867955), '56422': np.float64(0.032580473787295006), '61432': np.float64(-0.09893582611958518), '22844': np.float64(0.25924905841506174)}


2025-02-18 17:02:46,972 - src.halligan.parameter_search - INFO - New parameters added: ['59234', '47548']
2025-02-18 17:02:46,972 - src.halligan.parameter_search - INFO - New parameters added: ['59234', '47548']
2025-02-18 17:02:46,973 - src.halligan.parameter_search - INFO - Starting outer iteration 4/5
2025-02-18 17:02:46,973 - src.halligan.parameter_search - INFO - Starting outer iteration 4/5
[INFO 02-18 17:02:46] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 28087. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:02:46] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 22844. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:02:46] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT f




Best importance scores: {'Explanatory descriptions of the color red and its associations': np.float64(7.363976116846589), 'The color red': np.float64(4.239456565765613), 'The color yellow': np.float64(2.191765250831251), 'Discussion of Blue Ocean Strategy business framework': np.float64(0.7992918773849551), 'Descriptions of redheaded people and associated descriptive language patterns': np.float64(0.25924905841506174), 'Enterprise software products with Red in their name': np.float64(0.07053365664174027), 'Color terms used as descriptive adjectives': np.float64(-0.00027838327366502824), "Setup phrases in simple jokes, especially 'Because it' in punchlines": np.float64(-0.09893582611958518)}
Current parameters: ['Explanatory descriptions of the color red and its associations', 'Descriptions of redheaded people and associated descriptive language patterns', 'Descriptions of characters blushing or their faces turning red from emotion', 'Making bold aesthetic statements or attention-gra

  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
2025-02-18 17:03:34,796 - src.halligan.parameter_search - INFO - Replacing parameters: ['28087', '22844']
2025-02-18 17:03:34,796 - src.halligan.parameter_search - INFO - Replacing parameters: ['28087', '22844']


Importance scores: {'28087': np.float64(0.45371350927343634), '22844': np.float64(0.5394283943009253), '59234': np.float64(4.188523305977174), '47548': np.float64(7.383063547993656)}


2025-02-18 17:03:35,385 - src.halligan.parameter_search - INFO - New parameters added: ['52155', '26702']
2025-02-18 17:03:35,385 - src.halligan.parameter_search - INFO - New parameters added: ['52155', '26702']
2025-02-18 17:03:35,386 - src.halligan.parameter_search - INFO - Starting outer iteration 5/5
2025-02-18 17:03:35,386 - src.halligan.parameter_search - INFO - Starting outer iteration 5/5
[INFO 02-18 17:03:35] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 59234. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:03:35] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 47548. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:03:35] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT f




Best importance scores: {'Making bold aesthetic statements or attention-grabbing choices': np.float64(7.383063547993656), 'Explanatory descriptions of the color red and its associations': np.float64(7.363976116846589), 'The color red': np.float64(4.239456565765613), 'Descriptions of characters blushing or their faces turning red from emotion': np.float64(4.188523305977174), 'The color yellow': np.float64(2.191765250831251), 'Discussion of Blue Ocean Strategy business framework': np.float64(0.7992918773849551), 'Descriptions of redheaded people and associated descriptive language patterns': np.float64(0.5394283943009253), 'Enterprise software products with Red in their name': np.float64(0.07053365664174027), 'Color terms used as descriptive adjectives': np.float64(-0.00027838327366502824), "Setup phrases in simple jokes, especially 'Because it' in punchlines": np.float64(-0.09893582611958518)}
Current parameters: ['Descriptions of characters blushing or their faces turning red from e

  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
2025-02-18 17:04:21,198 - src.halligan.parameter_search - INFO - Replacing parameters: ['47548', '52155']
2025-02-18 17:04:21,198 - src.halligan.parameter_search - INFO - Replacing parameters: ['47548', '52155']


Importance scores: {'59234': np.float64(2.1234157520965766), '47548': np.float64(-0.12775927153969668), '52155': np.float64(0.9288147605617647), '26702': np.float64(2.0610769207697937)}


2025-02-18 17:04:21,556 - src.halligan.parameter_search - INFO - New parameters added: ['62593', '27270']
2025-02-18 17:04:21,556 - src.halligan.parameter_search - INFO - New parameters added: ['62593', '27270']


In [50]:
result

(AxClient(experiment=Experiment(None)),
 ['26702', '59234'],
 ['53668',
  '29816',
  '63676',
  '31815',
  '61432',
  '56422',
  '28087',
  '22844',
  '47548',
  '52155'],
 {'31815': np.float64(4.239456565765613),
  '28087': np.float64(7.363976116846589),
  '29816': np.float64(2.191765250831251),
  '53668': np.float64(0.07053365664174027),
  '63676': np.float64(-0.00027838327366502824),
  '56422': np.float64(0.7992918773849551),
  '61432': np.float64(-0.09893582611958518),
  '22844': np.float64(0.5394283943009253),
  '59234': np.float64(4.188523305977174),
  '47548': np.float64(7.383063547993656),
  '52155': np.float64(0.9288147605617647),
  '26702': np.float64(2.0610769207697937)})

In [52]:
client.features.lookup(result[1] + result[2], base_variant)

{63676: Feature("Color terms used as descriptive adjectives"),
 26702: Feature("Describing physical or abstract properties and characteristics"),
 31815: Feature("The color red"),
 53668: Feature("Enterprise software products with Red in their name"),
 56422: Feature("Discussion of Blue Ocean Strategy business framework"),
 47548: Feature("Making bold aesthetic statements or attention-grabbing choices"),
 61432: Feature("Setup phrases in simple jokes, especially 'Because it' in punchlines"),
 52155: Feature("Descriptions of cheeks and cheekbones in beauty and cosmetic contexts"),
 59234: Feature("Descriptions of characters blushing or their faces turning red from emotion"),
 29816: Feature("The color yellow"),
 28087: Feature("Explanatory descriptions of the color red and its associations"),
 22844: Feature("Descriptions of redheaded people and associated descriptive language patterns")}

In [51]:
get_feature_importance(result[0], neutral_value=0.0, maximize=True)

{'59234': np.float64(2.446542714951483),
 '47548': np.float64(0.19536769128414466),
 '52155': np.float64(1.2519417233859897),
 '26702': np.float64(2.384203883590482)}

In [53]:
get_feature_importance(result[0], neutral_value=0.0, maximize=True)

{'59234': np.float64(2.446542714946408),
 '47548': np.float64(0.1953676913105582),
 '52155': np.float64(1.2519417234122026),
 '26702': np.float64(2.384203883616096)}