# Evaluate BED

Evaluate aline on BED tasks

In [None]:
import torch

import numpy as np
from matplotlib import pyplot as plt
import matplotlib as mpl


import random
from hydra import initialize, compose
from hydra.utils import instantiate
from omegaconf import OmegaConf



from utils import set_seed, load_state_dict
from utils.eval import eval_boed

from utils import create_target_mask, select_targets_by_mask, compute_ll

from model import BaseTransformer

COLORS = ['#0072B2', '#009E73', '#D55E00', '#CC79A7', '#F0E442', '#56B4E9']
plt.style.use(['seaborn-v0_8-colorblind'])

In [None]:
def load(config_overrides=[], state_dict=None, config_name="train_bed", verbose=False):
    # Initialise the hyper params
    with initialize(version_base=None, config_path="./config"):
        cfg = compose(config_name=config_name, overrides=config_overrides)
        
    if verbose:
        print(OmegaConf.to_yaml(cfg))


    # Setting device
    if not torch.cuda.is_available():
        cfg.device = "cpu"
    torch.set_default_device(cfg.device)
    if cfg.device == "cuda":
        torch.set_default_dtype(torch.float32)
        torch.set_default_device("cuda")

    # Setting random seed
    if cfg.fix_seed:
        set_seed(cfg.seed)
    else:
        cfg.seed = torch.random.seed()

    # Data
    experiment = instantiate(cfg.task)

    # Model
    embedder = instantiate(cfg.embedder)
    encoder = instantiate(cfg.encoder)
    head = instantiate(cfg.head)
    model = BaseTransformer(embedder, encoder, head)
    

    if state_dict is not None:
        model = load_state_dict(model, cfg.output_dir, state_dict)


    return cfg, experiment, model

## Location Finding

### EIG Bounds

In [None]:
n_query = 2000

cfg, experiment, model = load([f"task.n_query_init={n_query}"], "aline_loc.pth")

bounds = eval_boed(model, experiment, cfg.T-1, int(1e3), 2000, 40, cfg.time_token)

In [None]:
@torch.no_grad()
def plot_policy2d(model, experiment, T=30, N=200, title="", posterior=True):
    # Get trace
    model.eval()
    design_indices = []         # action: indices of the chosen designs
    action_log_probs = []       # log probs of design history
    full_action_probs = []  # log probs of all designs
    nlls_for_prediction = []
    nlls_for_query = []

    experiment.n_query_init = N

    batch = experiment.sample_batch(1)

    # T-steps experiment

    mask_type = random.choice(cfg.task.mask_type)
    batch.target_mask = create_target_mask(mask_type,
                                        cfg.task.embedding_type,
                                        cfg.task.n_target_data,
                                        cfg.task.n_target_theta,
                                        cfg.task.n_selected_targets,
                                        cfg.task.predefined_masks,
                                        cfg.task.predefined_mask_weights,
                                        cfg.task.mask_index,
                                        cfg.task.attend_to)

    for t in range(T):
        pred = model.forward(batch)

        idx = pred.design_out.idx                           # [B, 1]
        design_indices.append(idx)

        # Update the batch
        batch = experiment.update_batch(batch, idx)

        # Action log probs
        action_log_probs.append(pred.design_out.log_prob)   # [B]
        full_action_probs.append(pred.design_out.zt)    # [B, N_design]

        # NLLs
        target_ll = compute_ll(batch.target_all,
                                pred.posterior_out.mixture_means,
                                pred.posterior_out.mixture_stds,
                                pred.posterior_out.mixture_weights)  # [B, n_target]

        masked_target_ll = select_targets_by_mask(target_ll, batch.target_mask)

        if cfg.task.embedding_type == "mix" and mask_type == "all":
            nll_for_query = - (masked_target_ll[:, :-cfg.task.n_target_theta].mean(dim=-1) +
                                masked_target_ll[:, -cfg.task.n_target_theta:].mean(dim=-1))
        else:
            nll_for_query = - masked_target_ll.mean(dim=-1)
        nlls_for_query.append(nll_for_query)

        if cfg.task.embedding_type == "mix":
            nll = - (target_ll[:, :-cfg.task.n_target_theta].mean(dim=-1) +
                        target_ll[:, -cfg.task.n_target_theta:].mean(dim=-1))
        else:
            nll = - target_ll.mean(dim=-1)
        nlls_for_prediction.append(nll)

    log_probs = torch.stack(action_log_probs, dim=1)


    norm = plt.Normalize(0, 1)  # Normalize colors between 0 and 1
    log_probs_exp = log_probs[0].exp().cpu().numpy()  # Convert log probs to probabilities
    alpha_values = log_probs_exp # / log_probs_exp.max()  # Normalize alpha between 0 and 1
    color_values = np.arange(1, T+1, 1) / T  # Normalize time steps to [0,1]

    # Create figure
    fig, ax = plt.subplots(figsize=(6, 4))

    # Plot the posterior
    if posterior:
        # For contour plot
        num_point = 200
        x = torch.linspace(0, 1, num_point)
        y = torch.linspace(0, 1, num_point)
        X, Y = torch.meshgrid(x, y)
        pos = torch.stack([X, Y], dim=-1).reshape(1, -1, 2, 1)

        prob = compute_ll(pos, 
                          pred.posterior_out.mixture_means,
                                pred.posterior_out.mixture_stds,
                                pred.posterior_out.mixture_weights).sum(-1)[0]
        
        contourf = ax.contourf(X.cpu().numpy(), Y.cpu().numpy(), prob.reshape(num_point, num_point).cpu().numpy(), 16, cmap=mpl.cm.bone) # PuBu_r

        # add color bar
        cbar_pos = fig.colorbar(contourf, ax=ax)
        cbar_pos.set_label(r'Posterior $\log q(\theta \, | \,  \mathcal{D}_T)$', 
                           rotation=270, verticalalignment='baseline')

    # scatter plot
    scatter = ax.scatter(
        batch.context_x[0, :T, 0].cpu().numpy(), 
        batch.context_x[0, :T, 1].cpu().numpy(), 
        c=color_values, cmap='summer', norm=norm, label=r'$\xi_t$'
    )



    # Add first colorbar for time step (color)
    cbar1 = fig.colorbar(scatter, ax=ax)
    cbar1.set_label(r'Time step $t/T$', rotation=270, verticalalignment='baseline', fontsize=12)


    # Plot theta points
    theta = batch.target_theta.reshape(-1, experiment.K, experiment.dim_x)
    ax.scatter(theta[0, :, 0].cpu().numpy(), theta[0, :, 1].cpu().numpy(), color=COLORS[3], label=r'$\theta$', marker='*', s=120)



    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
