[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lexiconium/diffusion-planner/blob/main/examples/training/run_diffusion_training.ipynb)

In [None]:
#@title ## Install Dependencies

!pip install git+https://github.com/lexiconium/diffusion-planner > /dev/null 2>&1

In [None]:
#@title ## Import Dependencies

import os

from diffusers.schedulers import DDPMScheduler
from transformers import Trainer, TrainerCallback, TrainerControl, TrainerState, TrainingArguments
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR

from diffusion_planner.diffusers import TemporalUnetDiffuserForDDPM
from diffusion_planner.models import TemporalUnet
from diffusion_planner.utils.data import DatasetForD4RL, DynamicCollatorWithPadding

In [None]:
#@title ## Configuration

#@markdown ---

#@markdown #### Data configuration

path_or_url = "https://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_expert-v2.hdf5"  #@param {type:"string"}
pad_to_multiple_of = 8  #@param {type:"integer"}
trajectory_type = "dynamic"  #@param {type:"string"}
horizon = None  #@param {type:"integer"}
max_horizon = None  #@param {type:"integer"}
min_horizon = 256  #@param {type:"integer"}

#@markdown #### Model configuration

block_out_channels = "(32, 64, 128, 256)"  #@param {type:"string"}
num_layers_per_block = 2  #@param {type:"integer"}
norm_eps = 1e-5  #@param {type:"number"}
num_groups = 8  #@param {type:"integer"}
dropout = 0.0  #@param {type:"number"}

#@markdown #### Training configuration

output_dir = "outputs"  #@param {type:"string"}
do_train = True  #@param {type:"boolean"}
per_device_train_batch_size = 8  #@param {type:"integer"}
learning_rate = 2e-4  #@param {type:"number"}
weight_decay = 1e-2  #@param {type:"number"}
adam_beta1 = 0.9  #@param {type:"number"}
adam_beta2 = 0.999  #@param {type:"number"}
adam_epsilon = 1e-8  #@param {type:"number"}
num_train_epochs = 1  #@param {type:"integer"}
lr_scheduler_type = "linear"  #@param ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
warmup_ratio = 0.2  #@param {type:"number"}
warmup_steps = 0  #@param {type:"integer"}
save_total_limit = 5  #@param {type:"integer"}
seed = 42  #@param {type:"integer"}
fp16 = True  #@param {type:"boolean"}

In [None]:
class SaveConfigCallback(TrainerCallback):
    def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        kwargs["model"].unet.save_config(
            os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "model_config.json")
        )


dataset = DatasetForD4RL(
    path_or_url,
    trajectory_type=trajectory_type,
    horizon=horizon,
    max_horizon=max_horizon,
    min_horizon=min_horizon
)

transition_dim = dataset.observation_dim + dataset.action_dim

scheduler = DDPMScheduler()
unet = TemporalUnet(
    in_channels=transition_dim,
    out_channels=transition_dim,
    block_out_channels=eval(block_out_channels),
    num_layers_per_block=num_layers_per_block,
    norm_eps=norm_eps,
    num_groups=num_groups,
    dropout=dropout
)

model = TemporalUnetDiffuserForDDPM(unet=unet, scheduler=scheduler)

training_args = TrainingArguments(
    output_dir=output_dir,
    do_train=do_train,
    per_device_train_batch_size=per_device_train_batch_size,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    adam_beta1=adam_beta1,
    adam_beta2=adam_beta2,
    adam_epsilon=adam_epsilon,
    num_train_epochs=num_train_epochs,
    lr_scheduler_type=lr_scheduler_type,
    warmup_ratio=warmup_ratio,
    warmup_steps=warmup_steps,
    save_total_limit=save_total_limit,
    seed=seed,
    fp16=fp16
)

trainer = Trainer(
    model,
    args=training_args,
    data_collator=DynamicCollatorWithPadding(
        pad_to_multiple_of=pad_to_multiple_of
    ),
    train_dataset=dataset,
    callbacks=[SaveConfigCallback]
)

trainer.train()