# $\beta$ -Pic b planet training
This notebook trains the models to get the same posteriors as found in the paper of [Sun et al](https://arxiv.org/pdf/2201.08506.pdf)

## Import the necessary packages 

In [1]:
import orbitize
import torch.nn as nn
import wandb
import zuko
import random 

from training import train
from orbitize import read_input
from lampe.data import H5Dataset

Connect to WandB

In [2]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmatteo-ruth[0m ([33mrutje[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

## Load the dataset with the observations

In [3]:
data_table = read_input.read_file('{}/betaPic.csv'.format(orbitize.DATADIR))

data_table = data_table[:-1] # Discard the  RV observation, don't know how to take it into account

## Load the generated datasets

Change the value of the scaling factor if needed and the priors of $\omega$ and $\tau$ as it may not work for the simulator

In [4]:
trainset = H5Dataset('datasets/beta-pic-train-2.h5', batch_size=2048, shuffle=True)  
validset = H5Dataset('datasets/beta-pic-val-2.h5', batch_size=2048, shuffle=True)

## Train the NPE model

In [5]:
train(
    trainset,
    validset,
    epochs = 512, 
    num_obs = len(data_table),

    # Embedding network architecture
    embedding_output_dim = 32,
    embedding_hidden_features = [256] * 3,

    # activation function for the embedding & NPE
    activation = nn.ELU,

    # NPE architecture
    transforms = 5, 
    flow = zuko.flows.spline.NSF,
    NPE_hidden_features = [512] * 5,

    # Training parameters
    initial_lr = 1e-3,
    weight_decay = 1e-2,
    clip = 1)

100%|██████████| 512/512 [5:57:10<00:00, 41.86s/epoch, loss=-14.4, val_loss=-14.6]   


0,1
lr,████████▄▄▄▄▄▄▄▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▇▅▅▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁
valid_loss,█▇▅▅▅▄▅▄▃▃▄▃▃▃▃▃▃▃▃▃▂▄▂▃▂▂▂▃▂▂▂▂▂▂▂▂▁▂▁▁

0,1
lr,6e-05
train_loss,-14.37477
valid_loss,-14.58765
