# Train models for visual proprioception

Train a regression model for visual proprioception. The input is sensory data (eg. a camera image). This is encoded by a p;predefined sensorprocessing component into a latent representation. What we are training and saving here is a regressor that is mapping the latent representation to the position of the robot (eg. a vector of 6 degrees of freedom).

The specification of this regressor is specified in an experiment of the type "visual_proprioception". Running this notebook will train and save this model.

In [1]:
import sys
sys.path.append("..")
from settings import Config

import pathlib
from pprint import pprint
import matplotlib.pyplot as plt

import numpy as np
import torch
import torch.nn as nn
#import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

torch.manual_seed(1)


from visual_proprioception.visproprio_helper import load_demonstrations_as_proprioception_training, get_visual_proprioception_sp, load_multiview_demonstrations_as_proprioception_training
from visual_proprioception.visproprio_models import VisProprio_SimpleMLPRegression


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Loading pointer config file: /home/ssheikholeslami/.config/BerryPicker/mainsettings.yaml
Loading machine-specific config file: /home/ssheikholeslami/SaharaBerryPickerData/settings-sahara.yaml


In [None]:
experiment = "visual_proprioception"

##############################################
#                 SingleView                 #
##############################################

# the latent space 128 ones
# run = "vp_aruco_128"
# run = "vp_convvae_128"
# run = "vp_ptun_vgg19_128"
# run = "vp_ptun_resnet50_128"

# the latent space 256 ones
# run = "vp_convvae_256"
# run = "vp_ptun_vgg19_256"
# run = "vp_ptun_resnet50_256"

#vits
# run ="vit_base"
# run ="vit_large"
# run ="vit_huge"

##############################################
#                 MultiViews                 #
##############################################

#concat_proj

# run ="vit_base_multiview"
# run ="vit_large_multiview"
# run =vit_huge_multiview


##  indiv_proj
# run = "vit_base_multiview_indiv_proj"  # ViT Base_indiv_proj
# run = "vit_large_multiview_indiv_proj" # ViT Large_indiv_proj
# run = "vit_huge_multiview_indiv_proj" # ViT Huge_indiv_proj

##  attention
# run = "vit_base_multiview_attention"  # ViT Base_attention
run = "vit_large_multiview_attention" # ViT Large_attention
# run = "vit_huge_multiview_attention" # ViT Huge_attention


##  weighted_sum
# run = "vit_base_multiview_weighted_sum"  # ViT Base_weighted_sum
# run = "vit_large_multiview_weighted_sum" # ViT Large_weighted_sum
# run = "vit_huge_multiview_weighted_sum" # ViT Huge_weighted_sum

##  gated
# run = "vit_base_multiview_gated"  # ViT Base_gated
# run = "vit_large_multiview_gated" # ViT Large_gated
# run = "vit_huge_multiview_gated" # ViT Huge_gated

exp = Config().get_experiment(experiment, run)
pprint(exp)

sp = get_visual_proprioception_sp(exp, device)


No system dependent experiment file
 /home/ssheikholeslami/SaharaBerryPickerData/experiments-Config/visual_proprioception/vit_large_multiview_attention_sysdep.yaml,
 that is ok, proceeding.
