In [3]:
from typing import List, Tuple
import argparse
from collections import deque

import numpy as np
import torch
import yaml
from epsilon_transformers.process.MixedStateTree import (MixedStateTree,
                                                         MixedStateTreeNode)
from epsilon_transformers.process.Process import (
    Process, _compute_emission_probabilities, _compute_next_distribution)
from epsilon_transformers.process.processes import Mess3
from tqdm import tqdm

from src.utils import get_cached_belief_filename, MODEL_PATH_005_085, MODEL_PATH_015_06
from src.generate_paths_and_beliefs import generate_mess3_beliefs, save_beliefs
from typing import Tuple, Set, List
from pathlib import Path
from transformer_lens import HookedTransformer
from src.experiment import run_activation_to_beliefs_regression, r_squared, load_model
import random
import time

In [9]:
model = load_model(MODEL_PATH_005_085 / "684806400.pt", MODEL_PATH_005_085 / "train_config.json", torch.device("cuda:0"))

In [18]:
def eval_rsq(model: HookedTransformer, x: int, a: int, denom: int = 10):
    st = time.time()
    inputs, input_beliefs = generate_mess3_beliefs(x, a, sort_pairs=True)
    
    print("gen beliefs", time.time() - st)
    indices = random.sample(range(len(inputs)), len(inputs) // denom)
    print("gen indices", time.time() - st)

    _, activations = model.run_with_cache(
            inputs[indices], names_filter=lambda x: "resid_post" in x
        )
    
    print("run transformer", time.time() - st)

    acts = activations["blocks.3.hook_resid_post"].cpu().detach().numpy()
    regression, belief_predictions = run_activation_to_beliefs_regression(
        acts, input_beliefs[indices]
    )
    print("regression", time.time() - st)

    rsq = r_squared(input_beliefs[indices], belief_predictions)
    print("rsq calc", time.time() - st)

    return rsq

tensor(0.9966)

In [31]:
x_values = [i/100 for i in range(2, 50, 2)]  # x values: 0.02, 0.04, ..., 0.50
a_values = [i/20 for i in range(0, 20)]      # a values: 0.0, 0.05, ..., 0.95, 1.0

for x in x_values:
    for a in a_values:
        # rsq = eval_rsq(model, x, a, 1)
        rsq_10 = eval_rsq(model, x, a, 20)
        # print(x, a, rsq_10.item())


0.02 0.0 0.8870023488998413
0.02 0.05 0.8997015953063965
0.02 0.1 0.8720460534095764
0.02 0.15 0.86314457654953
0.02 0.2 0.8565720319747925
0.02 0.25 0.8633589744567871
0.02 0.3 0.8667371273040771
0.02 0.35 0.8689465522766113
0.02 0.4 0.8723419904708862
0.02 0.45 0.8824520111083984
0.02 0.5 0.8924626111984253
0.02 0.55 0.9036968350410461
0.02 0.6 0.9140172004699707
0.02 0.65 0.9245727062225342
0.02 0.7 0.9371442198753357
0.02 0.75 0.9478135704994202
0.02 0.8 0.9593974947929382
0.02 0.85 0.9707686305046082
0.02 0.9 0.9819194674491882
0.02 0.95 0.9920350909233093
0.04 0.0 0.9032108783721924
0.04 0.05 0.9408770203590393
0.04 0.1 0.9245535731315613
0.04 0.15 0.9153757691383362
0.04 0.2 0.9097493886947632
0.04 0.25 0.9105076789855957
0.04 0.3 0.9119764566421509
0.04 0.35 0.9128133654594421
0.04 0.4 0.9165185689926147
0.04 0.45 0.9256211519241333
0.04 0.5 0.9376742839813232
0.04 0.55 0.9472992420196533
0.04 0.6 0.9581027030944824
0.04 0.65 0.9671405553817749
0.04 0.7 0.9752365946769714
0.04 