# Train a proprioception-tuned Multi-View Vision Transformer (ViT)
We create a sensor processing model using multiple Vision Transformer (ViT) based visual encoders
finetuned with proprioception.
We start with pretrained ViT models, then train them to:
1. Create a meaningful 128-dimensional latent representation from multiple camera views
2. Learn to map this representation to robot positions (proprioception)
 The sensor processing object associated with the trained model is in sensorprocessing/sp_vit_multiview.py


In [None]:
import sys
sys.path.append("..")

from exp_run_config import Config
Config.PROJECTNAME = "BerryPicker"

import pathlib
import torch
import torch.nn as nn
from torchvision import models, transforms
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

from demonstration.demonstration_helper import BCDemonstration
from sensorprocessing.sp_vit_multiview import MultiViewVitSensorProcessing
from robot.al5d_position_controller import RobotPosition

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:

# The experiment/run we are going to run: the specified model will be created
experiment = "sensorprocessing_propriotuned_Vit_multiview"


# Other possible configurations:

#concat_proj
# run = "vit_base_multiview"  # ViT Base
run = "vit_large_multiview" # ViT Large
# run = "vit_huge_multiview" # ViT Huge

##  indiv_proj
# run = "vit_base_multiview_indiv_proj"  # ViT Base_indiv_proj
# run = "vit_large_multiview_indiv_proj" # ViT Large_indiv_proj
# run = "vit_huge_multiview_indiv_proj" # ViT Huge


##  attention
# run = "vit_base_multiview_attention"  # ViT Base_attention
# run = "vit_large_multiview_attention" # ViT Large_attention
# run = "vit_huge_multiview_attention" # ViT Huge_attention

##  weighted_sum
# run = "vit_base_multiview_weighted_sum"  # ViT Base_weighted_sum
# run = "vit_large_multiview_weighted_sum" # ViT Large_weighted_sum
# run = "vit_huge_multiview_weighted_sum" # ViT Huge_weighted_sum

##  gated
# run = "vit_base_multiview_gated"  # ViT Base_gated
# run = "vit_large_multiview_gated" # ViT Large_gated
# run = "vit_huge_multiview_gated" # ViT Huge_gated

exp = Config().get_experiment(experiment, run)


Loading pointer config file: /home/ssheikholeslami/.config/BerryPicker/mainsettings.yaml
Loading machine-specific config file: /home/ssheikholeslami/SaharaBerryPickerData/settings-sahara.yaml
No system dependent experiment file
 /home/ssheikholeslami/SaharaBerryPickerData/experiments-Config/sensorprocessing_propriotuned_Vit_multiview/vit_large_multiview_sysdep.yaml,
 that is ok, proceeding.
Configuration for experiment: sensorprocessing_propriotuned_Vit_multiview/vit_large_multiview successfully loaded


In [None]:
# Create output directory if it doesn't exist
data_dir = pathlib.Path(exp["data_dir"])
data_dir.mkdir(parents=True, exist_ok=True)
print(f"Data directory: {data_dir}")

Data directory: /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_Vit_multiview/vit_large_multiview


### Create regression training data (image to proprioception)
The training data (XX, Y) is all the 2-view pictures from a demonstration with the corresponding proprioception data.

