In [1]:
%load_ext autoreload
%autoreload 2

In [96]:
import goodfire
import os
from typing import Dict, Callable, List
import numpy as np
import torch

from src.halligan import iterative_parameter_search
from src.halligan.feature_importance import get_feature_importance

client = goodfire.Client(os.environ["GOODFIRE_API_KEY"])

base_variant = goodfire.Variant("meta-llama/Llama-3.3-70B-Instruct")

In [3]:
PROMPT = "If you had a favorite color, what would it be? Answer in one word."

In [4]:
logits = client.chat.logits(
    model=base_variant,
    messages=[
        {"role": "user", "content": PROMPT},
        {"role": "assistant", "content": ""},
    ],
)


In [5]:
def create_variant_from_feature_boosts(feature_boosts: Dict[int, float]) -> goodfire.Variant:
    if not feature_boosts:
        return base_variant
    variant = goodfire.Variant("meta-llama/Llama-3.3-70B-Instruct")
    features = client.features.lookup(list(feature_boosts.keys()), base_variant)
    for feature_id, boost in feature_boosts.items():
        variant.set(features[feature_id], boost)
    return variant

In [6]:
create_variant_from_feature_boosts({0: 1.0, 1: -1.0})

Variant(
   base_model=meta-llama/Llama-3.3-70B-Instruct,
   edits={
      Feature("Explanatory statements about AI's impacts and applications"): 1.0,
      Feature("The assistant wishes enjoyment of food/recipes to users"): -1.0,
   }
   scopes={
   }
)

In [7]:
def variant_answer_logprob(variant, answer, top_k=10):
    logits_topk = client.chat.logits([{"role": "user", "content": PROMPT}, {"role": "assistant", "content": ""}], variant, top_k=top_k)
    logits_color = client.chat.logits([{"role": "user", "content": PROMPT}, {"role": "assistant", "content": ""}], variant, filter_vocabulary=[answer])
    logits_topk_values = np.array(list(logits_topk.logits.values()))
    max_logit = np.max(logits_topk_values)
    exp_logits = np.exp(logits_topk_values - max_logit)
    sum_exp_logits = np.sum(exp_logits)
    log_denominator = np.log(sum_exp_logits) + max_logit
    
    return logits_color.logits[answer] - log_denominator

In [8]:
def get_candidate_parameters(current_parameters: List[str], tried_parameters: List[str], num_features_to_replace: int) -> List[Dict]:
    features = client.features.lookup(current_parameters, base_variant)
    neighboring_features = client.features.neighbors(features.values(), base_variant, top_k=num_features_to_replace + len(tried_parameters))
    neighboring_ids = [str(f.index_in_sae) for f in neighboring_features]
    new_candidate_ids = [id for id in neighboring_ids if id not in tried_parameters]
    new_candidate_ids = new_candidate_ids[:num_features_to_replace]
    return [{"name": id, "type": "range", "bounds": [-1.0, 1.0]} for id in new_candidate_ids]

In [9]:
def objective_function(parameters: Dict) -> float:
    variant = create_variant_from_feature_boosts({int(k): v for k, v in parameters.items()})
    return variant_answer_logprob(variant, "Red")

In [10]:
red_features = client.features.search("color red", base_variant)
red_features

FeatureGroup([
   0: "The color red",
   1: "Explanatory descriptions of the color red and its associations",
   2: "The color yellow",
   3: "Enterprise software products with Red in their name",
   4: "Descriptions of redheaded people and associated descriptive language patterns",
   5: "Descriptions of characters blushing or their faces turning red from emotion",
   6: "The tomato turning red joke and its sexual innuendo",
   7: "CSS color property declarations",
   8: "The word orange/Orange when used as a significant token or keyword",
   9: "CSS color property declarations using hex codes"
])

In [11]:
initial_parameters = [{"name": str(f.index_in_sae), "type": "range", "bounds": [-1.0, 1.0]} for f in red_features[:4]]
initial_parameters

