In [1]:
from typing import List, Tuple
import argparse
from collections import deque

import numpy as np
import torch
import yaml
from epsilon_transformers.process.MixedStateTree import (MixedStateTree,
                                                         MixedStateTreeNode)
from epsilon_transformers.process.Process import (
    Process, _compute_emission_probabilities, _compute_next_distribution)
from epsilon_transformers.process.processes import Mess3
from tqdm import tqdm

from src.utils import get_cached_belief_filename, MODEL_PATH_005_085, MODEL_PATH_015_06
from src.generate_paths_and_beliefs import generate_mess3_beliefs, save_beliefs
from typing import Tuple, Set, List
from pathlib import Path
from transformer_lens import HookedTransformer
from src.experiment import run_activation_to_beliefs_regression, r_squared, load_model

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# class EvolutionaryAlgorithm:
#     def __init__(param_ranges: List[Tuple[float, float]])

In [7]:
def get_beliefs(x: float, a: float):
    file_path = Path(get_cached_belief_filename(x, a))
    if file_path.exists():
        return torch.load(file_path)
    else:
        inputs, input_beliefs = generate_mess3_beliefs(x, a, sort_pairs=True)
        # save_beliefs(inputs, input_beliefs, x, a)
        return {
            "params": {"x": x, "a": a},
            "inputs": inputs,
            "input_beliefs": input_beliefs,
        }

def evaluate(model: HookedTransformer, inputs: torch.Tensor, input_beliefs: torch.Tensor):
    _, activations = model.run_with_cache(
            inputs, names_filter=lambda x: "resid_post" in x
        )
    acts = activations["blocks.3.hook_resid_post"].cpu().detach().numpy()
    regression, belief_predictions = run_activation_to_beliefs_regression(
        acts, input_beliefs
    )

    rsq = r_squared(input_beliefs, belief_predictions)

    return rsq

In [12]:
p = get_beliefs(0.01, 0.9)
device = torch.device("cuda:1")
model = load_model(MODEL_PATH_015_06 / "998406400.pt", MODEL_PATH_015_06 / "train_config.json", device)
model = load_model(MODEL_PATH_005_085 / "684806400.pt", MODEL_PATH_015_06 / "train_config.json", device)
evaluate(model, p["inputs"], p["input_beliefs"])

tensor(0.9542, dtype=torch.float64)

In [15]:
import random
import numpy as np

# Define parameters
population_size = 5
generations = 5
mutation_rate = 0.7
crossover_rate = 0.7
elitism_count = 2

# Define ranges for x and a
x_range = (0, 0.5)
a_range = (0, 1)

# Fitness function
def evaluate_fitness(chromosome):
    x, a = chromosome
    belief_dict = get_beliefs(x, a)
    score = evaluate(model, belief_dict["inputs"], belief_dict["input_beliefs"])
    return score

# Initialize population
def initialize_population(size):
    population = []
    for _ in range(size):
        x = random.uniform(*x_range)
        a = random.uniform(*a_range)
        population.append([x, a])
    return population

# Selection (Tournament Selection)
def select_parents(population, fitness, k=3):
    selected = []
    for _ in range(len(population)):
        tournament = random.sample(list(zip(population, fitness)), k)
        winner = max(tournament, key=lambda ind: ind[1])[0]
        selected.append(winner)
    return selected

# Crossover (Single-Point Crossover)
def crossover(parent1, parent2):
    if random.random() < crossover_rate:
        point = random.randint(1, 1)
        child1 = [parent1[0], parent2[1]]
        child2 = [parent2[0], parent1[1]]
        return child1, child2
    else:
        return parent1, parent2

# Mutation
def mutate(chromosome, mutation_rate):
    if random.random() < mutation_rate:
        chromosome[0] = np.clip(chromosome[0] + np.random.normal(0, 0.05), *x_range)
    if random.random() < mutation_rate:
        chromosome[1] = np.clip(chromosome[1] + np.random.normal(0, 0.1), *a_range)
    return chromosome

