In [3]:
import sys
# use line-buffering for both stdout and stderr
# sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
# sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)

import hydra
from omegaconf import OmegaConf
import os
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
import pathlib
from diffusion_policy.workspace.base_workspace import BaseWorkspace
 
import torch


from torch.utils.data import DataLoader
import copy
import random
import numpy as np

from diffusion_policy.common.checkpoint_util import TopKCheckpointManager
from diffusion_policy.common.json_logger import JsonLogger
from diffusion_policy.common.pytorch_util import dict_apply, optimizer_to

import tqdm

from diffusion_policy.policy.diffusion_unet_hybrid_image_policy import DiffusionUnetHybridImagePolicy

from diffusion_policy.common.pytorch_util import dict_apply, optimizer_to
from diffusion_policy.model.diffusion.ema_model import EMAModel
from diffusion_policy.model.common.lr_scheduler import get_scheduler

from diffusion_policy.dataset.robomimic_replay_image_dataset import RobomimicReplayImageDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Sampler 

import h5py


OmegaConf.register_new_resolver("eval", eval, replace=True)

In [6]:
config_path='.'
config_name = "image_franka_rlbench.yaml" 

In [7]:
with initialize(version_base=None, config_path=config_path):
    cfg_org = compose(
        config_name=config_name,
        overrides=[
            "hydra.run.dir=data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}",
            "training.seed=42",
            "training.device=cuda:0"
        ],
    )
    print(cfg_org)
    
OmegaConf.resolve(cfg_org)

print('resume: ', cfg_org.training.resume)