[{'name': '31815', 'type': 'range', 'bounds': [-1.0, 1.0]},
 {'name': '28087', 'type': 'range', 'bounds': [-1.0, 1.0]},
 {'name': '29816', 'type': 'range', 'bounds': [-1.0, 1.0]},
 {'name': '53668', 'type': 'range', 'bounds': [-1.0, 1.0]}]

In [12]:
def display_function(best_importance_scores, current_parameters):
    features = client.features.lookup(list(best_importance_scores.keys()) + [f["name"] for f in current_parameters], base_variant)
    readable_best_importance_scores = {features[int(k)].label: v for k, v in sorted(best_importance_scores.items(), key=lambda x: x[1], reverse=True)}
    readable_current_parameters = [features[int(f["name"])].label for f in current_parameters]
    print("\n\n")
    print(f"Best importance scores: {readable_best_importance_scores}")
    print(f"Current parameters: {readable_current_parameters}")
    print("\n\n")

In [14]:
result = iterative_parameter_search(
    parameter_generator=get_candidate_parameters,
    objective_function=objective_function,
    initial_parameters=initial_parameters,
    n_best_known_params=2,
    n_new_candidate_params=2,
    outer_iterations=5,
    inner_iterations=20,
    maximize=True,
    neutral_value=0.0,
    display_function=display_function,
)

2025-02-18 17:11:22,083 - src.halligan.parameter_search - INFO - Starting outer iteration 1/5
[INFO 02-18 17:11:22] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 31815. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:11:22] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 28087. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:11:22] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 29816. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:11:22] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 53668. If that is not the expected value t




Best importance scores: {}
Current parameters: ['The color red', 'Explanatory descriptions of the color red and its associations', 'The color yellow', 'Enterprise software products with Red in their name']





  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))


Importance scores: {'31815': np.float64(3.027441981891481), '28087': np.float64(2.2739584793593424), '29816': np.float64(-0.19167260380625883), '53668': np.float64(0.03420946299148486)}


2025-02-18 17:12:13,806 - src.halligan.parameter_search - INFO - Retained parameters: ['31815', '28087']
2025-02-18 17:12:13,808 - src.halligan.parameter_search - INFO - New parameters added: ['63676', '56422']
2025-02-18 17:12:13,808 - src.halligan.parameter_search - INFO - Starting outer iteration 2/5
[INFO 02-18 17:12:13] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 31815. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:12:13] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 28087. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:12:13] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 63676. If that is not the expected value type, you can explicitly specify 'value_t




Best importance scores: {'The color red': np.float64(3.027441981891481), 'Explanatory descriptions of the color red and its associations': np.float64(2.2739584793593424), 'Enterprise software products with Red in their name': np.float64(0.03420946299148486), 'The color yellow': np.float64(-0.19167260380625883)}
Current parameters: ['The color red', 'Explanatory descriptions of the color red and its associations', 'Color terms used as descriptive adjectives', 'Discussion of Blue Ocean Strategy business framework']





  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))


Importance scores: {'31815': np.float64(1.3910348699795576), '28087': np.float64(2.7019012342540787), '63676': np.float64(0.08467533303854768), '56422': np.float64(-0.054502441706694604)}


2025-02-18 17:13:02,201 - src.halligan.parameter_search - INFO - Retained parameters: ['31815', '28087']
2025-02-18 17:13:02,201 - src.halligan.parameter_search - INFO - New parameters added: ['1256', '22844']
2025-02-18 17:13:02,202 - src.halligan.parameter_search - INFO - Starting outer iteration 3/5
[INFO 02-18 17:13:02] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 31815. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:13:02] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 28087. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:13:02] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 1256. If that is not the expected value type, you can explicitly specify 'value_typ