# Replacement (Elitism)
def replace_population(old_population, new_population, fitness):
    combined = list(zip(old_population, fitness)) + [(child, evaluate_fitness(child)) for child in new_population]
    combined.sort(key=lambda ind: ind[1], reverse=True)
    return [ind[0] for ind in combined[:population_size]]

# Main evolutionary algorithm
def evolutionary_algorithm():
    all_generations = []
    population = initialize_population(population_size)
    for generation in range(generations):
        print(population)
        fitness = [evaluate_fitness(chromosome) for chromosome in population]
        
        parents = select_parents(population, fitness)
        offspring = []
        while len(offspring) < population_size:
            parent1, parent2 = random.sample(parents, 2)
            child1, child2 = crossover(parent1, parent2)
            offspring.extend([mutate(child1, mutation_rate), mutate(child2, mutation_rate)])
        
        all_generations.append((population, fitness.item()))

        population = replace_population(population, offspring, fitness)
    
    return all_generations

# Run the algorithm and print the best solution
all_generations = evolutionary_algorithm()
#    best_solution = max(population, key=evaluate_fitness)
# print("Best solution:", best_solution)
# print("Best fitness:", evaluate_fitness(best_solution))


[[0.2925360457645261, 0.7509657805634686], [0.22253596692168948, 0.4102281671689537], [0.48336412427955855, 0.23656749593851345], [0.1267778474967992, 0.029779804942649202], [0.19425291157280122, 0.2226747029976075], [0.39365798206641084, 0.34960216514450326], [0.2591685002367581, 0.8026364522424402], [0.0201282706124864, 0.8772216372347319], [0.2163790772485742, 0.43004018203142624], [0.42502116399839907, 0.9217305102527351]]
[[0.22253596692168948, 1.0], [0.2925360457645261, 0.7509657805634686], [0.2591685002367581, 0.8026364522424402], [0.2925360457645261, 0.4102281671689537], [0.19425291157280122, 0.8026364522424402], [0.27293241186333145, 0.2226747029976075], [0.22253596692168948, 0.7509657805634686], [0.261197998789791, 0.6613878037732864], [0.42502116399839907, 0.9217305102527351], [0.2591685002367581, 0.2226747029976075]]
[[0.22253596692168948, 1.0], [0.2925360457645261, 1.0], [0.2925360457645261, 1.0], [0.22253596692168948, 1.0], [0.2591685002367581, 0.8955037361935247], [0.259

In [17]:
chromosomes, fitnesses = zip(*all_generations)

In [23]:
for chromosomes, fitnesses in zip(sum(chromosomes, []), sum(fitnesses, [])):
    print(chromosomes, fitnesses)

[0.2925360457645261, 0.7509657805634686] tensor(0.9928, dtype=torch.float64)
[0.22253596692168948, 0.4102281671689537] tensor(0.9886, dtype=torch.float64)
[0.48336412427955855, 0.23656749593851345] tensor(0.9092, dtype=torch.float64)
[0.1267778474967992, 0.029779804942649202] tensor(0.9659, dtype=torch.float64)
[0.19425291157280122, 0.2226747029976075] tensor(0.9886, dtype=torch.float64)
[0.39365798206641084, 0.34960216514450326] tensor(0.9826, dtype=torch.float64)
[0.2591685002367581, 0.8955037361935247] tensor(0.9922, dtype=torch.float64)
[0.0201282706124864, 0.8772216372347319] tensor(0.9771, dtype=torch.float64)
[0.2163790772485742, 0.43004018203142624] tensor(0.9885, dtype=torch.float64)
[0.42502116399839907, 0.9217305102527351] tensor(0.9907, dtype=torch.float64)
[0.22253596692168948, 1.0] tensor(0.9934, dtype=torch.float64)
[0.2925360457645261, 0.7509657805634686] tensor(0.9928, dtype=torch.float64)
[0.2591685002367581, 0.8955037361935247] tensor(0.9922, dtype=torch.float64)
[0.