# Train a Concatenated Multi-View Convolutional Variational Autoencoder

This script trains a Convolutional Variational Autoencoder (Conv-VAE) that handles
 multiple camera views by concatenating the images horizontally before processing.

We extend the existing Julian-8897-Conv-VAE-PyTorch implementation to work with
concatenated multi-view images, allowing us to create a unified latent representation
from multiple camera perspectives.

The parameters for training are specified in an experiment run of type 
"sensorprocessing_conv_vae_concat_multiview". The resulting model files are stored in 
the experiment directory.

 After running a satisfactory model, copy the model name and directory to the 
 experiment/run yaml file in the model_subdir and model_checkpoint fields.


In [5]:
import sys
sys.path.append("..")
from settings import Config
import pathlib
from pprint import pprint
import shutil
import json
from pathlib import Path
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sensorprocessing.sp_conv_vae_concat_multiview import ConcatConvVaeSensorProcessing

# Adding the Conv-VAE-PyTorch code to the path
sys.path.append(Config()["conv_vae"]["code_dir"])

# Import functions from the original Conv-VAE module
from sensorprocessing.conv_vae import get_conv_vae_config, create_configured_vae_json, train

# Import needed modules from the Conv-VAE package
import data_loader.data_loaders as module_data
from parse_config import ConfigParser

In [6]:
# Set to True for a dry run without copying files
dry_run = False


# Specify and load the experiment
experiment = "sensorprocessing_conv_vae_concat_multiview"
# Choose one of these runs based on desired latent size:
run = "proprio_128_concat_multiview"  # For 128-dimensional latent space
# run = "proprio_256_concat_multiview"    # For 256-dimensional latent space
exp = Config().get_experiment(experiment, run)
print("Experiment configuration:")
pprint(exp)




No system dependent experiment file
 /home/ssheikholeslami/SaharaBerryPickerData/experiments-Config/sensorprocessing_conv_vae_concat_multiview/proprio_128_concat_multiview_sysdep.yaml,
 that is ok, proceeding.
