# # Train a proprioception-tuned Multi-View CNN

We create a sensor processing model using CNN-based visual encoding with multiple views
finetuned with proprioception.

This script handles training with two camera views, fusing the information
to create a more robust representation of the robot state. We support both VGG19 and ResNet50 as base models.

In [1]:
import sys
sys.path.append("..")
from settings import Config

import pathlib
import torch
import torch.nn as nn
from torchvision import models, transforms
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Dataset
import numpy as np

from behavior_cloning.demo_to_trainingdata import BCDemonstration
from robot.al5d_position_controller import RobotPosition
# Import the model classes from the sensor processing module
from sensorprocessing.sp_propriotuned_cnn_multiview import MultiViewVGG19Model, MultiViewResNetModel
from sensorprocessing.sp_propriotuned_cnn_multiview import MultiViewVGG19SensorProcessing, MultiViewResNetSensorProcessing

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [2]:


# The experiment/run we are going to run: the specified model will be created
experiment = "sensorprocessing_propriotuned_cnn_multiview"
# run = "vgg19_128_multiview"
# run = "resnet50_128_multiview"
run = "vgg19_256_multiview"
# run = "resnet50_256_multiview"


exp = Config().get_experiment(experiment, run)

Loading pointer config file: /home/ssheikholeslami/.config/BerryPicker/mainsettings.yaml
Loading machine-specific config file: /home/ssheikholeslami/SaharaBerryPickerData/settings-sahara.yaml
No system dependent experiment file
 /home/ssheikholeslami/SaharaBerryPickerData/experiments-Config/sensorprocessing_propriotuned_cnn_multiview/vgg19_256_multiview_sysdep.yaml,
 that is ok, proceeding.
Configuration for experiment: sensorprocessing_propriotuned_cnn_multiview/vgg19_256_multiview successfully loaded


### Create regression training data (image to proprioception)
The training data (X, Y) is all the pictures from a demonstration with the corresponding proprioception data. 

In [3]:
# Create output directory if it doesn't exist
data_dir = pathlib.Path(exp["data_dir"])
data_dir.mkdir(parents=True, exist_ok=True)
print(f"Data directory: {data_dir}")

Data directory: /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_cnn_multiview/vgg19_256_multiview


In [4]:

# Define a custom dataset for multi-view images
class MultiViewDataset(Dataset):
    def __init__(self, view_inputs, targets):
        """
        Dataset for handling multiple camera views

        Args:
            view_inputs: List of tensors, one per view [view1, view2, ...]
                Each view tensor has shape [num_samples, C, H, W]
            targets: Tensor of target values [num_samples, output_dim]
        """
        self.view_inputs = view_inputs  # List of tensors, one per view
        self.targets = targets
        self.num_samples = len(targets)

        # Verify all views have the same number of samples
        for view in view_inputs:
            assert len(view) == self.num_samples, "All views must have the same number of samples"

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Get corresponding sample from each view
        views = [view[idx] for view in self.view_inputs]
        target = self.targets[idx]
        return views, target