{'_target_': 'diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace.TrainDiffusionUnetHybridWorkspace', 'checkpoint': {'save_last_ckpt': True, 'save_last_snapshot': False, 'topk': {'format_str': 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt', 'k': 5, 'mode': 'max', 'monitor_key': 'test_mean_score'}}, 'dataloader': {'batch_size': 64, 'num_workers': 8, 'persistent_workers': False, 'pin_memory': True, 'shuffle': True}, 'dataset_obs_steps': 2, 'ema': {'_target_': 'diffusion_policy.model.diffusion.ema_model.EMAModel', 'inv_gamma': 1.0, 'max_value': 0.9999, 'min_value': 0.0, 'power': 0.75, 'update_after_step': 0}, 'exp_name': 'default', 'horizon': 16, 'keypoint_visible_rate': 1.0, 'logging': {'group': None, 'id': None, 'mode': 'online', 'name': '2022.12.29-22.31.41_train_diffusion_unet_hybrid_square_image', 'project': 'diffusion_policy_debug', 'resume': True, 'tags': ['train_diffusion_unet_hybrid', 'square_image', 'default']}, 'multi_run': {'run_dir': 'data/outputs

In [8]:
last_checkpoint_dir = None 

In [9]:
class TrainDiffusionUnetHybridWorkspace(BaseWorkspace):
    include_keys = ['global_step', 'epoch']

    def __init__(self, cfg: OmegaConf, output_dir=None):
        super().__init__(cfg, output_dir=output_dir)
        # set seed
        seed = cfg.training.seed
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        
        # configure model
        self.model: DiffusionUnetHybridImagePolicy = hydra.utils.instantiate(cfg.policy)
        print("configuring done")
        self.ema_model: DiffusionUnetHybridImagePolicy = None
        if cfg.training.use_ema:
            self.ema_model = copy.deepcopy(self.model)

        # configure training state
        self.optimizer = hydra.utils.instantiate(
            cfg.optimizer, params=self.model.parameters())

        # configure training state
        self.global_step = 0
        self.epoch = 0

In [10]:
import datetime

### recreating workspace

In [12]:
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
output_dir = f"/home/carl_lab/policy_training/diffusion_policy/diffusion_policy/data/outputs/custom{timestamp}"
# os.mkdir(output_dir)
#make dirs
os.makedirs(output_dir, exist_ok=True)
workspace = TrainDiffusionUnetHybridWorkspace(cfg_org, output_dir=output_dir)

self = workspace
print('output dir: ', output_dir)



using obs modality: low_dim with keys: ['joint_states', 'ee_states', 'gripper_states']
using obs modality: rgb with keys: ['eye_in_hand_rgb', 'agentview_rgb']
using obs modality: depth with keys: []
using obs modality: scan with keys: []




Diffusion params: 2.564722e+08
Vision params: 2.239418e+07
configuring done
output dir:  /home/carl_lab/policy_training/diffusion_policy/diffusion_policy/data/outputs/custom2025_03_23_12_55_48


In [13]:
cfg = copy.deepcopy(self.cfg)

# resume training
# if cfg.training.resume:
#     lastest_ckpt_path = self.get_checkpoint_path()
#     if lastest_ckpt_path.is_file():
#         print(f"Resuming from checkpoint {lastest_ckpt_path}")
#         self.load_checkpoint(path=lastest_ckpt_path)

In [14]:
new_config = OmegaConf.to_container(cfg.task.dataset, resolve=True )
del new_config['_target_']


dataset = RobomimicReplayImageDataset(**new_config)
len(dataset)

Acquiring lock on cache.
Cache does not exist. Creating!


Loading lowdim data: 100%|██████████| 4/4 [00:00<00:00, 23.01it/s]
Loading image data: 100%|██████████| 13214/13214 [00:08<00:00, 1622.92it/s]


Saving cache to disk.


5774

In [15]:
cfg_dataloader = {key:value for key,value in cfg.dataloader.items()} 

In [16]:
for key, value in cfg_dataloader.items():
    print(f"Key: {key}, Value: {value}")

Key: batch_size, Value: 64
Key: num_workers, Value: 8
Key: persistent_workers, Value: False
Key: pin_memory, Value: True
Key: shuffle, Value: True


In [17]:
print(dataset.__dict__)

{'replay_buffer': /
 ├── data
 │   ├── action (6607, 10) float32
 │   ├── agentview_rgb (6607, 128, 128, 3) uint8
 │   ├── ee_states (6607, 16) float32
 │   ├── eye_in_hand_rgb (6607, 128, 128, 3) uint8
 │   ├── gripper_states (6607, 1) float32
 │   └── joint_states (6607, 7) float32
 └── meta
     └── episode_ends (100,) int64, 'sampler': <diffusion_policy.common.sampler.SequenceSampler object at 0x79b6d2916130>, 'shape_meta': {'action': {'shape': [10]}, 'obs': {'agentview_rgb': {'shape': [3, 128, 128], 'type': 'rgb'}, 'ee_states': {'shape': [16]}, 'joint_states': {'shape': [7]}, 'eye_in_hand_rgb': {'shape': [3, 128, 128], 'type': 'rgb'}, 'gripper_states': {'shape': [1]}}}, 'rgb_keys': ['agentview_rgb', 'eye_in_hand_rgb'], 'lowdim_keys': ['ee_states', 'joint_states', 'gripper_states'], 'abs_action': True, 'n_obs_steps': 2, 'train_mask': array([ True,  True,  True,  True,  True,  True,  True,  True, False,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        Tr

In [18]:
train_dataloader = DataLoader(dataset, **cfg_dataloader)
normalizer = dataset.get_normalizer()
# configure validation dataset
val_dataset = dataset.get_validation_dataset()
val_dataloader = DataLoader(val_dataset, **cfg.val_dataloader)

In [19]:
batch =  next(iter(train_dataloader))
batch.keys()

dict_keys(['obs', 'action'])

In [20]:
batch['action'].shape

torch.Size([64, 16, 10])

In [21]:
batch['obs']['agentview_rgb'].shape

torch.Size([64, 2, 3, 128, 128])

In [22]:
self.model.set_normalizer(normalizer)
if cfg.training.use_ema:
    self.ema_model.set_normalizer(normalizer)

# configure lr scheduler
lr_scheduler = get_scheduler(
    cfg.training.lr_scheduler,
    optimizer=self.optimizer,
    num_warmup_steps=cfg.training.lr_warmup_steps,
    num_training_steps=(
        len(train_dataloader) * cfg.training.num_epochs) \
            // cfg.training.gradient_accumulate_every,
    # pytorch assumes stepping LRScheduler every epoch
    # however huggingface diffusers steps it every batch
    last_epoch=self.global_step-1
)

# configure ema
ema: EMAModel = None
if cfg.training.use_ema:
    ema = hydra.utils.instantiate(
        cfg.ema,
        model=self.ema_model)

In [23]:
topk_manager = TopKCheckpointManager(
    save_dir=os.path.join(self.output_dir, 'checkpoints'),
    **cfg.checkpoint.topk
)

# device transfer
device = torch.device(cfg.training.device)
self.model.to(device)
if self.ema_model is not None:
    self.ema_model.to(device)
optimizer_to(self.optimizer, device)

AdamW (
Parameter Group 0
    amsgrad: False
    betas: [0.95, 0.999]
    capturable: False
    eps: 1e-08
    foreach: None
    initial_lr: 0.0001
    lr: 0.0
    maximize: False
    weight_decay: 1e-06
)

In [24]:
print('output dir: ', output_dir)

output dir:  /home/carl_lab/policy_training/diffusion_policy/diffusion_policy/data/outputs/custom2025_03_23_12_55_48


In [26]:
cfg.training.num_epochs

1000

In [25]:
train_sampling_batch = None
log_path = os.path.join(self.output_dir, 'logs.json.txt')
with JsonLogger(log_path) as json_logger:
    for local_epoch_idx in range(cfg.training.num_epochs):
        step_log = dict()
        # ========= train for this epoch ==========
        train_losses = list()
        with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}", 
                leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
            for batch_idx, batch in enumerate(tepoch):
                # device transfer
                batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
                if train_sampling_batch is None:
                    train_sampling_batch = batch

                # compute loss
                raw_loss = self.model.compute_loss(batch)
                loss = raw_loss / cfg.training.gradient_accumulate_every
                loss.backward()

                # step optimizer
                if self.global_step % cfg.training.gradient_accumulate_every == 0:
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    lr_scheduler.step()
                
                # update ema
                if cfg.training.use_ema:
                    ema.step(self.model)

                # logging
                raw_loss_cpu = raw_loss.item()
                tepoch.set_postfix(loss=raw_loss_cpu, refresh=False)
                train_losses.append(raw_loss_cpu)
                step_log = {
                    'train_loss': raw_loss_cpu,
                    'global_step': self.global_step,
                    'epoch': self.epoch,
                    'lr': lr_scheduler.get_last_lr()[0]
                }

                is_last_batch = (batch_idx == (len(train_dataloader)-1))
                if not is_last_batch:
                    # log of last step is combined with validation and rollout
                     
                    json_logger.log(step_log)
                    self.global_step += 1

                if (cfg.training.max_train_steps is not None) \
                    and batch_idx >= (cfg.training.max_train_steps-1):
                    break

        # at the end of each epoch
        # replace train_loss with epoch average
        train_loss = np.mean(train_losses)
        step_log['train_loss'] = train_loss

        # ========= eval for this epoch ==========
        policy = self.model
        if cfg.training.use_ema:
            policy = self.ema_model
        policy.eval()

 
        # run validation
        if (self.epoch % cfg.training.val_every) == 0:
            with torch.no_grad():
                val_losses = list()
                with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}", 
                        leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
                    for batch_idx, batch in enumerate(tepoch):
                        batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
                        loss = self.model.compute_loss(batch)
                        val_losses.append(loss)
                        if (cfg.training.max_val_steps is not None) \
                            and batch_idx >= (cfg.training.max_val_steps-1):
                            break
                if len(val_losses) > 0:
                    val_loss = torch.mean(torch.tensor(val_losses)).item()
                    # log epoch average validation loss
                    step_log['val_loss'] = val_loss

        # run diffusion sampling on a training batch
        if (self.epoch % cfg.training.sample_every) == 0:
            with torch.no_grad():
                # sample trajectory from training set, and evaluate difference
                batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True))
                obs_dict = batch['obs']
                gt_action = batch['action']
                
                result = policy.predict_action(obs_dict)
                pred_action = result['action_pred']
                mse = torch.nn.functional.mse_loss(pred_action, gt_action)
                step_log['train_action_mse_error'] = mse.item()
                del batch
                del obs_dict
                del gt_action
                del result
                del pred_action
                del mse
        
        # checkpoint
        if (self.epoch % 100) == 0:
            self.save_checkpoint(tag=f'epoch_{self.epoch}') 
            
        # ========= eval end for this epoch ==========
        policy.train()

        # end of epoch
        # log of last step is combined with validation and rollout
         
        json_logger.log(step_log)
        self.global_step += 1
        self.epoch += 1

                                                                              

In [None]:
self.save_checkpoint(tag=f"after_train_{self.epoch}_epochs")