In [None]:
def load_multiview_images_as_proprioception_training(task, proprioception_input_file, proprioception_target_file):
    """Loads all the images of a task from multiple camera views, and processes it as two tensors
    as input and target data for proprioception training.

    Caches the processed results into the input and target file pointed in the config.
    Remove those files to recalculate.
    """
    retval = {}
    if proprioception_input_file.exists():
        print(f"Loading cached data from {proprioception_input_file}")
        retval["view_inputs"] = torch.load(proprioception_input_file, weights_only=True)
        retval["targets"] = torch.load(proprioception_target_file, weights_only=True)
    else:
        demos_dir = pathlib.Path(Config()["demos"]["directory"])
        task_dir = pathlib.Path(demos_dir, "demos", task)

        # Lists to store multi-view images and targets
        view_lists = {}  # Dictionary to organize views by camera
        targetlist = []
        num_views = exp.get("num_views", 2)

        print(f"Loading demonstrations from {task_dir}")
        for demo_dir in task_dir.iterdir():
            if not demo_dir.is_dir():
                continue

            print(f"Processing demonstration: {demo_dir.name}")
            # Create BCDemonstration with multi-camera support
            bcd = BCDemonstration(
                demo_dir,
                sensorprocessor=None,
                cameras=None  # This will detect all available cameras
            )

            # Initialize view lists if not already done
            if not view_lists:
                for camera in bcd.cameras:
                    view_lists[camera] = []

            # Process each timestep
            for i in range(bcd.trim_from, bcd.trim_to):
                # Get all images for this timestep
                all_images = bcd.get_all_images(i)

                # If we don't have all required views, skip this timestep
                if len(all_images) < num_views:
                    print(f"  Skipping timestep {i} - only {len(all_images)}/{num_views} views available")
                    continue

                # Collect images from each camera
                for camera, (sensor_readings, _) in all_images.items():
                    if camera in view_lists:
                        view_lists[camera].append(sensor_readings[0])

                # Get the robot action for this timestep
                a = bcd.get_a(i)
                rp = RobotPosition.from_vector(a)
                anorm = rp.to_normalized_vector()
                targetlist.append(torch.from_numpy(anorm))

        # Ensure we have the same number of frames for each view
        min_frames = min(len(view_list) for view_list in view_lists.values())
        if min_frames < len(targetlist):
            print(f"Truncating dataset to {min_frames} frames (from {len(targetlist)})")
            targetlist = targetlist[:min_frames]
            for camera in view_lists:
                view_lists[camera] = view_lists[camera][:min_frames]

        # Stack tensors for each view
        view_tensors = []
        for camera in sorted(view_lists.keys())[:num_views]:  # Take only the required number of views
            view_tensors.append(torch.stack(view_lists[camera]))

        # Create multi-view input tensor [num_samples, num_views, C, H, W]
        retval["view_inputs"] = view_tensors
        retval["targets"] = torch.stack(targetlist)

        # Save processed data
        torch.save(retval["view_inputs"], proprioception_input_file)
        torch.save(retval["targets"], proprioception_target_file)
        print(f"Saved {len(targetlist)} training examples with {num_views} views each")

    # Separate the training and validation data
    length = len(retval["targets"])
    rows = torch.randperm(length)

    # Shuffle targets
    shuffled_targets = retval["targets"][rows]

    # Shuffle each view input using the same row indices
    shuffled_view_inputs = []
    for view_tensor in retval["view_inputs"]:
        shuffled_view_inputs.append(view_tensor[rows])

    # Split into training (67%) and validation (33%) sets
    training_size = int(length * 0.67)

    # Training data
    retval["view_inputs_training"] = [view[:training_size] for view in shuffled_view_inputs]
    retval["targets_training"] = shuffled_targets[:training_size]

    # Validation data
    retval["view_inputs_validation"] = [view[training_size:] for view in shuffled_view_inputs]
    retval["targets_validation"] = shuffled_targets[training_size:]

    print(f"Created {training_size} training examples and {length - training_size} validation examples")
    return retval


In [None]:

# Load the training data
task = exp["proprioception_training_task"]
proprioception_input_file = pathlib.Path(exp["data_dir"], exp["proprioception_input_file"])
proprioception_target_file = pathlib.Path(exp["data_dir"], exp["proprioception_target_file"])

tr = load_multiview_images_as_proprioception_training(task, proprioception_input_file, proprioception_target_file)
view_inputs_training = tr["view_inputs_training"]
targets_training = tr["targets_training"]
view_inputs_validation = tr["view_inputs_validation"]
targets_validation = tr["targets_validation"]

Loading cached data from /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_Vit_multiview/vit_large_multiview/train_inputs.pt


KeyboardInterrupt: 

### Create the multi-view ViT model


In [None]:

# Create the multi-view ViT model
sp = MultiViewVitSensorProcessing(exp, device)
model = sp.enc  # Get the actual encoder model for training

print("Model created successfully")

try:
    params = model.parameters()
    print("Parameters accessed successfully")
    param_count = sum(p.numel() for p in params)
    print(f"Total parameters: {param_count}")
except Exception as e:
    print(f"Error accessing parameters: {e}")

# Select loss function
loss_type = exp.get('loss', 'MSELoss')
if loss_type == 'MSELoss':
    criterion = nn.MSELoss()
elif loss_type == 'L1Loss':
    criterion = nn.L1Loss()
else:
    criterion = nn.MSELoss()  # Default to MSE

# Set up optimizer with appropriate learning rate and weight decay
optimizer = optim.Adam(
    model.parameters(),
    lr=exp.get('learning_rate', 0.001),
    weight_decay=exp.get('weight_decay', 0.01)
)

# Optional learning rate scheduler
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)