def load_multiview_demonstrations(task, proprioception_input_file, proprioception_target_file, num_views=2):
    """
    Loads all the images of a task from multiple camera views, and processes it as two tensors
    as input and target data for proprioception training.

    Args:
        task: Task name to load demonstrations from
        proprioception_input_file: Path to save/load processed inputs
        proprioception_target_file: Path to save/load processed targets
        num_views: Number of camera views to process

    Returns:
        Dictionary containing training and validation data splits
    """

    retval = {}
    if proprioception_input_file.exists():
        print(f"Loading cached data from {proprioception_input_file}")
        retval["view_inputs"] = torch.load(proprioception_input_file, weights_only=True)
        retval["targets"] = torch.load(proprioception_target_file, weights_only=True)
    else:
        demos_dir = pathlib.Path(Config()["demos"]["directory"])
        task_dir = pathlib.Path(demos_dir, "demos", task)

        # Lists to store multi-view images and targets
        view_lists = {}  # Dictionary to organize views by camera
        targetlist = []

        print(f"Loading demonstrations from {task_dir}")
        for demo_dir in task_dir.iterdir():
            if not demo_dir.is_dir():
                continue

            print(f"Processing demonstration: {demo_dir.name}")
            # Create BCDemonstration with multi-camera support
            bcd = BCDemonstration(
                demo_dir,
                sensorprocessor=None
            )

            # Initialize view lists if not already done
            if not view_lists:
                for camera in bcd.cameras:
                    view_lists[camera] = []

            # Process each timestep
            for i in range(bcd.trim_from, bcd.trim_to):
                # Get all images for this timestep
                all_images = bcd.get_all_images(i)

                # If we don't have all required views, skip this timestep
                if len(all_images) < num_views:
                    print(f"  Skipping timestep {i} - only {len(all_images)}/{num_views} views available")
                    continue

                # Collect images from each camera
                for camera, (sensor_readings, _) in all_images.items():
                    if camera in view_lists:
                        view_lists[camera].append(sensor_readings[0])

                # Get the robot action for this timestep
                a = bcd.get_a(i)
                rp = RobotPosition.from_vector(a)
                anorm = rp.to_normalized_vector()
                targetlist.append(torch.from_numpy(anorm))

        # Ensure we have the same number of frames for each view
        min_frames = min(len(view_list) for view_list in view_lists.values())
        if min_frames < len(targetlist):
            print(f"Truncating dataset to {min_frames} frames (from {len(targetlist)})")
            targetlist = targetlist[:min_frames]
            for camera in view_lists:
                view_lists[camera] = view_lists[camera][:min_frames]

        # Stack tensors for each view
        view_tensors = []
        for camera in sorted(view_lists.keys())[:num_views]:  # Take only the required number of views
            view_tensors.append(torch.stack(view_lists[camera]))

        # Create multi-view input tensor list [num_views, num_samples, C, H, W]
        retval["view_inputs"] = view_tensors
        retval["targets"] = torch.stack(targetlist)

        # Save processed data
        torch.save(retval["view_inputs"], proprioception_input_file)
        torch.save(retval["targets"], proprioception_target_file)
        print(f"Saved {len(targetlist)} training examples with {num_views} views each")

    # Separate the training and validation data
    length = len(retval["targets"])
    rows = torch.randperm(length)

    # Shuffle targets
    shuffled_targets = retval["targets"][rows]

    # Shuffle each view input using the same row indices
    shuffled_view_inputs = []
    for view_tensor in retval["view_inputs"]:
        shuffled_view_inputs.append(view_tensor[rows])

    # Split into training (67%) and validation (33%) sets
    training_size = int(length * 0.67)

    # Training data
    retval["view_inputs_training"] = [view[:training_size] for view in shuffled_view_inputs]
    retval["targets_training"] = shuffled_targets[:training_size]

    # Validation data
    retval["view_inputs_validation"] = [view[training_size:] for view in shuffled_view_inputs]
    retval["targets_validation"] = shuffled_targets[training_size:]

    print(f"Created {training_size} training examples and {length - training_size} validation examples")
    return retval

In [5]:
# Main execution

# Load the training data
task = exp["proprioception_training_task"]
proprioception_input_file = pathlib.Path(exp["data_dir"], exp["proprioception_input_file"])
proprioception_target_file = pathlib.Path(exp["data_dir"], exp["proprioception_target_file"])
num_views = exp.get("num_views", 2)

tr = load_multiview_demonstrations(task, proprioception_input_file, proprioception_target_file, num_views)
view_inputs_training = tr["view_inputs_training"]
targets_training = tr["targets_training"]
view_inputs_validation = tr["view_inputs_validation"]
targets_validation = tr["targets_validation"]

Loading cached data from /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_cnn_multiview/vgg19_256_multiview/train_inputs.pt


Created 1205 training examples and 594 validation examples


### Create a model that performs proprioception regression

In [6]:

# Create the model based on configuration
if exp['model'] == 'VGG19ProprioTunedRegression_multiview':
    # Create the VGG19-based multi-view model by using the class from sp_propriotuned_cnn.py
    model = MultiViewVGG19Model(exp, device)
    # Alternatively, we could also use the sensor processing class directly
    # sp = MultiViewVGG19SensorProcessing(exp, device)
    # model = sp.enc
elif exp['model'] == 'ResNetProprioTunedRegression_multiview':
    # Create the ResNet-based multi-view model by using the class from sp_propriotuned_cnn.py
    model = MultiViewResNetModel(exp, device)
    # Alternatively, we could also use the sensor processing class directly
    # sp = MultiViewResNetSensorProcessing(exp, device)
    # model = sp.enc