plot_policy2d(model, experiment, 30, title="Location Finding", N=2000)

## CES

### EIG Bounds

In [None]:
cfg, experiment, model = load(["task=ces", f"task.n_query_init={n_query}"], "aline_ces.pth")

bounds = eval_boed(model, experiment, cfg.T-1, int(1e7), 2000, 10, cfg.time_token)

In [None]:
@torch.no_grad()
def plot_posterior(model, experiment, T=30, N=200, title="CES"):
    # Evaluate model and generate data
    model.eval()
    experiment.n_query_init = N
    batch = experiment.sample_batch(1)
    theta = batch.target_theta.squeeze(-1)

    # Run T steps
    for t in range(T):
        pred = model.forward(batch)
        idx = pred.design_out.idx
        batch = experiment.update_batch(batch, idx)

    # Extract target variables
    rho = theta[..., 0]
    alpha = theta[..., 1:4]
    log_u = theta[..., 4]

    # Setup subplots
    fig, axes = plt.subplots(1, 3, figsize=(9, 3))
    line_width = 2
    alpha_grid = 0.5

    # --- Plot rho ---
    ax = axes[0]
    x = torch.linspace(0, 1, 100).unsqueeze(-1)
    prob = compute_ll(x,
                      pred.posterior_out.mixture_means[:, 0],
                      pred.posterior_out.mixture_stds[:, 0],
                      pred.posterior_out.mixture_weights[:, 0]).exp()
    ax.plot(x.cpu().numpy(), prob.cpu().numpy(), linewidth=line_width)
    ax.axvline(rho.cpu().numpy(), color='#f55f51', linestyle='--')
    ax.set_title(r"$\rho$")
    ax.set_xlabel(r"$\rho$")
    # ax.set_ylabel(r"$p(\rho)$")
    ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%d'))
    ax.xaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
    # Add minor ticks
    ax.xaxis.set_minor_locator(mpl.ticker.AutoMinorLocator())
    ax.yaxis.set_minor_locator(mpl.ticker.AutoMinorLocator())
    # ax.grid(True, linestyle='--', alpha=alpha_grid)

    # --- Plot alpha ---
    ax = axes[1]
    x = torch.linspace(0, 1, 100).unsqueeze(-1)
    for i, a in enumerate(alpha.squeeze(0)):
        prob = compute_ll(x,
                          pred.posterior_out.mixture_means[:, i + 1],
                          pred.posterior_out.mixture_stds[:, i + 1],
                          pred.posterior_out.mixture_weights[:, i + 1]).exp()
        ax.plot(x.cpu().numpy(), prob.cpu().numpy(), linewidth=line_width)
        ax.axvline(a.cpu().numpy(), color='#f55f51', linestyle='--')
    ax.set_title(r"$\alpha$")
    ax.set_xlabel(r"$\alpha$")
    # ax.set_ylabel(r"$p(\alpha)$")
    ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%d'))
    ax.xaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
    # Add minor ticks
    ax.xaxis.set_minor_locator(mpl.ticker.AutoMinorLocator())
    ax.yaxis.set_minor_locator(mpl.ticker.AutoMinorLocator())
    # ax.grid(True, linestyle='--', alpha=alpha_grid)

    # --- Plot log(u) ---
    ax = axes[2]
    x = torch.linspace(-6, 8, 100).unsqueeze(-1)
    prob = compute_ll(x,
                      pred.posterior_out.mixture_means[:, 4],
                      pred.posterior_out.mixture_stds[:, 4],
                      pred.posterior_out.mixture_weights[:, 4]).exp()
    ax.plot(x.cpu().numpy(), prob.cpu().numpy(), linewidth=line_width)
    ax.axvline(log_u.cpu().numpy(), color='#f55f51', linestyle='--')
    ax.set_title(r"$u$")
    ax.set_xlabel(r"$\log(u)$")
    # ax.set_ylabel(r"$p(\log(u))$")
    ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%d'))
    ax.xaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
    # Add minor ticks
    ax.xaxis.set_minor_locator(mpl.ticker.AutoMinorLocator())
    ax.yaxis.set_minor_locator(mpl.ticker.AutoMinorLocator())
    # ax.grid(True, linestyle='--', alpha=alpha_grid)

    # Add ONE y-label for the whole figure
    fig.text(0.05, 0.5, r"$p(\theta)$", va='center', rotation='vertical', fontsize=14)

    # Final layout tweaks
    fig.suptitle(title, fontweight='bold')
    fig.tight_layout(rect=[0.05, 0, 1, 1])
    plt.savefig('outputs/figures/bed_pos_ces.pdf', bbox_inches='tight', dpi=300)
    plt.show()

In [None]:
plot_posterior(model, experiment, T=30, N=n_query, title=None)