In [None]:
# import seaborn as sns
# sns.set_theme(style="whitegrid", font_scale=1.5)
# sns.set_palette("colorblind")
# sns.despine()

In [1]:
import torch

from nll_to_po.models.dn_policy import MLPPolicy
from nll_to_po.training.utils import train_single_policy, setup_logger
import nll_to_po.training.loss as L

import wandb

In [2]:
import os
import logging

logger = logging.getLogger("wandb")
logger.setLevel(logging.ERROR)
os.environ["WANDB_SILENT"] = "true"  # Suppress WandB output

### config

In [3]:
# Experiment parameters
n_experiments = 1
n_updates = 500
learning_rate = 0.001
use_wandb = True
wandb_project = "tractable"

# Policy architecture
input_dim = 4
output_dim = 2
hidden_sizes = [64, 64]
fixed_logstd = False

# Data generating dist q
init_dist_loc = 5.0
init_dist_scale = 0.75
init_dist_n_samples = 25

### NLL

In [5]:
for exp_idx in range(n_experiments):
    policy = MLPPolicy(input_dim, output_dim, hidden_sizes, fixed_logstd)

    # Generate new random data for each experiment
    X = torch.randn(1, input_dim)
    mean_y = torch.ones((1, output_dim)) * init_dist_loc
    y = mean_y + torch.randn(init_dist_n_samples, output_dim) * init_dist_scale
    X = X.repeat(init_dist_n_samples, 1)  # Repeat X for each sample
    batch_size = X.shape[0]

    # Create a DataLoader
    train_dataset = torch.utils.data.TensorDataset(X, y)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )

    # Define the loss function
    loss_function = L.NLL()

    config = {
        "batch_size": batch_size,
        "fixed_logstd": fixed_logstd,
        "init_dist_loc": init_dist_loc,
        "init_dist_scale": init_dist_scale,
        "init_dist_n_samples": init_dist_n_samples,
        "learning_rate": learning_rate,
        "loss": loss_function.name,
    }
    if use_wandb:
        wandb_run = wandb.init(
            project=wandb_project,
            config=config,
        )
    else:
        wandb_run = None
    logger, _, ts_writer = setup_logger(
        logger_name="nll_to_po",
        log_dir="../logs",
        env_id="test_theory",
        exp_name=f"NLL_{exp_idx}",
    )
    # Log the configuration
    logger.info(f"%%%%%%%%%%%%%%%%%%%\nconfig:\n{config}\n%%%%%%%%%%%%%%%%%%%%")

    # Run comparison
    train_single_policy(
        policy=policy,
        train_dataloader=train_dataloader,
        loss_function=loss_function,
        n_updates=n_updates,
        learning_rate=learning_rate,
        wandb_run=wandb_run,
        tensorboard_writer=ts_writer,
        logger=logger,
    )

    wandb_run.finish()

2025-08-13 13:38:05 - nll_to_po - INFO - %%%%%%%%%%%%%%%%%%%
config:
{'batch_size': 25, 'fixed_logstd': False, 'init_dist_loc': 5.0, 'init_dist_scale': 0.75, 'init_dist_n_samples': 25, 'learning_rate': 0.001, 'loss': 'NLL'}
%%%%%%%%%%%%%%%%%%%%
2025-08-13 13:38:05 - nll_to_po - INFO - %%%%%%%%%%%%%%%%%%%
config:
{'batch_size': 25, 'fixed_logstd': False, 'init_dist_loc': 5.0, 'init_dist_scale': 0.75, 'init_dist_n_samples': 25, 'learning_rate': 0.001, 'loss': 'NLL'}
%%%%%%%%%%%%%%%%%%%%
2025-08-13 13:38:05 - nll_to_po - INFO - Starting training for 500 epochs
2025-08-13 13:38:05 - nll_to_po - INFO - Starting training for 500 epochs
Training epochs:   0%|          | 0/500 [00:00<?, ?it/s]2025-08-13 13:38:05 - nll_to_po - INFO - %%%%%%%%%%%%%%%%%%%
config:
{'batch_size': 25, 'fixed_logstd': False, 'init_dist_loc': 5.0, 'init_dist_scale': 0.75, 'init_dist_n_samples': 25, 'learning_rate': 0.001, 'loss': 'NLL'}
%%%%%%%%%%%%%%%%%%%%
2025-08-13 13:38:05 - nll_to_po - INFO - Starting training fo

### PG + entropy

In [4]:
n_generations = 5
use_rsample = True
reward_transform = "normalize"
entropy_weight = 0.01

for _ in range(n_experiments):
    policy = MLPPolicy(input_dim, output_dim, hidden_sizes, fixed_logstd)

    # Generate new random data for each experiment
    X = torch.randn(1, input_dim)
    mean_y = torch.ones((1, output_dim)) * init_dist_loc
    y = mean_y + torch.randn(init_dist_n_samples, output_dim) * init_dist_scale
    X = X.repeat(init_dist_n_samples, 1)  # Repeat X for each sample
    batch_size = X.shape[0]

    # Create a DataLoader
    train_dataset = torch.utils.data.TensorDataset(X, y)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )

    # Define the loss function
    loss_function = L.PO_Entropy(
        n_generations=n_generations,
        use_rsample=use_rsample,
        reward_transform=reward_transform,
        entropy_weight=entropy_weight,
    )

    wandb_run = wandb.init(
        project=wandb_project,
        config={
            "batch_size": batch_size,
            "fixed_logstd": fixed_logstd,
            "init_dist_loc": init_dist_loc,
            "init_dist_scale": init_dist_scale,
            "init_dist_n_samples": init_dist_n_samples,
            "learning_rate": learning_rate,
            "loss": loss_function.name,
            "n_generations": n_generations,
            "use_rsample": use_rsample,
            "reward_transform": reward_transform,
            "entropy_weight": entropy_weight,
        },
    )

    # Run comparison
    train_single_policy(
        policy=policy,
        train_dataloader=train_dataloader,
        loss_function=loss_function,
        n_updates=n_updates,
        learning_rate=learning_rate,
        wandb_run=wandb_run,
    )

    wandb_run.finish()

  return F.mse_loss(input, target, reduction=self.reduction)
Training epochs: 100%|██████████| 500/500 [00:01<00:00, 336.06it/s]