else:
    raise Exception(f"Unknown model {exp['model']}")



In [7]:
# Create DataLoaders for batching
batch_size = exp.get('batch_size', 32)
train_dataset = MultiViewDataset(view_inputs_training, targets_training)
test_dataset = MultiViewDataset(view_inputs_validation, targets_validation)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [8]:
def train_and_save_multiview_model(model, criterion, optimizer, modelfile, device="cpu", epochs=20, scheduler=None, log_interval=1):
    """
    Trains and saves the multi-view CNN proprioception model

    Args:
        model: Multi-view CNN model
        criterion: Loss function
        optimizer: Optimizer
        modelfile: Path to save the final model
        device: Training device (cpu/cuda)
        epochs: Number of training epochs
        scheduler: Optional learning rate scheduler
        log_interval: How often to print logs
    """
    # Create checkpoint directory
    checkpoint_dir = modelfile.parent / "checkpoints"
    checkpoint_dir.mkdir(exist_ok=True)

    # Maximum number of checkpoints to keep
    max_checkpoints = 2

    # Function to extract epoch number from checkpoint file
    def get_epoch_number(checkpoint_file):
        try:
            filename = checkpoint_file.stem
            parts = filename.split('_')
            if len(parts) >= 2:
                return int(parts[1])
            return 0
        except:
            return 0

    # Keep track of the best validation loss
    best_val_loss = float('inf')
    start_epoch = 0

    # Check for existing checkpoints to resume from
    checkpoint_files = list(checkpoint_dir.glob("epoch_*.pth"))
    if checkpoint_files:
        # Sort by epoch number
        checkpoint_files.sort(key=get_epoch_number)

        # Get the most recent checkpoint
        latest_checkpoint = checkpoint_files[-1]
        epoch_num = get_epoch_number(latest_checkpoint)

        print(f"Found checkpoint from epoch {epoch_num}. Resuming training...")

        # Load checkpoint
        checkpoint = torch.load(latest_checkpoint, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_val_loss = checkpoint.get('best_val_loss', float('inf'))

        print(f"Resuming from epoch {start_epoch}/{epochs} with best validation loss: {best_val_loss:.4f}")

    # Function to clean up old checkpoints
    def cleanup_old_checkpoints():
        checkpoint_files = list(checkpoint_dir.glob("epoch_*.pth"))
        checkpoint_files.sort(key=get_epoch_number)

        if len(checkpoint_files) > max_checkpoints:
            files_to_delete = checkpoint_files[:-max_checkpoints]
            for file_path in files_to_delete:
                try:
                    file_path.unlink()
                    print(f"Deleted old checkpoint: {file_path.name}")
                except Exception as e:
                    print(f"Failed to delete {file_path.name}: {e}")

    # Ensure model is on the right device
    model = model.to(device)
    criterion = criterion.to(device)

    # Training loop
    for epoch in range(start_epoch, epochs):
        print(f"Starting epoch {epoch+1}/{epochs}")

        # Training phase
        model.train()
        total_loss = 0
        for batch_idx, (batch_views, batch_y) in enumerate(train_loader):
            # Move views and targets to device
            batch_views = [view.to(device) for view in batch_views]
            batch_y = batch_y.to(device)

            # Forward pass
            predictions = model(batch_views)
            loss = criterion(predictions, batch_y)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Print progress every few batches
            if (batch_idx + 1) % 10 == 0:
                print(f"  Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item():.4f}")

        avg_train_loss = total_loss / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch_views, batch_y in test_loader:
                batch_views = [view.to(device) for view in batch_views]
                batch_y = batch_y.to(device)

                predictions = model(batch_views)
                loss = criterion(predictions, batch_y)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(test_loader)

        # Update learning rate if scheduler is provided
        if scheduler is not None:
            scheduler.step(avg_val_loss)

        # Save checkpoint for this epoch
        checkpoint_path = checkpoint_dir / f"epoch_{epoch:06d}.pth"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
            'best_val_loss': best_val_loss
        }, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

        # Clean up old checkpoints
        cleanup_old_checkpoints()

        # Save the best model separately
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_path = checkpoint_dir / "best_model.pth"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'best_val_loss': best_val_loss
            }, best_model_path)
            print(f"  New best model saved with validation loss: {best_val_loss:.4f}")

        # Log progress
        if (epoch + 1) % log_interval == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

    # Training completed
    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

    # Load the best model for final save
    best_model_path = checkpoint_dir / "best_model.pth"
    if best_model_path.exists():
        best_checkpoint = torch.load(best_model_path, map_location=device)
        model.load_state_dict(best_checkpoint['model_state_dict'])
        print(f"Loaded best model from epoch {best_checkpoint['epoch']+1} with loss {best_checkpoint['val_loss']:.4f}")

    # Save final model
    torch.save(model.state_dict(), modelfile)
    print(f"Final model saved to {modelfile}")

    return model