Initializing Multi-View ViT Sensor Processing:
  Model: vit_l_16
  Number of views: 2
  Fusion type: gated
  Latent dimension: 128
  Image size: 224x224


Using 2 x vit_l_16 with output dimension 1024
Created gated fusion network
Created proprioceptor: 128 → 64 → 64 → 6
Feature extractors frozen. Projection and proprioceptor layers are trainable.
Model created successfully
Parameters accessed successfully
Total parameters: 607803144




### Custom dataset for multi-view data

In [None]:

class MultiViewDataset(torch.utils.data.Dataset):
    def __init__(self, view_inputs, targets):
        self.view_inputs = view_inputs  # List of tensors, one per view
        self.targets = targets
        self.num_samples = len(targets)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Get corresponding sample from each view
        views = [view[idx] for view in self.view_inputs]
        target = self.targets[idx]
        return views, target

In [None]:

# Create DataLoaders for batching
batch_size = exp.get('batch_size', 32)
train_dataset = MultiViewDataset(view_inputs_training, targets_training)
test_dataset = MultiViewDataset(view_inputs_validation, targets_validation)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
def train_and_save_multiview_proprioception_model(model, criterion, optimizer, modelfile,
                                                device="cpu", epochs=20, scheduler=None,
                                                log_interval=1):
    """Trains and saves the multi-view ViT proprioception model with checkpoint support

    Args:
        model: Multi-view ViT model with proprioception
        criterion: Loss function
        optimizer: Optimizer
        modelfile: Path to save the final model
        device: Training device (cpu/cuda)
        epochs: Number of training epochs
        scheduler: Optional learning rate scheduler
        log_interval: How often to print logs
    """
    # Create checkpoint directory
    checkpoint_dir = modelfile.parent / "checkpoints"
    checkpoint_dir.mkdir(exist_ok=True)

    # Maximum number of checkpoints to keep (excluding the best model)
    max_checkpoints = 2

    # Ensure model is on the right device
    model = model.to(device)
    criterion = criterion.to(device)

    # Function to extract epoch number from checkpoint file
    def get_epoch_number(checkpoint_file):
        try:
            # Use a more robust approach to extract epoch number
            # Format: epoch_XXXX.pth where XXXX is the epoch number
            filename = checkpoint_file.stem
            parts = filename.split('_')
            if len(parts) >= 2:
                return int(parts[1])  # Get the number after "epoch_"
            return 0
        except:
            return 0

    # Keep track of the best validation loss
    best_val_loss = float('inf')
    start_epoch = 0

    # Check for existing checkpoints to resume from
    checkpoint_files = list(checkpoint_dir.glob("epoch_*.pth"))
    if checkpoint_files:
        # Sort by epoch number for more reliable ordering
        checkpoint_files.sort(key=get_epoch_number)

        # Get the most recent checkpoint
        latest_checkpoint = checkpoint_files[-1]
        epoch_num = get_epoch_number(latest_checkpoint)

        print(f"Found checkpoint from epoch {epoch_num}. Resuming training...")

        # Load checkpoint
        checkpoint = torch.load(latest_checkpoint, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_val_loss = checkpoint.get('best_val_loss', float('inf'))

        print(f"Resuming from epoch {start_epoch}/{epochs} with best validation loss: {best_val_loss:.4f}")

    # Function to clean up old checkpoints
    def cleanup_old_checkpoints():
        # Get all epoch checkpoint files
        checkpoint_files = list(checkpoint_dir.glob("epoch_*.pth"))

        # Sort by actual epoch number, not just filename
        checkpoint_files.sort(key=get_epoch_number)

        if len(checkpoint_files) > max_checkpoints:
            files_to_delete = checkpoint_files[:-max_checkpoints]
            for file_path in files_to_delete:
                try:
                    file_path.unlink()
                    print(f"Deleted old checkpoint: {file_path.name}")
                except Exception as e:
                    print(f"Failed to delete {file_path.name}: {e}")

    # Training loop
    for epoch in range(start_epoch, epochs):
        print(f"Starting epoch {epoch+1}/{epochs}")

        # Training phase
        model.train()
        total_loss = 0
        for batch_idx, (batch_views, batch_y) in enumerate(train_loader):
            # Move views and targets to device
            batch_views = [view.to(device) for view in batch_views]
            batch_y = batch_y.to(device)

            # Forward pass through the full model (including proprioceptor)
            predictions = model.forward(batch_views)
            loss = criterion(predictions, batch_y)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Print progress every few batches
            if (batch_idx + 1) % 10 == 0:
                print(f"  Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item():.4f}")

        avg_train_loss = total_loss / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch_views, batch_y in test_loader:
                # Move views and targets to device
                batch_views = [view.to(device) for view in batch_views]
                batch_y = batch_y.to(device)

                predictions = model(batch_views)
                loss = criterion(predictions, batch_y)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(test_loader)

        # Update learning rate if scheduler is provided
        if scheduler is not None:
            scheduler.step(avg_val_loss)

        # Save checkpoint for this epoch - using formatted epoch numbers for reliable sorting
        checkpoint_path = checkpoint_dir / f"epoch_{epoch:06d}.pth"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
            'best_val_loss': best_val_loss
        }, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

        # Clean up old checkpoints to save space
        cleanup_old_checkpoints()

        # Save the best model separately
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_path = checkpoint_dir / "best_model.pth"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'best_val_loss': best_val_loss
            }, best_model_path)
            print(f"  New best model saved with validation loss: {best_val_loss:.4f}")

        # Log progress
        if (epoch + 1) % log_interval == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

    # Training completed successfully
    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

    # Load the best model for final save
    best_model_path = checkpoint_dir / "best_model.pth"
    if best_model_path.exists():
        best_checkpoint = torch.load(best_model_path, map_location=device)
        model.load_state_dict(best_checkpoint['model_state_dict'])
        print(f"Loaded best model from epoch {best_checkpoint['epoch']+1} with loss {best_checkpoint['val_loss']:.4f}")

    # Save to final model file only after completing all epochs
    torch.save(model.state_dict(), modelfile)
    print(f"Final model saved to {modelfile}")

    return model


