In [1]:
import torch

import seaborn as sns

from nll_to_po.models.dn_policy import MLPPolicy
from nll_to_po.training.utils import train_single_policy
import nll_to_po.training.loss as L

sns.set_theme(style="whitegrid", font_scale=1.5)
sns.set_palette("colorblind")
sns.despine()

<Figure size 640x480 with 0 Axes>

In [3]:
import wandb as wb

In [None]:
# Experiment parameters
n_experiments = 10  # Number of repetitions
n_updates = 200
input_dim = 4
output_dim = 2
hidden_sizes = [64, 64]

fixed_logstd = False
init_dist_loc = 1.0
init_dist_scale = 0.1
init_dist_n_samples = 25
rsample_for_grpo = False

learning_rate = 0.001

for _ in range(n_experiments):
    policy = MLPPolicy(input_dim, output_dim, hidden_sizes, fixed_logstd)

    # Generate new random data for each experiment
    X = torch.randn(1, input_dim)
    mean_y = 2 + torch.randn(1, output_dim) * init_dist_loc
    y = mean_y + torch.randn(init_dist_n_samples, output_dim) * init_dist_scale
    X = X.repeat(init_dist_n_samples, 1)  # Repeat X for each sample

    # Create a DataLoader
    train_dataset = torch.utils.data.TensorDataset(X, y)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=1, shuffle=True
    )

    # Define the loss function
    # loss_function = L.NLL()
    loss_function = L.PO(
        n_generations=100,
        use_rsample=False,
        reward_transform="normalize",  # "normalize", "rbf", "none"
        rbf_gamma=None,
    )

    wandb_run = wb.init(
        project="mse_nll_po",
        name="PO",
        config={
            "fixed_logstd": fixed_logstd,
            "init_dist_loc": init_dist_loc,
            "loss": "PO",
            "learning_rate": learning_rate,
            "n_generations": 100,
            "reward_transform": "normalize",
            "init_dist_n_samples": init_dist_n_samples,
        },
    )

    # Run comparison
    train_single_policy(
        policy=policy,
        train_dataloader=train_dataloader,
        loss_function=loss_function,
        n_updates=n_updates,
        learning_rate=learning_rate,
        wandb_run=wandb_run,
    )

    wandb_run.finish()

  return F.mse_loss(input, target, reduction=self.reduction)
Training epochs: 100%|██████████| 200/200 [00:00<00:00, 344.58it/s]


0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇█████
train/NLL,█████▇▇▇▇▆▆▄▄▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/grad_norm,▄█▆▄▄█▅▆▆▆▅▅▅▆▄▂▂▂▂▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▇▆▆▆▅▆▅▅▄▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/mean_error,███▇▇▆▆▅▅▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,199.0
train/NLL,-0.97552
train/grad_norm,0.44027
train/loss,-0.07402
train/mean_error,0.0


  return F.mse_loss(input, target, reduction=self.reduction)
Training epochs: 100%|██████████| 200/200 [00:00<00:00, 341.17it/s]


0,1
epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
train/NLL,█████▇▇▆▆▆▄▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
train/grad_norm,█▆▆▃▅▂▂▃▂▂▂▂▁▂▂▂▁▁▂▂▁▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▂▁
train/loss,█▄▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/mean_error,██▇▅▄▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,199.0
train/NLL,-1.35511
train/grad_norm,0.30358
train/loss,-0.0636
train/mean_error,9e-05


  return F.mse_loss(input, target, reduction=self.reduction)
Training epochs: 100%|██████████| 200/200 [00:00<00:00, 350.37it/s]


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇████
train/NLL,███████▇▇▇▆▆▆▅▅▅▅▄▄▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁
train/grad_norm,▆▄█▆▆▅▆▇▄▅▄▅▅▄▅▁▃▂▄▃▂▁▃▂▂▂▁▁▂▂▁▂▁▁▁▁▁▂▂▁
train/loss,▆▆▆█▇▅▅▄▆▃▄▂▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/mean_error,█▅▅▅▆▆▆▅▅▄▄▄▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,199.0
train/NLL,-1.51733
train/grad_norm,0.1017
train/loss,-0.0377
train/mean_error,0.00033


  return F.mse_loss(input, target, reduction=self.reduction)