In [9]:
# Test functions
def test_multiview_model(model, test_view_inputs, test_targets, device, n_samples=5):
    """Test the multi-view CNN model on random examples."""
    if n_samples > len(test_targets):
        n_samples = len(test_targets)

    # Get random indices
    indices = torch.randperm(len(test_targets))[:n_samples]

    print("\nTesting multi-view CNN model on random examples:")
    print("-" * 60)

    # Ensure model is in evaluation mode
    model.eval()

    with torch.no_grad():
        for i, idx in enumerate(indices):
            # Get views and target
            views = [view[idx].unsqueeze(0).to(device) for view in test_view_inputs]
            target = test_targets[idx].cpu().numpy()

            # Get latent representation
            latent = model.encode_views(views)

            # Get prediction
            prediction = model.proprioceptor(latent)
            prediction = prediction.cpu().numpy().flatten()

            # Print the results
            print(f"Example {i+1}:")
            for j, view in enumerate(views):
                print(f"  View {j+1} shape: {view.shape}")
            print(f"  Latent shape: {latent.shape}")
            print(f"  Target position: {target}")
            print(f"  Predicted position: {prediction}")
            print(f"  Mean squared error: {np.mean((prediction - target) ** 2):.6f}")
            print()

In [10]:

def test_model_encoding(model, view_inputs_validation, device, n_samples=3):
    """Test the encoding functionality of the multi-view CNN model."""
    if n_samples > len(view_inputs_validation[0]):
        n_samples = len(view_inputs_validation[0])

    # Get first few samples
    indices = range(n_samples)

    print("\nTesting model encoding functionality:")
    print("-" * 60)

    # Get expected dimensions from model configuration
    latent_size = model.latent_size
    output_size = model.output_size

    # Ensure model is in evaluation mode
    model.eval()

    with torch.no_grad():
        for i in indices:
            # Get views
            views = [view[i].unsqueeze(0).to(device) for view in view_inputs_validation]

            print(f"\nSample {i+1}:")

            # Test feature extraction
            features_list = []
            for j, view in enumerate(views):
                features = model.feature_extractors[j](view)
                flat_features = model.flatten(features)
                features_list.append(flat_features)
                print(f"  View {j+1} features shape: {flat_features.shape}")

            # Test concatenated features
            concat_features = torch.cat(features_list, dim=1)
            print(f"  Concatenated features shape: {concat_features.shape}")

            # Test latent encoding
            latent = model.encode_views(views)
            print(f"  Latent representation shape: {latent.shape}")
            # Assert latent has correct shape
            assert latent.shape[1] == latent_size, f"Expected latent size {latent_size}, got {latent.shape[1]}"

            # Test full forward pass
            output = model.forward(views)
            print(f"  Model output shape: {output.shape}")
            # Assert output has correct shape
            assert output.shape[1] == output_size, f"Expected output size {output_size}, got {output.shape[1]}"

            # Test that the proprioceptor gives the expected output when fed the latent
            proprio_output = model.proprioceptor(latent)
            print(f"  Proprioceptor output shape: {proprio_output.shape}")
            # Assert proprioceptor output has correct shape
            assert proprio_output.shape[1] == output_size, f"Expected proprioceptor output size {output_size}, got {proprio_output.shape[1]}"

            # Verify that direct proprioceptor output matches the forward pass
            is_close = torch.allclose(output, proprio_output, rtol=1e-5, atol=1e-5)
            print(f"  Forward pass matches proprioceptor: {is_close}")
            assert is_close, "Forward pass output does not match proprioceptor output"

            print("  All shape assertions passed!")