# Set up model file path and epochs
modelfile = pathlib.Path(exp["data_dir"], exp["proprioception_mlp_model_file"])
epochs = exp.get("epochs", 20)

# First check for existing final model
if modelfile.exists() and exp.get("reload_existing_model", True):
    print(f"Loading existing final model from {modelfile}")
    model.load_state_dict(torch.load(modelfile, map_location=device))

    # Evaluate the loaded model
    model.eval()
    with torch.no_grad():
        val_loss = 0
        for batch_views, batch_y in test_loader:
            # Move views and targets to device
            batch_views = [view.to(device) for view in batch_views]
            batch_y = batch_y.to(device)

            predictions = model(batch_views)
            loss = criterion(predictions, batch_y)
            val_loss += loss.item()

        avg_val_loss = val_loss / len(test_loader)
        print(f"Loaded model validation loss: {avg_val_loss:.4f}")
else:
    # Check for checkpoints to resume from, otherwise start fresh training
    checkpoint_dir = modelfile.parent / "checkpoints"
    checkpoint_dir.mkdir(exist_ok=True)

    # Use the get_epoch_number function for reliable sorting
    def get_epoch_number(checkpoint_file):
        try:
            filename = checkpoint_file.stem
            parts = filename.split('_')
            if len(parts) >= 2:
                return int(parts[1])
            return 0
        except:
            return 0

    checkpoint_files = list(checkpoint_dir.glob("epoch_*.pth"))
    if checkpoint_files:
        # Sort by epoch number
        checkpoint_files.sort(key=get_epoch_number)
        latest_epoch = get_epoch_number(checkpoint_files[-1])
        print(f"Found checkpoints up to epoch {latest_epoch}. Will resume training from last checkpoint.")
    else:
        print(f"Starting new training for {epochs} epochs")

    # Train the model with checkpoint support
    model = train_and_save_multiview_proprioception_model(
        model, criterion, optimizer, modelfile,
        device=device, epochs=epochs, scheduler=lr_scheduler
    )

Starting new training for 300 epochs
Starting epoch 1/300


  Batch 10/37, Loss: 0.3241
  Batch 20/37, Loss: 0.2129
  Batch 30/37, Loss: 0.1449
Checkpoint saved to /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_Vit_multiview/vit_large_multiview_gated/checkpoints/epoch_000000.pth
  New best model saved with validation loss: 0.1244
Epoch [1/300], Train Loss: 0.2276, Val Loss: 0.1244
Starting epoch 2/300
  Batch 10/37, Loss: 0.0622
  Batch 20/37, Loss: 0.0477
  Batch 30/37, Loss: 0.0425
Checkpoint saved to /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_Vit_multiview/vit_large_multiview_gated/checkpoints/epoch_000001.pth
  New best model saved with validation loss: 0.0407