Training epochs: 100%|██████████| 200/200 [00:00<00:00, 323.47it/s]


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇████
train/NLL,███▇▇▆▅▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
train/grad_norm,▇█▆▅▆▆▄▆▄▆▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▂▁▁▁▁▁▁▁
train/loss,█▆▇▅▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/mean_error,█▇▇▅▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,199.0
train/NLL,-1.52604
train/grad_norm,0.19822
train/loss,-0.04226
train/mean_error,0.0


  return F.mse_loss(input, target, reduction=self.reduction)
Training epochs: 100%|██████████| 200/200 [00:00<00:00, 332.83it/s]


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇███
train/NLL,█▇▇▇▇▇▆▆▆▅▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
train/grad_norm,▆▅▄█▆▅▅▄▄▃▂▂▂▂▂▁▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,▇█▅▆▅▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/mean_error,██▇▆▆▅▅▄▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,199.0
train/NLL,-1.1676
train/grad_norm,0.18267
train/loss,-0.03388
train/mean_error,0.00075


  return F.mse_loss(input, target, reduction=self.reduction)
Training epochs: 100%|██████████| 200/200 [00:00<00:00, 353.38it/s]


0,1
epoch,▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
train/NLL,████████████▇▇▇▆▆▅▅▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
train/grad_norm,▄▆▄▄▄█▄▅▅▅▅▅▄▃▂▃▂▂▁▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▆▅▅▅▅▃▄▃▅▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/mean_error,█▇▇▇▇▆▆▅▅▅▄▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,199.0
train/NLL,-1.03639
train/grad_norm,0.18833
train/loss,-0.04377
train/mean_error,0.00074


  return F.mse_loss(input, target, reduction=self.reduction)
Training epochs: 100%|██████████| 200/200 [00:00<00:00, 343.22it/s]


0,1
epoch,▁▁▁▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇███
train/NLL,█████▇▇▇▇▆▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
train/grad_norm,▇▆▅▅▅▅▆█▃▃▂▂▂▂▂▂▂▁▂▂▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁
train/loss,▇▅▆█▆▆▄▃▂▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/mean_error,█▇▆▅▅▄▄▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,199.0
train/NLL,-0.97244
train/grad_norm,0.33574
train/loss,-0.06192
train/mean_error,0.00085


  return F.mse_loss(input, target, reduction=self.reduction)
Training epochs: 100%|██████████| 200/200 [00:00<00:00, 350.55it/s]


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇██
train/NLL,███▇▇▇▇▇▆▆▅▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁
train/grad_norm,█▇▆█▆▇█▅▅▆▄▃▂▂▃▂▂▂▂▁▁▁▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▆▄▄█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/mean_error,██▆▆▅▅▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,199.0
train/NLL,-1.86127
train/grad_norm,0.17779
train/loss,-0.02732
train/mean_error,1e-05


  return F.mse_loss(input, target, reduction=self.reduction)
Training epochs: 100%|██████████| 200/200 [00:00<00:00, 308.53it/s]


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇█
train/NLL,██▇▇▇▅▅▅▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/grad_norm,▅▅▆▇▅▇█▆▆█▃▆▃▁▂▂▂▂▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▆▅▆▄▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/mean_error,██▇▇▇▆▆▆▅▅▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,199.0
train/NLL,-1.14729
train/grad_norm,0.18963
train/loss,-0.04884
train/mean_error,0.00022


  return F.mse_loss(input, target, reduction=self.reduction)
Training epochs: 100%|██████████| 200/200 [00:00<00:00, 315.45it/s]


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇███
train/NLL,████▇▇▇▇▇▆▄▄▄▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁
train/grad_norm,▆▇▃▆▆▇█▅▂▂▄▃▂▁▄▂▁▂▂▂▂▂▂▁▁▁▂▁▁▂▁▁▂▁▁▂▁▁▁▂
train/loss,▇█▆▅▆▄▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/mean_error,██▇▇▆▅▅▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,199.0
train/NLL,-1.70201
train/grad_norm,0.17388
train/loss,-0.03488
train/mean_error,6e-05