Configuration for experiment: visual_proprioception/vit_large_multiview_attention successfully loaded
{'batch_size': 8,
 'data_dir': PosixPath('/home/ssheikholeslami/SaharaBerryPickerData/experiment_data/visual_proprioception/vit_large_multiview_attention'),
 'encoding_size': 128,
 'epochs': 1000,
 'exp_run_sys_indep_file': PosixPath('/lustre/fs1/home/ssheikholeslami/BerryPicker/src/experiment_configs/visual_proprioception/vit_large_multiview_attention.yaml'),
 'freeze_backbone': False,
 'freeze_feature_extractor': True,
 'fusion_type': 'attention',
 'group_name': 'visual_proprioception',
 'image_size': 224,
 'latent_size': 128,
 'learning_rate': 0.0001,
 'loss': 'MSE',
 'model_type': 'ViTProprioTunedRegression',
 'name': 'vit-large-128-multiview-attention',
 'num_views': 2,
 'output_size': 6,
 'pro

Exception: Missing experiment system independent config file /lustre/fs1/home/ssheikholeslami/BerryPicker/src/experiment_configs/sensorprocessing_propriotuned_Vit_multiview/vit_large_multiview-attention.yaml

In [None]:
# Create the regression model

model = VisProprio_SimpleMLPRegression(exp)
if exp["loss"] == "MSE":
    criterion = nn.MSELoss()
elif exp["loss"] == "L1":
    criterion = nn.L1Loss()
else:
    raise Exception(f'Unknown loss type {exp["loss"]}')

optimizer = optim.Adam(model.parameters(), lr=0.001)

### Load and cache the training data. 
* Iterate through the images and process them into latent encodings. 
* Iterate through the json files describing the robot position
* Save the input and target values into files in the experiment directory. These will act as caches for later runs
* Create the training and validation splits

In [None]:

task = exp["proprioception_training_task"]
proprioception_input_file = pathlib.Path(
    exp["data_dir"], exp["proprioception_input_file"])
proprioception_target_file = pathlib.Path(
    exp["data_dir"], exp["proprioception_target_file"])


# Check if we're using a multi-view approach
is_multiview = exp.get("sensor_processing", "").endswith("_multiview") or exp.get("num_views", 1) > 1

if is_multiview:
    print(f"Using multi-view approach with {exp.get('num_views', 2)} views")

    # Use the multiview loading function
    tr = load_multiview_demonstrations_as_proprioception_training(
        task,
        proprioception_input_file,
        proprioception_target_file,
        num_views=exp.get("num_views", 2)
    )

    # Create a custom dataset for multi-view data
    class MultiViewDataset(torch.utils.data.Dataset):
        def __init__(self, view_inputs, targets):
            self.view_inputs = view_inputs  # List of tensors, one per view
            self.targets = targets
            self.num_samples = len(targets)

        def __len__(self):
            return self.num_samples

        def __getitem__(self, idx):
            # Get corresponding sample from each view
            views = [view[idx] for view in self.view_inputs]
            target = self.targets[idx]
            return views, target

    # Create DataLoaders for batching
    batch_size = exp.get('batch_size', 32)
    train_dataset = MultiViewDataset(tr["view_inputs_training"], tr["targets_training"])
    test_dataset = MultiViewDataset(tr["view_inputs_validation"], tr["targets_validation"])

else:
    print("Using single-view approach")

    # Use the original loading function
    tr = load_demonstrations_as_proprioception_training(
        sp, task, proprioception_input_file, proprioception_target_file
    )

    inputs_training = tr["inputs_training"]
    targets_training = tr["targets_training"]
    inputs_validation = tr["inputs_validation"]
    targets_validation = tr["targets_validation"]

    # Create standard DataLoaders for single-view data
    batch_size = exp.get('batch_size', 32)
    train_dataset = TensorDataset(inputs_training, targets_training)
    test_dataset = TensorDataset(inputs_validation, targets_validation)





# Create DataLoaders for batching


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Using multi-view approach with 2 views
Loading cached data from /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/visual_proprioception/vit_large_multiview_indiv_proj/train_inputs.pt
Created 1169 training examples and 576 validation examples


### Perform the training

In [None]:
def train_and_save_proprioception_model(exp):
    """Trains and saves the proprioception model, handling both single and multi-view inputs
    with checkpoint support for resuming interrupted training
    """
    final_modelfile = pathlib.Path(exp["data_dir"], exp["proprioception_mlp_model_file"])
    checkpoint_dir = pathlib.Path(exp["data_dir"], "checkpoints")
    checkpoint_dir.mkdir(exist_ok=True)

    # Maximum number of checkpoints to keep (excluding the best model)
    max_checkpoints = 2

    # Check if we're using a multi-view approach
    is_multiview = exp.get("sensor_processing", "").endswith("_multiview") or exp.get("num_views", 1) > 1
    num_views = exp.get("num_views", 2)

    # First check for existing final model
    if final_modelfile.exists() and exp.get("reload_existing_model", True):
        print(f"Loading existing final model from {final_modelfile}")
        model.load_state_dict(torch.load(final_modelfile, map_location=device))

        # Evaluate the loaded model
        model.eval()
        with torch.no_grad():
            total_loss = 0
            batch_count = 0

            for batch_data in test_loader:
                if is_multiview:
                    batch_views, batch_y = batch_data

                    # Process the batch for evaluation
                    batch_size = batch_views[0].size(0)
                    batch_features = []

                    for i in range(batch_size):
                        sample_views = [view[i].unsqueeze(0).to(device) for view in batch_views]
                        sample_features = sp.process(sample_views)
                        # Convert numpy array to tensor and move to device
                        sample_features_tensor = torch.tensor(sample_features, dtype=torch.float32).to(device)
                        batch_features.append(sample_features_tensor)

                    batch_X = torch.stack(batch_features).to(device)
                    predictions = model(batch_X)
                else:
                    batch_X, batch_y = batch_data
                    batch_X = batch_X.to(device)
                    predictions = model(batch_X)

                # Make sure batch_y is on the same device
                batch_y = batch_y.to(device)
                loss = criterion(predictions, batch_y)
                total_loss += loss.item()
                batch_count += 1

            avg_loss = total_loss / max(batch_count, 1)
            print(f"Loaded model evaluation loss: {avg_loss:.4f}")

        return model

    # Function to extract epoch number from checkpoint file
    def get_epoch_number(checkpoint_file):
        try:
            # Use a more robust approach to extract epoch number
            # Format: epoch_XXXX.pth where XXXX is the epoch number
            filename = checkpoint_file.stem
            parts = filename.split('_')
            if len(parts) >= 2:
                return int(parts[1])  # Get the number after "epoch_"
            return 0
        except:
            return 0

    # Function to clean up old checkpoints
    def cleanup_old_checkpoints():
        # Get all epoch checkpoint files
        checkpoint_files = list(checkpoint_dir.glob("epoch_*.pth"))

        # Sort by actual epoch number, not just filename
        checkpoint_files.sort(key=get_epoch_number)

        if len(checkpoint_files) > max_checkpoints:
            files_to_delete = checkpoint_files[:-max_checkpoints]
            for file_path in files_to_delete:
                try:
                    file_path.unlink()
                    print(f"Deleted old checkpoint: {file_path.name}")
                except Exception as e:
                    print(f"Failed to delete {file_path.name}: {e}")

    # Make sure model is on the correct device
    model.to(device)
    print(f"Model moved to {device}")

    # Set training parameters
    num_epochs = exp["epochs"]
    start_epoch = 0
    best_loss = float('inf')

    # Check for existing checkpoints to resume from
    checkpoint_files = list(checkpoint_dir.glob("epoch_*.pth"))
    if checkpoint_files:
        # Sort by epoch number for more reliable ordering
        checkpoint_files.sort(key=get_epoch_number)

        # Get the most recent checkpoint
        latest_checkpoint = checkpoint_files[-1]
        epoch_num = get_epoch_number(latest_checkpoint)

        print(f"Found checkpoint from epoch {epoch_num}. Resuming training...")

        # Load checkpoint
        checkpoint = torch.load(latest_checkpoint, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_loss = checkpoint.get('best_loss', float('inf'))

        print(f"Resuming from epoch {start_epoch}/{num_epochs} with best loss: {best_loss:.4f}")
    else:
        print(f"Starting new training for {num_epochs} epochs")

    # Start or resume training
    for epoch in range(start_epoch, num_epochs):
        print(f"Starting epoch {epoch+1}/{num_epochs}")
        model.train()
        total_loss = 0
        batch_count = 0

        # Training loop handles both single and multi-view cases
        for batch_idx, batch_data in enumerate(train_loader):
            try:
                if is_multiview:
                    batch_views, batch_y = batch_data

                    # With multi-view, batch_views is a list of tensors, each with shape [batch_size, C, H, W]
                    batch_size = batch_views[0].size(0)
                    batch_features = []

                    # Process each sample in the batch
                    for i in range(batch_size):
                        # Extract this sample's views
                        sample_views = [view[i].unsqueeze(0).to(device) for view in batch_views]

                        # Process this sample through sp
                        sample_features = sp.process(sample_views)

                        # Convert numpy array to tensor and move to device
                        sample_features_tensor = torch.tensor(sample_features, dtype=torch.float32).to(device)
                        batch_features.append(sample_features_tensor)

                    # Stack all samples' features into a batch
                    batch_X = torch.stack(batch_features).to(device)

                    # Forward pass
                    predictions = model(batch_X)
                else:
                    batch_X, batch_y = batch_data
                    # Move to device
                    batch_X = batch_X.to(device)
                    # Standard single-view processing
                    predictions = model(batch_X)

                # Make sure batch_y is on the same device
                batch_y = batch_y.to(device)
                loss = criterion(predictions, batch_y)

                # Backward pass and optimization
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                batch_count += 1

                # Print progress every few batches
                if (batch_idx + 1) % 10 == 0:
                    print(f"  Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item():.4f}")

            except Exception as e:
                print(f"Error in batch {batch_idx}: {e}")
                # Save emergency checkpoint in case of error - use formatted epoch and batch numbers
                save_path = checkpoint_dir / f"emergency_epoch_{epoch:06d}_batch_{batch_idx:06d}.pth"
                torch.save({
                    'epoch': epoch,
                    'batch': batch_idx,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': total_loss / max(batch_count, 1),
                    'best_loss': best_loss
                }, save_path)
                print(f"Emergency checkpoint saved to {save_path}")
                continue

        # Calculate average loss for the epoch
        avg_loss = total_loss / max(batch_count, 1)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

        # Evaluate the model
        model.eval()
        test_loss = 0
        eval_batch_count = 0
        with torch.no_grad():
            for batch_data in test_loader:
                if is_multiview:
                    batch_views, batch_y = batch_data

                    # Process the batch the same way as in training
                    batch_size = batch_views[0].size(0)
                    batch_features = []

                    for i in range(batch_size):
                        sample_views = [view[i].unsqueeze(0).to(device) for view in batch_views]
                        sample_features = sp.process(sample_views)
                        # Convert numpy array to tensor and move to device
                        sample_features_tensor = torch.tensor(sample_features, dtype=torch.float32).to(device)
                        batch_features.append(sample_features_tensor)

                    batch_X = torch.stack(batch_features).to(device)
                    predictions = model(batch_X)
                else:
                    batch_X, batch_y = batch_data
                    batch_X = batch_X.to(device)
                    predictions = model(batch_X)

                # Make sure batch_y is on the same device
                batch_y = batch_y.to(device)
                loss = criterion(predictions, batch_y)
                test_loss += loss.item()
                eval_batch_count += 1

        avg_test_loss = test_loss / max(eval_batch_count, 1)
        print(f'Validation Loss: {avg_test_loss:.4f}')

        # Save checkpoint after each epoch - using formatted epoch numbers for reliable sorting
        checkpoint_path = checkpoint_dir / f"epoch_{epoch:06d}.pth"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': avg_loss,
            'test_loss': avg_test_loss,
            'best_loss': best_loss
        }, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

        # Clean up old checkpoints to save space
        cleanup_old_checkpoints()

        # Update best model if improved
        if avg_test_loss < best_loss:
            best_loss = avg_test_loss
            best_model_path = checkpoint_dir / "best_model.pth"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_loss,
                'test_loss': avg_test_loss,
                'best_loss': best_loss
            }, best_model_path)
            print(f"New best model saved with test loss: {best_loss:.4f}")

    # Training completed successfully
    print(f"Training complete. Best test loss: {best_loss:.4f}")

    # Load the best model for final save
    best_model_path = checkpoint_dir / "best_model.pth"
    if best_model_path.exists():
        best_checkpoint = torch.load(best_model_path, map_location=device)
        model.load_state_dict(best_checkpoint['model_state_dict'])
        print(f"Loaded best model from epoch {best_checkpoint['epoch']+1} with test loss {best_checkpoint['test_loss']:.4f}")

    # Save the final model
    torch.save(model.state_dict(), final_modelfile)
    print(f"Final model saved to {final_modelfile}")

    return model