Epoch [2/300], Train Loss: 0.0595, Val Loss: 0.0407
Starting epoch 3/300
  Batch 10/37, Loss: 0.0371
  Batch 20/37, Loss: 0.0294
  Batch 30/37, Loss: 0.0306
Checkpoint saved to /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_Vit_multiview/vit_large_multiv

  best_checkpoint = torch.load(best_model_path, map_location=device)


Loaded best model from epoch 187 with loss 0.0078
Final model saved to /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_Vit_multiview/vit_large_multiview_gated/proprioception_mlp.pth


### Test the trained model


In [None]:

# Create the sensor processing module using the trained model
sp = MultiViewVitSensorProcessing(exp, device)

# Test it on a few validation examples
def test_multiview_sensor_processing(sp, test_view_inputs, test_targets, n_samples=5):
    """Test the multi-view sensor processing module on a few examples."""
    if n_samples > len(test_targets):
        n_samples = len(test_targets)

    # Get random indices
    indices = torch.randperm(len(test_targets))[:n_samples]

    print("\nTesting multi-view sensor processing on random examples:")
    print("-" * 60)

    for i, idx in enumerate(indices):
        # Get views and target
        views = [view[idx].unsqueeze(0).to(device) for view in test_view_inputs]  # Add batch dimension
        target = test_targets[idx].cpu().numpy()

        # Process the views to get the latent representation
        latent = sp.process(views)

        # Print the results
        print(f"Example {i+1}:")
        for j, view in enumerate(views):
            print(f"  View {j+1} shape: {view.shape}")
        print(f"  Latent shape: {latent.shape}")
        print(f"  Target position: {target}")
        print()

# Test the sensor processing
test_multiview_sensor_processing(sp, view_inputs_validation, targets_validation)

Initializing Multi-View ViT Sensor Processing:
  Model: vit_l_16
  Number of views: 2
  Fusion type: gated
  Latent dimension: 128
  Image size: 224x224


Using 2 x vit_l_16 with output dimension 1024
Created gated fusion network
Created proprioceptor: 128 → 64 → 64 → 6
Feature extractors frozen. Projection and proprioceptor layers are trainable.
Loading Multi-View ViT encoder weights from /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_Vit_multiview/vit_large_multiview_gated/proprioception_mlp.pth


  self.enc.load_state_dict(torch.load(modelfile, map_location=device))



Testing multi-view sensor processing on random examples:
------------------------------------------------------------
Example 1:
  View 1 shape: torch.Size([1, 3, 256, 256])
  View 2 shape: torch.Size([1, 3, 256, 256])
  Latent shape: (128,)
  Target position: [0.8857948  0.45419306 0.66338944 0.91739106 0.51170576 0.7944365 ]

Example 2:
  View 1 shape: torch.Size([1, 3, 256, 256])
  View 2 shape: torch.Size([1, 3, 256, 256])
  Latent shape: (128,)
  Target position: [0.11979929 0.64397836 0.28895995 0.37931362 0.10096735 0.35096252]

Example 3:
  View 1 shape: torch.Size([1, 3, 256, 256])
  View 2 shape: torch.Size([1, 3, 256, 256])
  Latent shape: (128,)
  Target position: [0.46494654 0.85770124 0.21145414 0.5182245  0.62526083 0.9698937 ]

Example 4:
  View 1 shape: torch.Size([1, 3, 256, 256])
  View 2 shape: torch.Size([1, 3, 256, 256])
  Latent shape: (128,)
  Target position: [0.8515704  0.9489891  0.69166666 0.65402603 0.15616164 0.22596249]

Example 5:
  View 1 shape: torch.

### Verify the model's encoding and forward methods

In [None]:

model.eval()
with torch.no_grad():
    # Get a sample set of views
    sample_views = [view[0].unsqueeze(0).to(device) for view in view_inputs_validation]

    # Get the latent representation using encode
    latent = model.encode(sample_views)
    print(f"Latent representation shape: {latent.shape}")

    # Get the robot position prediction using forward
    position = model.forward(sample_views)
    print(f"Robot position prediction shape: {position.shape}")

    # Check that the latent representation has the expected size
    expected_latent_size = exp["latent_size"]
    assert latent.shape[1] == expected_latent_size, f"Expected latent size {expected_latent_size}, got {latent.shape[1]}"

    # Check that the position prediction has the expected size
    expected_output_size = exp["output_size"]
    assert position.shape[1] == expected_output_size, f"Expected output size {expected_output_size}, got {position.shape[1]}"

    print("Verification successful!")