Best importance scores: {'The color red': np.float64(3.027441981891481), 'Explanatory descriptions of the color red and its associations': np.float64(2.7019012342540787), 'Color terms used as descriptive adjectives': np.float64(0.08467533303854768), 'Enterprise software products with Red in their name': np.float64(0.03420946299148486), 'Discussion of Blue Ocean Strategy business framework': np.float64(-0.054502441706694604), 'The color yellow': np.float64(-0.19167260380625883)}
Current parameters: ['The color red', 'Explanatory descriptions of the color red and its associations', 'The word pink appearing in text', 'Descriptions of redheaded people and associated descriptive language patterns']





  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))


Importance scores: {'31815': np.float64(-0.23682635802807184), '28087': np.float64(5.307829172437381), '1256': np.float64(3.627944473216731), '22844': np.float64(1.6746537345537398)}


2025-02-18 17:13:48,851 - src.halligan.parameter_search - INFO - Retained parameters: ['28087', '1256']
2025-02-18 17:13:48,851 - src.halligan.parameter_search - INFO - New parameters added: ['59234', '61432']
2025-02-18 17:13:48,852 - src.halligan.parameter_search - INFO - Starting outer iteration 4/5
[INFO 02-18 17:13:48] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 28087. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:13:48] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 1256. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:13:48] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 59234. If that is not the expected value type, you can explicitly specify 'value_typ




Best importance scores: {'Explanatory descriptions of the color red and its associations': np.float64(5.307829172437381), 'The word pink appearing in text': np.float64(3.627944473216731), 'The color red': np.float64(3.027441981891481), 'Descriptions of redheaded people and associated descriptive language patterns': np.float64(1.6746537345537398), 'Color terms used as descriptive adjectives': np.float64(0.08467533303854768), 'Enterprise software products with Red in their name': np.float64(0.03420946299148486), 'Discussion of Blue Ocean Strategy business framework': np.float64(-0.054502441706694604), 'The color yellow': np.float64(-0.19167260380625883)}
Current parameters: ['Explanatory descriptions of the color red and its associations', 'The word pink appearing in text', 'Descriptions of characters blushing or their faces turning red from emotion', "Setup phrases in simple jokes, especially 'Because it' in punchlines"]





  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))


Importance scores: {'28087': np.float64(11.316747386404476), '1256': np.float64(0.036396728423003566), '59234': np.float64(4.927358803301345), '61432': np.float64(0.3297991792932562)}


2025-02-18 17:14:35,287 - src.halligan.parameter_search - INFO - Retained parameters: ['28087', '59234']
2025-02-18 17:14:35,287 - src.halligan.parameter_search - INFO - New parameters added: ['52155', '51771']
2025-02-18 17:14:35,288 - src.halligan.parameter_search - INFO - Starting outer iteration 5/5
[INFO 02-18 17:14:35] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 28087. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:14:35] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 59234. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:14:35] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 52155. If that is not the expected value type, you can explicitly specify 'value_t




Best importance scores: {'Explanatory descriptions of the color red and its associations': np.float64(11.316747386404476), 'Descriptions of characters blushing or their faces turning red from emotion': np.float64(4.927358803301345), 'The word pink appearing in text': np.float64(3.627944473216731), 'The color red': np.float64(3.027441981891481), 'Descriptions of redheaded people and associated descriptive language patterns': np.float64(1.6746537345537398), "Setup phrases in simple jokes, especially 'Because it' in punchlines": np.float64(0.3297991792932562), 'Color terms used as descriptive adjectives': np.float64(0.08467533303854768), 'Enterprise software products with Red in their name': np.float64(0.03420946299148486), 'Discussion of Blue Ocean Strategy business framework': np.float64(-0.054502441706694604), 'The color yellow': np.float64(-0.19167260380625883)}
Current parameters: ['Explanatory descriptions of the color red and its associations', 'Descriptions of characters blushi

  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))


Importance scores: {'28087': np.float64(-1.7882960541582984), '59234': np.float64(-0.8717733132995378), '52155': np.float64(-2.1855158349514756), '51771': np.float64(-2.164001047829384)}


2025-02-18 17:15:21,664 - src.halligan.parameter_search - INFO - Retained parameters: ['28087', '59234']
2025-02-18 17:15:21,665 - src.halligan.parameter_search - INFO - New parameters added: ['62593', '14735']