In [11]:
def test_process_file(model, view_inputs_validation, n_samples=3):
    """Test the process_file method which handles view caching."""
    # Create a temporary sensor processing wrapper
    class TempSP:
        def __init__(self, model):
            self.enc = model
            self._view_cache = {}
            self._timestep_cache = {}
            self._current_timestep = None

        def process(self, views_list):
            self.enc.eval()
            with torch.no_grad():
                z = self.enc.encode_views(views_list)
            z = torch.squeeze(z)
            return z.cpu().numpy()

        def process_file(self, file_path, camera_id=None):
            # Simulate loading an image
            # In a real implementation, this would load from file_path
            # For testing, we'll just sample from our validation set

            # Set a camera ID if none provided
            if camera_id is None:
                camera_id = f"camera_{np.random.randint(0, model.num_views)}"

            # Get timestep from filename if possible
            try:
                filename = pathlib.Path(file_path).stem
                timestep = int(filename.split('_')[0])
                self._current_timestep = timestep
            except:
                timestep = None

            # Get a validation image
            idx = np.random.randint(0, len(view_inputs_validation[0]))
            view_idx = int(camera_id.split('_')[-1]) if camera_id.startswith('camera_') else 0
            view_idx = min(view_idx, len(view_inputs_validation)-1)
            sensor_readings = view_inputs_validation[view_idx][idx].unsqueeze(0)

            # Update cache
            if timestep is not None:
                if timestep not in self._timestep_cache:
                    self._timestep_cache[timestep] = {}
                self._timestep_cache[timestep][camera_id] = sensor_readings

            self._view_cache[camera_id] = sensor_readings

            # Prepare views for processing
            views_list = []
            required_cameras = model.num_views

            # Try timestep cache first
            if timestep is not None and timestep in self._timestep_cache:
                timestep_views = self._timestep_cache[timestep]
                if len(timestep_views) == required_cameras:
                    for cam in sorted(timestep_views.keys())[:required_cameras]:
                        views_list.append(timestep_views[cam])
                    return self.process(views_list)

            # Use global cache
            views_list = [sensor_readings]
            other_cameras = [cam for cam in sorted(self._view_cache.keys()) if cam != camera_id]

            while len(views_list) < required_cameras and other_cameras:
                camera = other_cameras.pop(0)
                views_list.append(self._view_cache[camera])

            # Fill with duplicates if needed
            while len(views_list) < required_cameras:
                views_list.append(sensor_readings)

            return self.process(views_list)

    if n_samples > len(view_inputs_validation[0]):
        n_samples = len(view_inputs_validation[0])

    print("\nTesting process_file method with view caching:")
    print("-" * 60)

    # Create temporary sensor processor
    sp = TempSP(model)

    # Create mock file paths
    for i in range(n_samples):
        print(f"\nTest {i+1}:")

        # Test with one camera first
        first_camera = "camera_0"
        file_path_1 = f"{i+1:05d}_{first_camera}.jpg"
        latent_1 = sp.process_file(file_path_1, camera_id=first_camera)
        print(f"  Processed with camera 0, latent shape: {latent_1.shape}")

        # Now test with second camera, should use first camera from cache
        second_camera = "camera_1"
        file_path_2 = f"{i+1:05d}_{second_camera}.jpg"
        latent_2 = sp.process_file(file_path_2, camera_id=second_camera)
        print(f"  Processed with camera 1, latent shape: {latent_2.shape}")

        # Test cache recovery - use a different timestep
        file_path_3 = f"{i+2:05d}_{first_camera}.jpg"
        latent_3 = sp.process_file(file_path_3, camera_id=first_camera)
        print(f"  Processed with camera 0 (new timestep), latent shape: {latent_3.shape}")

        # Test cross-timestep behavior
        file_path_4 = f"{i+2:05d}_{second_camera}.jpg"
        latent_4 = sp.process_file(file_path_4, camera_id=second_camera)
        print(f"  Processed with camera 1 (new timestep), latent shape: {latent_4.shape}")

        # Verify that latent representations have correct dimensions
        assert latent_1.shape[0] == model.latent_size, f"Expected latent size {model.latent_size}, got {latent_1.shape[0]}"
        assert latent_2.shape[0] == model.latent_size, f"Expected latent size {model.latent_size}, got {latent_2.shape[0]}"

        print("  All shape assertions passed!")


