# Tutorial - Wandb experiements monitoring

In this notebook, we will see how to monitor your experiments using the integrated **wandb** callbacks.

In [None]:
# Install the library
%pip install pythae

## Train your Pythae model

In [None]:
import torchvision.datasets as datasets

%load_ext autoreload
%autoreload 2

In [None]:
mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=None)

train_dataset = mnist_trainset.data[:-10000].reshape(-1, 1, 28, 28) / 255.
eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255.

In [None]:
from pythae.models import BetaVAE, BetaVAEConfig
from pythae.trainers import BaseTrainerConfig
from pythae.pipelines.training import TrainingPipeline
from pythae.models.nn.benchmarks.mnist import Encoder_ResNet_VAE_MNIST, Decoder_ResNet_AE_MNIST

In [None]:
training_config = BaseTrainerConfig(
    output_dir='my_model',
    learning_rate=1e-4,
    batch_size=100,
    num_epochs=10, # Change this to train the model a bit more
)


model_config = BetaVAEConfig(
    input_dim=(1, 28, 28),
    latent_dim=16,
    beta=2.

)

model = BetaVAE(
    model_config=model_config,
    encoder=Encoder_ResNet_VAE_MNIST(model_config), 
    decoder=Decoder_ResNet_AE_MNIST(model_config) 
)

## Before lauching the pipeline, you will need to build your `WandbCallback`

To be able to access this feature you will need:
- a valid wandb acccount
- the `wandb` package installed in your virtual env. You can install it by running (`pip install wandb`)
- to be logged in by running (`$ wandb login`)

In [None]:
# Before being allowed to monitor your experiments you may need to run the following
# !pip install wandb
# !wandb login

In [None]:
# Create you callback
from pythae.trainers.training_callbacks import WandbCallback

callbacks = [] # the TrainingPipeline expects a list of callbacks

wandb_cb = WandbCallback() # Build the callback 

# SetUp the callback 
wandb_cb.setup(
    training_config=training_config, # training config
    model_config=model_config, # model config
    project_name="your_wandb_project", # specify your wandb project
    entity_name="your_wandb_entity", # specify your wandb entity
)

callbacks.append(wandb_cb) # Add it to the callbacks list

In [None]:
pipeline = TrainingPipeline(
    training_config=training_config,
    model=model
)

In [None]:
pipeline(
    train_data=train_dataset,
    eval_data=eval_dataset,
    callbacks=callbacks # pass the callbacks to the TrainingPipeline and you are done!
)
# You can log to https://wandb.ai/your_wandb_entity/your_wandb_project to monitor your training