In [15]:
result

(AxClient(experiment=Experiment(None)),
 ['28087', '59234'],
 ['29816',
  '53668',
  '63676',
  '56422',
  '31815',
  '22844',
  '1256',
  '61432',
  '52155',
  '51771'],
 {'31815': np.float64(3.027441981891481),
  '28087': np.float64(11.316747386404476),
  '29816': np.float64(-0.19167260380625883),
  '53668': np.float64(0.03420946299148486),
  '63676': np.float64(0.08467533303854768),
  '56422': np.float64(-0.054502441706694604),
  '1256': np.float64(3.627944473216731),
  '22844': np.float64(1.6746537345537398),
  '59234': np.float64(4.927358803301345),
  '61432': np.float64(0.3297991792932562),
  '52155': np.float64(-2.1855158349514756),
  '51771': np.float64(-2.164001047829384)})

In [16]:
result[0].experiment.parameters

{'28087': RangeParameter(name='28087', parameter_type=FLOAT, range=[-1.0, 1.0]),
 '59234': RangeParameter(name='59234', parameter_type=FLOAT, range=[-1.0, 1.0]),
 '52155': RangeParameter(name='52155', parameter_type=FLOAT, range=[-1.0, 1.0]),
 '51771': RangeParameter(name='51771', parameter_type=FLOAT, range=[-1.0, 1.0])}

In [52]:
client.features.lookup(result[1] + result[2], base_variant)

{63676: Feature("Color terms used as descriptive adjectives"),
 26702: Feature("Describing physical or abstract properties and characteristics"),
 31815: Feature("The color red"),
 53668: Feature("Enterprise software products with Red in their name"),
 56422: Feature("Discussion of Blue Ocean Strategy business framework"),
 47548: Feature("Making bold aesthetic statements or attention-grabbing choices"),
 61432: Feature("Setup phrases in simple jokes, especially 'Because it' in punchlines"),
 52155: Feature("Descriptions of cheeks and cheekbones in beauty and cosmetic contexts"),
 59234: Feature("Descriptions of characters blushing or their faces turning red from emotion"),
 29816: Feature("The color yellow"),
 28087: Feature("Explanatory descriptions of the color red and its associations"),
 22844: Feature("Descriptions of redheaded people and associated descriptive language patterns")}

In [17]:
get_feature_importance(result[0], neutral_value=0.0, maximize=True)

{'28087': np.float64(0.4270054284564986),
 '59234': np.float64(1.3435281690673704),
 '52155': np.float64(0.029785647649950064),
 '51771': np.float64(0.051300438404772386)}

In [18]:
get_feature_importance(result[0], neutral_value=0.0, maximize=True)

{'28087': np.float64(0.4270054289224636),
 '59234': np.float64(1.3435281697638604),
 '52155': np.float64(0.029785648112705232),
 '51771': np.float64(0.05130043559755748)}

In [34]:
get_feature_importance(result[0], neutral_value=0.0, maximize=True)

{'28087': np.float64(0.4270054281771358),
 '59234': np.float64(1.3435281686483478),
 '52155': np.float64(0.02978564722700483),
 '51771': np.float64(0.051300434362269165)}

In [35]:
result[0].get_trials_data_frame()



Unnamed: 0,trial_index,arm_name,trial_status,generation_method,objective,28087,59234,52155,51771
0,0,0_0,COMPLETED,Sobol,-2.795823,-0.712859,-0.918937,-0.822593,-0.015031
1,1,1_0,COMPLETED,Sobol,-8.801599,0.825829,0.231215,0.588376,0.130167
2,2,2_0,COMPLETED,Sobol,-6.769167,0.055169,-0.018237,-0.126197,-0.531794
3,3,3_0,COMPLETED,Sobol,-6.916401,-0.41815,0.830638,0.391676,0.666629
4,4,4_0,COMPLETED,Sobol,-13.095476,-0.06289,-0.423244,0.987969,0.751937
5,5,5_0,COMPLETED,Sobol,-0.005582,0.457243,0.735689,-0.722185,-0.886558
6,6,6_0,COMPLETED,Sobol,-0.001953,0.658141,-0.514861,0.055428,0.294987
7,7,7_0,COMPLETED,Sobol,-6.166072,-0.802483,0.327186,-0.28995,-0.410398
8,8,8_0,COMPLETED,BoTorch,-1.479064,1.0,-1.0,-0.045774,0.65652
9,9,9_0,COMPLETED,BoTorch,-0.025442,0.05766,0.588517,-0.923838,-0.68456