In [12]:

# Test data visualization - show some predictions vs targets
def visualize_model_predictions(model, view_inputs_validation, targets_validation, device, n_samples=5):
    """Visualize model predictions compared to ground truth targets"""
    if n_samples > len(targets_validation):
        n_samples = len(targets_validation)

    # Get random indices
    indices = torch.randperm(len(targets_validation))[:n_samples]

    print("\nComparing model predictions with ground truth:")
    print("-" * 60)

    # Field names for robot position
    field_names = ["height", "distance", "heading", "wrist_angle", "wrist_rotation", "gripper"]

    # Error metrics
    total_mse = 0
    total_mae = 0

    # Ensure model is in evaluation mode
    model.eval()

    with torch.no_grad():
        for i, idx in enumerate(indices):
            # Get views and ground truth
            views = [view[idx].unsqueeze(0).to(device) for view in view_inputs_validation]
            ground_truth = targets_validation[idx].cpu().numpy()

            # Get model prediction
            prediction = model(views).cpu().numpy().flatten()

            # Calculate errors
            mse = np.mean((prediction - ground_truth) ** 2)
            mae = np.mean(np.abs(prediction - ground_truth))
            total_mse += mse
            total_mae += mae

            # Print comparison
            print(f"\nSample {i+1}:")
            print("  Field                 | Ground Truth | Prediction  | Difference")
            print("  ---------------------|--------------|-------------|------------")

            for j, field in enumerate(field_names):
                diff = prediction[j] - ground_truth[j]
                print(f"  {field.ljust(20)} | {ground_truth[j]:12.4f} | {prediction[j]:11.4f} | {diff:10.4f}")

            print(f"  MSE: {mse:.6f}, MAE: {mae:.6f}")

    # Print average errors
    avg_mse = total_mse / n_samples
    avg_mae = total_mae / n_samples
    print("\nAverage errors across samples:")
    print(f"  Mean Squared Error: {avg_mse:.6f}")
    print(f"  Mean Absolute Error: {avg_mae:.6f}")


In [13]:

# Test for the model's ability to handle noisy or missing views
def test_model_robustness(model, view_inputs_validation, targets_validation, device, n_samples=3):
    """Test the model's robustness to imperfect view data"""
    if n_samples > len(targets_validation):
        n_samples = len(targets_validation)

    # Get random indices
    indices = torch.randperm(len(targets_validation))[:n_samples]

    print("\nTesting model robustness to imperfect views:")
    print("-" * 60)

    # Ensure model is in evaluation mode
    model.eval()

    with torch.no_grad():
        for i, idx in enumerate(indices):
            # Get views and ground truth
            views_original = [view[idx].unsqueeze(0).to(device) for view in view_inputs_validation]
            ground_truth = targets_validation[idx].cpu().numpy()

            # Case 1: Normal prediction with clean views
            prediction_clean = model(views_original).cpu().numpy().flatten()

            # Case 2: Add Gaussian noise to one view
            views_noisy = views_original.copy()
            noise = torch.randn_like(views_noisy[0]) * 0.1
            views_noisy[0] = views_noisy[0] + noise
            prediction_noisy = model(views_noisy).cpu().numpy().flatten()

            # Case 3: Duplicate one view (simulating missing view)
            views_duplicate = [views_original[0], views_original[0]]  # Use first view twice
            prediction_duplicate = model(views_duplicate).cpu().numpy().flatten()

            # Calculate errors
            mse_clean = np.mean((prediction_clean - ground_truth) ** 2)
            mse_noisy = np.mean((prediction_noisy - ground_truth) ** 2)
            mse_duplicate = np.mean((prediction_duplicate - ground_truth) ** 2)

            # Print results
            print(f"\nSample {i+1}:")
            print(f"  Clean views MSE: {mse_clean:.6f}")
            print(f"  Noisy view MSE: {mse_noisy:.6f}")
            print(f"  Duplicate view MSE: {mse_duplicate:.6f}")
            print(f"  Noise impact: {((mse_noisy - mse_clean) / mse_clean) * 100:.2f}% increase in error")
            print(f"  Missing view impact: {((mse_duplicate - mse_clean) / mse_clean) * 100:.2f}% increase in error")