In [None]:
# modelfile = pathlib.Path(Config()["explorations"]["proprioception_mlp_model_file"])

#if modelfile.exists():
#    model.load_state_dict(torch.load(modelfile))
#else:
train_and_save_proprioception_model(exp)

Model moved to cuda
Found checkpoint from epoch 598. Resuming training...
Resuming from epoch 599/1000 with best loss: 0.0020
Starting epoch 600/1000


  checkpoint = torch.load(latest_checkpoint, map_location=device)


  Batch 10/147, Loss: 0.0002
  Batch 20/147, Loss: 0.0001
  Batch 30/147, Loss: 0.0003
  Batch 40/147, Loss: 0.0002
  Batch 50/147, Loss: 0.0002
  Batch 60/147, Loss: 0.0002
  Batch 70/147, Loss: 0.0002
  Batch 80/147, Loss: 0.0002
  Batch 90/147, Loss: 0.0002
  Batch 100/147, Loss: 0.0001
  Batch 110/147, Loss: 0.0002
  Batch 120/147, Loss: 0.0001
  Batch 130/147, Loss: 0.0001
  Batch 140/147, Loss: 0.0003
Epoch [600/1000], Loss: 0.0002
Validation Loss: 0.0021
Checkpoint saved to /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/visual_proprioception/vit_large_multiview_indiv_proj/checkpoints/epoch_000599.pth
Deleted old checkpoint: epoch_000597.pth
Starting epoch 601/1000
  Batch 10/147, Loss: 0.0002
  Batch 20/147, Loss: 0.0002
  Batch 30/147, Loss: 0.0001
  Batch 40/147, Loss: 0.0001
  Batch 50/147, Loss: 0.0002
  Batch 60/147, Loss: 0.0002
  Batch 70/147, Loss: 0.0002
  Batch 80/147, Loss: 0.0002
  Batch 90/147, Loss: 0.0002
  Batch 100/147, Loss: 0.0001
  Batch 110/147,

  best_checkpoint = torch.load(best_model_path, map_location=device)


VisProprio_SimpleMLPRegression(
  (model): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=6, bias=True)
  )
)