Latent representation shape: torch.Size([1, 128])
Robot position prediction shape: torch.Size([1, 6])
Verification successful!


### Test the process_file method with cache handling

In [None]:
def test_process_file(sp, view_inputs_validation, n_samples=3):
    """Test the process_file method which handles view caching."""
    if n_samples > len(view_inputs_validation[0]):
        n_samples = len(view_inputs_validation[0])

    # Get first few samples
    indices = range(n_samples)

    print("\nTesting process_file method with view caching:")
    print("-" * 60)

    # Create temporary files to simulate image files
    import tempfile
    import os
    from torchvision.utils import save_image

    temp_dir = tempfile.mkdtemp()
    try:
        # Create temporary image files for testing
        for idx in indices:
            print(f"\nProcessing timestep {idx+1}:")

            # Create a file for each view at this timestep
            file_paths = []
            for view_idx, view_tensor in enumerate(view_inputs_validation):
                # Create file name with timestep and camera ID
                file_name = f"{idx+1:05d}_camera{view_idx+1}.jpg"
                file_path = os.path.join(temp_dir, file_name)

                # Save tensor as image
                save_image(view_tensor[idx], file_path)
                file_paths.append(file_path)

            # Test processing with all views available
            print("  Testing with all views available:")
            for view_idx, file_path in enumerate(file_paths):
                latent = sp.process_file(file_path, camera_id=f"camera{view_idx+1}")
                print(f"    Processed view {view_idx+1}, latent shape: {latent.shape}")

            # Test processing with missing views (if we have more than one view)
            if len(file_paths) > 1:
                print("  Testing with one missing view:")
                # Clear cache to simulate new scenario
                if hasattr(sp, '_view_cache'):
                    sp._view_cache = {}
                    sp._timestep_cache = {}

                # Process only one view and see if cache is updated
                latent = sp.process_file(file_paths[0], camera_id=f"camera1")
                print(f"    Processed only view 1, latent shape: {latent.shape}")

                # Now process a different view and see if previous view is used from cache
                latent = sp.process_file(file_paths[1], camera_id=f"camera2")
                print(f"    Processed view 2 with view 1 from cache, latent shape: {latent.shape}")

    finally:
        # Clean up temporary directory
        import shutil
        shutil.rmtree(temp_dir)

# Test the process_file method with caching
test_process_file(sp, view_inputs_validation)


Testing process_file method with view caching:
------------------------------------------------------------

Processing timestep 1:
  Testing with all views available:
Using 1 unique views (with 1 duplicated)
    Processed view 1, latent shape: (128,)
Using complete set of 2 views from timestep 1
    Processed view 2, latent shape: (128,)
  Testing with one missing view:
Using 1 unique views (with 1 duplicated)
    Processed only view 1, latent shape: (128,)
Using complete set of 2 views from timestep 1
    Processed view 2 with view 1 from cache, latent shape: (128,)

Processing timestep 2:
  Testing with all views available:
Using 2 unique views (with 0 duplicated)
    Processed view 1, latent shape: (128,)
Using complete set of 2 views from timestep 2
    Processed view 2, latent shape: (128,)
  Testing with one missing view:
Using 1 unique views (with 1 duplicated)
    Processed only view 1, latent shape: (128,)
Using complete set of 2 views from timestep 2
    Processed view 2 wi

### Save final model and print summary

In [None]:

final_modelfile = pathlib.Path(exp["data_dir"], exp["proprioception_mlp_model_file"])
torch.save(model.state_dict(), final_modelfile)
print(f"Model saved to {final_modelfile}")

print("\nTraining complete!")
print(f"Vision Transformer type: {exp['vit_model']}")
print(f"Number of views: {exp.get('num_views', 2)}")
print(f"Fusion type: {exp.get('fusion_type', 'concat_proj')}")
print(f"Latent space dimension: {exp['latent_size']}")
print(f"Output dimension (robot DOF): {exp['output_size']}")
print(f"Use the MultiViewVitSensorProcessing class to load and use this model for inference.")

Model saved to /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_Vit_multiview/vit_large_multiview_gated/proprioception_mlp.pth

Training complete!
Vision Transformer type: vit_l_16
Number of views: 2
Fusion type: gated
Latent space dimension: 128
Output dimension (robot DOF): 6
Use the MultiViewVitSensorProcessing class to load and use this model for inference.
