In [None]:
import torch as th
import sys
import os
import certifi
import matplotlib.pyplot as plt
from lpe.methods import MHIS
from lpe.method_utils import *
from lpe.utils import Transformer
from lpe.utils import datasets as lpe_datasets
from lpe.utils.model_adapter import ModelAdapter
from lpe.mo_mhis import MO_MHIS
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
# Set device
device = th.device('cuda' if th.cuda.is_available() else 'cpu')
device

In [None]:
# Set model name
model_name = 'gelu-1l'

In [None]:
# Run if toy model
model = Transformer.from_pretrained(model_name).to(device)

In [None]:
# Run if HF model
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# tokenizer.pad_token = tokenizer.eos_token
# hf_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
# model = ModelAdapter(hf_model, tokenizer, device=device)

In [None]:
dist_name = 'camel'
behavior = distribution_registry[dist_name](model.tokenizer, device=model.device)
orig_dists = behavior.input_dists(n_reps=N_REPS_DICT[dist_name])

In [None]:
# Generate estimates
methods = ["MHIS", "MO_MHIS"]
estimates = {}

for method in methods:
    print(f"Computing estimates for {method}")
    estimates[method] = {}
    for target in tqdm(list(targets)):
        if method == "MHIS":
            estimates[method][target] = MHIS(model, orig_dists, target, temp=RECOMMENDED_TEMPS[model_name]["MHIS"][dist_name], n_samples=2**16, burn_in=2**10)
        if method == "MO_MHIS":
            estimates[method][target] = MO_MHIS(model, orig_dists, target, temp=RECOMMENDED_TEMPS[model_name]["MHIS"][dist_name], n_samples=2**16, burn_in=2**10)


In [None]:
# Plot the (unfit) estimates
plt.figure(figsize=(12, 6))
colors = {'MO_MHIS': 'orange', 'MHIS': 'purple'}
for method in methods:
    estimates_for_method = [estimates[method][target] for target in targets]
    plt.scatter(gt_probs[targets].cpu().numpy(), estimates_for_method, label=method, color=colors[method])

    # Plot the 0s at the bottom
    zero_targets = list(filter(lambda target: estimates[method][target] == 0, targets))
    plt.scatter(gt_probs[zero_targets].cpu().numpy(), [1e-9]*len(zero_targets), color=colors[method], marker='x')


plt.plot([1e-9, 1e-5], [1e-9, 1e-5], label='ground truth', color='black')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('Ground Truth Probability')
plt.ylabel('Estimate')
plt.title(f"Estimates for {dist_name} on {model_name}")
plt.legend()
plt.show()