In [14]:

# Print model architecture summary
print(f"Created {exp['model']} with {sum(p.numel() for p in model.parameters())} parameters")
print(f"Number of views: {num_views}")
print(f"Latent dimension: {exp['latent_size']}")
print(f"Output dimension: {exp['output_size']}")


Created VGG19ProprioTunedRegression_multiview with 73776582 parameters
Number of views: 2
Latent dimension: 256
Output dimension: 6


In [15]:

expected_views = exp.get("num_views", 2)  # Default to 2 if not specified

# Check if model matches the configured number of views
if hasattr(model, 'feature_extractors') and isinstance(model.feature_extractors, nn.ModuleList):
    actual_views = len(model.feature_extractors)
    if actual_views == expected_views:
        print(f"Model confirmed as {expected_views}-view configuration")
    else:
        print(f"WARNING: Model supports {actual_views} views, but configuration specifies {expected_views} views")
else:
    print(f"WARNING: This appears to be a single-view model, not a {expected_views}-view model")

Model confirmed as 2-view configuration


In [16]:


# Select loss function
loss_type = exp.get('loss', 'MSELoss')
if loss_type == 'MSELoss':
    criterion = nn.MSELoss()
elif loss_type == 'L1Loss':
    criterion = nn.L1Loss()
else:
    criterion = nn.MSELoss()  # Default to MSE

# Set up optimizer with appropriate learning rate and weight decay
optimizer = optim.Adam(
    model.parameters(),
    lr=exp.get('learning_rate', 0.001),
    weight_decay=exp.get('weight_decay', 0.01)
)

# Optional learning rate scheduler
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)



In [17]:


# Train the model
modelfile = pathlib.Path(exp["data_dir"], exp["proprioception_mlp_model_file"])
epochs = exp.get("epochs", 20)

# Check if model exists already
if modelfile.exists() and exp.get("reload_existing_model", True):
    print(f"Loading existing final model from {modelfile}")
    model.load_state_dict(torch.load(modelfile, map_location=device))

    # Evaluate the loaded model
    model.eval()
    with torch.no_grad():
        val_loss = 0
        for batch_views, batch_y in test_loader:
            # Move views and targets to device
            batch_views = [view.to(device) for view in batch_views]
            batch_y = batch_y.to(device)

            predictions = model(batch_views)
            loss = criterion(predictions, batch_y)
            val_loss += loss.item()

        avg_val_loss = val_loss / len(test_loader)
        print(f"Loaded model validation loss: {avg_val_loss:.4f}")

    # Test the loaded model
    test_multiview_model(model, view_inputs_validation, targets_validation, device)
else:
    # Check for checkpoints to resume from, otherwise start fresh training
    checkpoint_dir = modelfile.parent / "checkpoints"
    checkpoint_dir.mkdir(exist_ok=True)

    # Use the get_epoch_number function for reliable sorting
    def get_epoch_number(checkpoint_file):
        try:
            filename = checkpoint_file.stem
            parts = filename.split('_')
            if len(parts) >= 2:
                return int(parts[1])
            return 0
        except:
            return 0

    checkpoint_files = list(checkpoint_dir.glob("epoch_*.pth"))
    if checkpoint_files:
        # Sort by epoch number
        checkpoint_files.sort(key=get_epoch_number)
        latest_epoch = get_epoch_number(checkpoint_files[-1])
        print(f"Found checkpoints up to epoch {latest_epoch}. Will resume training from last checkpoint.")
    else:
        print(f"Starting new training for {epochs} epochs")

    # Train the model
    model = train_and_save_multiview_model(
        model, criterion, optimizer, modelfile,
        device=device, epochs=epochs, scheduler=lr_scheduler
    )

    # Test the trained model
    test_multiview_model(model, view_inputs_validation, targets_validation, device)


Found checkpoints up to epoch 79. Will resume training from last checkpoint.
Found checkpoint from epoch 79. Resuming training...


  checkpoint = torch.load(latest_checkpoint, map_location=device)