In [39]:
from ax.service.ax_client import AxClient
from ax.service.utils.instantiation import ObjectiveProperties

In [48]:
[{"name": p.name, "parameter_type": "FLOAT", "range": [-1.0, 1.0]} for p in result[0].experiment.parameters.values()]

[{'name': '28087', 'parameter_type': 'FLOAT', 'range': [-1.0, 1.0]},
 {'name': '59234', 'parameter_type': 'FLOAT', 'range': [-1.0, 1.0]},
 {'name': '52155', 'parameter_type': 'FLOAT', 'range': [-1.0, 1.0]},
 {'name': '51771', 'parameter_type': 'FLOAT', 'range': [-1.0, 1.0]}]

In [93]:
ax_client = AxClient()
ax_client.create_experiment(parameters=[{"name": p.name, "type": "range", "bounds": [-1.0, 1.0]} for p in result[0].experiment.parameters.values()], objectives={"objective": ObjectiveProperties(minimize=False)})

[INFO 02-18 17:56:58] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 28087. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:56:58] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 59234. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:56:58] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 52155. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 02-18 17:56:58] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter 51771. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter di

In [94]:
for _ in range(8):
    trial, trial_index = ax_client.get_next_trial()
    objective = objective_function(trial)
    ax_client.complete_trial(trial_index, raw_data={"objective": objective})

  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))
  warn("Encountered exception in computing model fit quality: " + str(e))


In [95]:
for _ in range(1):
    trial, trial_index = ax_client.get_next_trial()
    objective = objective_function(trial)
    ax_client.complete_trial(trial_index, raw_data={"objective": objective})

In [97]:
model = ax_client.generation_strategy.model.model.surrogate.model

# Before optimization:
model.train()  # First put in train mode
model.eval()   # Then put in eval mode
model.likelihood.eval()  # Make sure likelihood is in eval mode too

# Clear the GPyTorch caches
if hasattr(model, "prediction_strategy"):
    model.prediction_strategy = None  # Clear the cached predictions

# Force a recomputation of the cache
dummy_X = torch.zeros((1, model.train_inputs[0].shape[-1]), device=model.train_inputs[0].device)
with torch.no_grad():
    model.posterior(dummy_X)

In [98]:
get_feature_importance(ax_client)

{'28087': np.float64(3.3714687635920786),
 '59234': np.float64(-0.9437166073682057),
 '52155': np.float64(0.430833179489996),
 '51771': np.float64(-0.17982463099368484)}

In [99]:
get_feature_importance(ax_client)

{'28087': np.float64(4.341524385061879),
 '59234': np.float64(0.026339014086603996),
 '52155': np.float64(1.4008888013092955),
 '51771': np.float64(0.7902309907181051)}

In [100]:
get_feature_importance(ax_client)

{'28087': np.float64(4.341524386520801),
 '59234': np.float64(0.026339014793803628),
 '52155': np.float64(1.400888802354065),
 '51771': np.float64(0.790230991111101)}

In [69]:
ax_client.generation_strategy.model.model.surrogate

<Surrogate botorch_model_class=None mll_class=<class 'gpytorch.mlls.exact_marginal_log_likelihood.ExactMarginalLogLikelihood'> outcome_transform_classes=None input_transform_classes=None 

Where I'm at here is I notice that get_feature_importance returns bogus nonsense values the first time it's called, but returns much more reasonable values the second time and beyond.