Configuration for experiment: sensorprocessing_conv_vae_concat_multiview/proprio_128_concat_multiview successfully loaded
Experiment configuration:
{'data_dir': PosixPath('/home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_conv_vae_concat_multiview/proprio_128_concat_multiview'),
 'epochs': 300,
 'exp_run_sys_indep_file': PosixPath('/lustre/fs1/home/ssheikholeslami/BerryPicker/src/experiment_configs/sensorprocessing_conv_vae_concat_multiview/proprio_128_concat_multiview.yaml'),
 'group_name': 'sensorprocessing_conv_vae_concat_multiview',
 'image_size': 128,
 'json_template_name': 'conv-vae-config-default.json',
 'latent_size': 128,
 'min_views_required': 1,
 'model_checkpoint': 'checkpoint-epoch300.pth',
 'model_dir': 'models',
 'model_name': 'VAE_Robot'

In [2]:
TEMPLATE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp"}


In [3]:


def build_concat_training_set(exp: dict):
    """Build the training set in a way that matches *stack_mode*.

    * ``width``   → save width‑concatenated RGBs (identical to v3 behaviour).
    * ``channel`` → **TODO** save paired views as .pt tensors or factor the
                    stacking into the Conv‑VAE dataloader.
    """

    if exp["stack_mode"] == "channel":
        raise NotImplementedError(
            "Dataset builder for stack_mode='channel' is not yet implemented – "
            "we'll add it in the next iteration."
        )

    demos_base = Path(Config()["demos"]["directory"]) / "demos"
    task_dir = demos_base / exp["training_task"]

    training_img_dir = Path(exp["data_dir"], exp["training_data_dir"], "Images")
    training_img_dir.mkdir(parents=True, exist_ok=True)

    count = 0
    for demo in task_dir.iterdir():
        if not demo.is_dir():
            continue
        by_timestep: dict[str, list[Path]] = {}
        for img in demo.iterdir():
            if img.suffix.lower() not in TEMPLATE_EXTENSIONS:
                continue
            key = img.stem.split("_")[0]  # 00001_dev2 → 00001
            by_timestep.setdefault(key, []).append(img)

        for t, imgs in by_timestep.items():
            if len(imgs) < exp["num_views"]:
                continue
            imgs = sorted(imgs)[: exp["num_views"]]
            pil_imgs = [Image.open(i).convert("RGB") for i in imgs]
            widths, heights = zip(*(im.size for im in pil_imgs))
            if len(set(heights)) > 1 or len(set(widths)) > 1:
                continue  # inconsistent sizes
            concat = Image.new("RGB", (sum(widths), heights[0]))
            x = 0
            for im in pil_imgs:
                concat.paste(im, (x, 0))
                x += im.size[0]
            out_name = f"{demo.name}_{t}.jpg"
            concat.save(training_img_dir / out_name)
            count += 1
    print(f"Prepared {count} training composites in {training_img_dir}")


In [5]:
#build dataset
build_concat_training_set(exp)

Prepared 1803 training composites in /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_conv_vae_concat_multiview/proprio_128_concat_multiview/vae-training-data/Images


In [6]:
vae_json = create_configured_vae_json(exp)
print("Generated VAE config json:", vae_json)
if dry_run:
    print("DRY_RUN enabled – skipping training")


/lustre/fs1/home/ssheikholeslami/BerryPicker/src/sensorprocessing/conv-vae-config-default.json
{'name': 'VAE_Robot', 'n_gpu': 1, 'arch': {'type': 'VanillaVAE', 'args': {'in_channels': 3, 'latent_dims': 128, 'flow': False}}, 'data_loader': {'type': 'CelebDataLoader', 'args': {'data_dir': '/home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_conv_vae_concat_multiview/proprio_128_concat_multiview/vae-training-data', 'batch_size': 64, 'shuffle': True, 'validation_split': 0.2, 'num_workers': 2}}, 'optimizer': {'type': 'Adam', 'args': {'lr': 0.005, 'weight_decay': 0.0, 'amsgrad': True}}, 'loss': 'elbo_loss', 'metrics': [], 'lr_scheduler': {'type': 'StepLR', 'args': {'step_size': 50, 'gamma': 0.1}}, 'trainer': {'epochs': 300, 'save_dir': '/home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_conv_vae_concat_multiview/proprio_128_concat_multiview/models', 'save_period': 5, 'verbosity': 2, 'monitor': 'min val_loss', 'early_stop': 10, 'tensorboard':

In [7]:
cfg = get_conv_vae_config(vae_json)
print(
    f"▶ Training {exp['latent_size']}‑dim latent Conv‑VAE on "
    f"{exp['num_views']}‑view composites …"
)
trainer = train(cfg)

▶ Training 128‑dim latent Conv‑VAE on 2‑view composites …


2025-05-07 05:10:57.551725: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-07 05:10:57.823001: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746609057.949339 3529479 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746609057.994831 3529479 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-07 05:10:58.199031: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [8]:
# 3) checkpoint info for yaml
print("model_subdir:", repr(trainer.checkpoint_dir.name))
print("model_checkpoint:", repr(f"checkpoint-epoch{trainer.epochs}.pth"))



model_subdir: '0507_051054'
model_checkpoint: 'checkpoint-epoch300.pth'


In [12]:
# 4) sanity check
sp = ConcatConvVaeSensorProcessing(exp)
dummy = torch.randn(1, 3, 64, 64)
latent = sp.encode([dummy] * exp["num_views"])
assert latent.shape[-1] == exp["latent_size"], f"latent mismatch: {latent.shape}"
print("✓ Sanity check passed – latent size correct")

✓ Sanity check passed – latent size correct


__Important__ After the training finished, in order to use the resulting system, one need to edit the run file (eg: vae_01.yaml) and enter into it the location of the checkpoint. This is the content printed by the code cell below