Resuming from epoch 80/300 with best validation loss: 0.0154
Starting epoch 81/300
  Batch 10/38, Loss: 0.0170
  Batch 20/38, Loss: 0.0159
  Batch 30/38, Loss: 0.0136
Checkpoint saved to /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_cnn_multiview/vgg19_256_multiview/checkpoints/epoch_000080.pth
Deleted old checkpoint: epoch_000078.pth
  New best model saved with validation loss: 0.0145
Epoch [81/300], Train Loss: 0.0163, Val Loss: 0.0145
Starting epoch 82/300
  Batch 10/38, Loss: 0.0163
  Batch 20/38, Loss: 0.0135
  Batch 30/38, Loss: 0.0168
Checkpoint saved to /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_cnn_multiview/vgg19_256_multiview/checkpoints/epoch_000081.pth
Deleted old checkpoint: epoch_000079.pth
Epoch [82/300], Train Loss: 0.0163, Val Loss: 0.0146
Starting epoch 83/300
  Batch 10/38, Loss: 0.0147
  Batch 20/38, Loss: 0.0162
  Batch 30/38, Loss: 0.0145
Checkpoint saved to /home/ssheikhol

  best_checkpoint = torch.load(best_model_path, map_location=device)


Loaded best model from epoch 81 with loss 0.0145
Final model saved to /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_cnn_multiview/vgg19_256_multiview/proprioception_mlp.pth

Testing multi-view CNN model on random examples:
------------------------------------------------------------
Example 1:
  View 1 shape: torch.Size([1, 3, 256, 256])
  View 2 shape: torch.Size([1, 3, 256, 256])
  Latent shape: torch.Size([1, 256])
  Target position: [0.2709384  0.65022457 0.29910257 0.48742312 0.73724264 0.23908657]
  Predicted position: [0.3599518  0.67725927 0.22808337 0.43484753 0.441798   0.47818425]
  Mean squared error: 0.026820

Example 2:
  View 1 shape: torch.Size([1, 3, 256, 256])
  View 2 shape: torch.Size([1, 3, 256, 256])
  Latent shape: torch.Size([1, 256])
  Target position: [0.8884927  0.59413934 0.27151477 0.6784262  0.7311824  0.17874762]
  Predicted position: [0.84530586 0.6002718  0.2861594  0.6194405  0.7351022  0.17558551]
  Mean squ

In [18]:

# Run additional tests
print("\nRunning additional model tests...\n")
test_model_encoding(model, view_inputs_validation, device)
test_process_file(model, view_inputs_validation)
visualize_model_predictions(model, view_inputs_validation, targets_validation, device)
test_model_robustness(model, view_inputs_validation, targets_validation, device)



Running additional model tests...


Testing model encoding functionality:
------------------------------------------------------------

Sample 1:
  View 1 features shape: torch.Size([1, 32768])
  View 2 features shape: torch.Size([1, 32768])
  Concatenated features shape: torch.Size([1, 65536])
  Latent representation shape: torch.Size([1, 256])
  Model output shape: torch.Size([1, 6])
  Proprioceptor output shape: torch.Size([1, 6])
  Forward pass matches proprioceptor: True
  All shape assertions passed!

Sample 2:
  View 1 features shape: torch.Size([1, 32768])
  View 2 features shape: torch.Size([1, 32768])
  Concatenated features shape: torch.Size([1, 65536])
  Latent representation shape: torch.Size([1, 256])
  Model output shape: torch.Size([1, 6])
  Proprioceptor output shape: torch.Size([1, 6])
  Forward pass matches proprioceptor: True
  All shape assertions passed!

Sample 3:
  View 1 features shape: torch.Size([1, 32768])
  View 2 features shape: torch.Size([1, 32768])
  C

In [19]:

print("\nMulti-view CNN training and testing complete!")
print(f"Model type: {exp['model']}")
print(f"Number of views: {num_views}")
print(f"Latent space dimension: {exp['latent_size']}")
print(f"Output dimension (robot DOF): {exp['output_size']}")
print(f"Use the appropriate MultiView sensor processing class to load and use this model for inference.")


Multi-view CNN training and testing complete!
Model type: VGG19ProprioTunedRegression_multiview
Number of views: 2
Latent space dimension: 256
Output dimension (robot DOF): 6
Use the appropriate MultiView sensor processing class to load and use this model for inference.
