# Train a proprioception-tuned Concatenated Image Vision Transformer (ViT)

We creates a sensor processing model that concatenates
multiple camera view images before processing them through a single Vision Transformer.
The model is trained to:
1. Create a meaningful 128-dimensional latent representation from concatenated camera views
2. Learn to map this representation to robot positions (proprioception)



In [None]:
import sys
sys.path.append("..")

from settings import Config

import pathlib
import torch
import torch.nn as nn
from torchvision import models, transforms
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

from demonstration.demonstration_helper import BCDemonstration
from sensorprocessing.sp_vit_concat_images import ConcatImageVitSensorProcessing
from robot.al5d_position_controller import RobotPosition

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:

# The experiment/run we are going to run: the specified model will be created
experiment = "sensorprocessing_propriotuned_Vit_concat_multiview"


# Other possible configurations:

#concat_proj
# run = "vit_base_concat_multiview" # ViT Base
run = "vit_large_concat_multiview"  # ViT Large
# run = "vit_huge_multiview" # ViT Huge


exp = Config().get_experiment(experiment, run)


No system dependent experiment file
 /home/ssheikholeslami/SaharaBerryPickerData/experiments-Config/sensorprocessing_propriotuned_Vit_concat_multiview/vit_large_concat_multiview_sysdep.yaml,
 that is ok, proceeding.
Configuration for experiment: sensorprocessing_propriotuned_Vit_concat_multiview/vit_large_concat_multiview successfully loaded


In [4]:
# Create output directory if it doesn't exist
data_dir = pathlib.Path(exp["data_dir"])
data_dir.mkdir(parents=True, exist_ok=True)
print(f"Data directory: {data_dir}")

Data directory: /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_Vit_concat_multiview/vit_large_concat_multiview


### Create regression training data (image to proprioception)
The training data (XX, Y) is all the 2-view pictures from a demonstration with the corresponding proprioception data.

In [5]:
def load_multiview_demonstrations(task, proprioception_input_file, proprioception_target_file, num_views=2):
    """
    Loads all the images of a task from multiple camera views, and processes it as two tensors
    as input and target data for proprioception training.

    Args:
        task: Task name to load demonstrations from
        proprioception_input_file: Path to save/load processed inputs
        proprioception_target_file: Path to save/load processed targets
        num_views: Number of camera views to process

    Returns:
        Dictionary containing training and validation data splits
    """

    retval = {}
    if proprioception_input_file.exists():
        print(f"Loading cached data from {proprioception_input_file}")
        retval["view_inputs"] = torch.load(proprioception_input_file, weights_only=True)
        retval["targets"] = torch.load(proprioception_target_file, weights_only=True)
    else:
        demos_dir = pathlib.Path(Config()["demos"]["directory"])
        task_dir = pathlib.Path(demos_dir, "demos", task)

        # Lists to store multi-view images and targets
        view_lists = {}  # Dictionary to organize views by camera
        targetlist = []

        print(f"Loading demonstrations from {task_dir}")
        for demo_dir in task_dir.iterdir():
            if not demo_dir.is_dir():
                continue

            print(f"Processing demonstration: {demo_dir.name}")
            # Create BCDemonstration with multi-camera support
            bcd = BCDemonstration(
                demo_dir,
                sensorprocessor=None,
                cameras=None  # This will detect all available cameras
            )

            # Initialize view lists if not already done
            if not view_lists:
                for camera in bcd.cameras:
                    view_lists[camera] = []

            # Process each timestep
            for i in range(bcd.trim_from, bcd.trim_to):
                # Get all images for this timestep
                all_images = bcd.get_all_images(i)

                # If we don't have all required views, skip this timestep
                if len(all_images) < num_views:
                    print(f"  Skipping timestep {i} - only {len(all_images)}/{num_views} views available")
                    continue

                # Collect images from each camera
                for camera, (sensor_readings, _) in all_images.items():
                    if camera in view_lists:
                        view_lists[camera].append(sensor_readings[0])

                # Get the robot action for this timestep
                a = bcd.get_a(i)
                rp = RobotPosition.from_vector(a)
                anorm = rp.to_normalized_vector()
                targetlist.append(torch.from_numpy(anorm))

        # Ensure we have the same number of frames for each view
        min_frames = min(len(view_list) for view_list in view_lists.values())
        if min_frames < len(targetlist):
            print(f"Truncating dataset to {min_frames} frames (from {len(targetlist)})")
            targetlist = targetlist[:min_frames]
            for camera in view_lists:
                view_lists[camera] = view_lists[camera][:min_frames]

        # Stack tensors for each view
        view_tensors = []
        for camera in sorted(view_lists.keys())[:num_views]:  # Take only the required number of views
            view_tensors.append(torch.stack(view_lists[camera]))

        # Create multi-view input tensor list [num_views, num_samples, C, H, W]
        retval["view_inputs"] = view_tensors
        retval["targets"] = torch.stack(targetlist)

        # Save processed data
        torch.save(retval["view_inputs"], proprioception_input_file)
        torch.save(retval["targets"], proprioception_target_file)
        print(f"Saved {len(targetlist)} training examples with {num_views} views each")

    # Separate the training and validation data
    length = len(retval["targets"])
    rows = torch.randperm(length)

    # Shuffle targets
    shuffled_targets = retval["targets"][rows]

    # Shuffle each view input using the same row indices
    shuffled_view_inputs = []
    for view_tensor in retval["view_inputs"]:
        shuffled_view_inputs.append(view_tensor[rows])

    # Split into training (67%) and validation (33%) sets
    training_size = int(length * 0.67)

    # Training data
    retval["view_inputs_training"] = [view[:training_size] for view in shuffled_view_inputs]
    retval["targets_training"] = shuffled_targets[:training_size]

    # Validation data
    retval["view_inputs_validation"] = [view[training_size:] for view in shuffled_view_inputs]
    retval["targets_validation"] = shuffled_targets[training_size:]

    print(f"Created {training_size} training examples and {length - training_size} validation examples")
    return retval


In [7]:

# Load the training data
task = exp["proprioception_training_task"]
proprioception_input_file = pathlib.Path(exp["data_dir"], exp["proprioception_input_file"])
proprioception_target_file = pathlib.Path(exp["data_dir"], exp["proprioception_target_file"])

tr = load_multiview_demonstrations(task, proprioception_input_file, proprioception_target_file)
view_inputs_training = tr["view_inputs_training"]
targets_training = tr["targets_training"]
view_inputs_validation = tr["view_inputs_validation"]
targets_validation = tr["targets_validation"]

Loading cached data from /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_Vit_concat_multiview/vit_large_concat_multiview/train_inputs.pt
Created 1169 training examples and 576 validation examples


### Create the multi-view ViT model


In [8]:

# Create the multi-view ViT model
sp = ConcatImageVitSensorProcessing(exp, device)
model = sp.enc  # Get the actual encoder model for training

print("Model created successfully")

try:
    params = model.parameters()
    print("Parameters accessed successfully")
    param_count = sum(p.numel() for p in params)
    print(f"Total parameters: {param_count}")
except Exception as e:
    print(f"Error accessing parameters: {e}")

# Select loss function
loss_type = exp.get('loss', 'MSELoss')
if loss_type == 'MSELoss':
    criterion = nn.MSELoss()
elif loss_type == 'L1Loss':
    criterion = nn.L1Loss()
else:
    criterion = nn.MSELoss()  # Default to MSE

# Set up optimizer with appropriate learning rate and weight decay
optimizer = optim.Adam(
    model.parameters(),
    lr=exp.get('learning_rate', 0.001),
    weight_decay=exp.get('weight_decay', 0.01)
)

# Optional learning rate scheduler
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)


Initializing Concatenated Image ViT Sensor Processing:
  Model: vit_l_16
  Number of views: 2
  Latent dimension: 128
  Image size: 224x224
Using vit_l_16 with output dimension 1024
Created projection network: 1024 → 512 → 256 → 128
Created proprioceptor: 128 → 64 → 64 → 6
Feature extractor frozen. Projection and proprioceptor layers are trainable.
Loading Concatenated Image ViT encoder weights from /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_Vit_concat_multiview/vit_large_concat_multiview/proprioception_mlp.pth


  self.enc.load_state_dict(torch.load(modelfile, map_location=device))


Model created successfully
Parameters accessed successfully
Total parameters: 304004998




### Custom dataset for multi-view data

In [11]:
# Define a custom dataset for multi-view images
class MultiViewDataset(torch.utils.data.Dataset):
    def __init__(self, view_inputs, targets):
        """
        Dataset for handling multiple camera views

        Args:
            view_inputs: List of tensors, one per view [view1, view2, ...]
                Each view tensor has shape [num_samples, C, H, W]
            targets: Tensor of target values [num_samples, output_dim]
        """
        self.view_inputs = view_inputs  # List of tensors, one per view
        self.targets = targets
        self.num_samples = len(targets)

        # Verify all views have the same number of samples
        for view in view_inputs:
            assert len(view) == self.num_samples, "All views must have the same number of samples"

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Get corresponding sample from each view
        views = [view[idx] for view in self.view_inputs]
        target = self.targets[idx]
        return views, target


In [12]:

# Create DataLoaders for batching
batch_size = exp.get('batch_size', 32)
train_dataset = MultiViewDataset(view_inputs_training, targets_training)
test_dataset = MultiViewDataset(view_inputs_validation, targets_validation)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [13]:
# For testing with the encoder model
def test_concat_model_direct(model, test_view_inputs, test_targets, device, n_samples=5):
    """Test the concatenated image ViT encoder model directly."""
    if n_samples > len(test_targets):
        n_samples = len(test_targets)

    # Get random indices
    indices = torch.randperm(len(test_targets))[:n_samples]

    print("\nTesting concatenated image ViT model on random examples:")
    print("-" * 60)

    # Ensure model is in evaluation mode
    model.eval()

    with torch.no_grad():
        for i, idx in enumerate(indices):
            # Get views and target
            views = [view[idx].unsqueeze(0).to(device) for view in test_view_inputs]
            target = test_targets[idx].cpu().numpy()

            # Get latent representation directly
            latent = model.encode(views)

            # Get prediction
            prediction = model.proprioceptor(latent)
            prediction = prediction.cpu().numpy().flatten()

            # Print the results
            print(f"Example {i+1}:")
            for j, view in enumerate(views):
                print(f"  View {j+1} shape: {view.shape}")
            print(f"  Latent shape: {latent.shape}")
            print(f"  Target position: {target}")
            print(f"  Predicted position: {prediction}")
            print(f"  Mean squared error: {np.mean((prediction - target) ** 2):.6f}")
            print()

In [15]:
def train_and_save_concat_vit_model(model, criterion, optimizer, modelfile,
                                  device="cpu", epochs=20, scheduler=None,
                                  log_interval=1):
    """
    Trains and saves the concatenated image ViT proprioception model with checkpoint support

    Args:
        model: Concatenated image ViT model with proprioception
        criterion: Loss function
        optimizer: Optimizer
        modelfile: Path to save the final model
        device: Training device (cpu/cuda)
        epochs: Number of training epochs
        scheduler: Optional learning rate scheduler
        log_interval: How often to print logs
    """
    # Create checkpoint directory
    checkpoint_dir = modelfile.parent / "checkpoints"
    checkpoint_dir.mkdir(exist_ok=True)

    # Maximum number of checkpoints to keep (excluding the best model)
    max_checkpoints = 2

    # Ensure model is on the right device
    model = model.to(device)
    criterion = criterion.to(device)

    # Function to extract epoch number from checkpoint file
    def get_epoch_number(checkpoint_file):
        try:
            # Format: epoch_XXXX.pth where XXXX is the epoch number
            filename = checkpoint_file.stem
            parts = filename.split('_')
            if len(parts) >= 2:
                return int(parts[1])  # Get the number after "epoch_"
            return 0
        except:
            return 0

    # Keep track of the best validation loss
    best_val_loss = float('inf')
    start_epoch = 0

    # Check for existing checkpoints to resume from
    checkpoint_files = list(checkpoint_dir.glob("epoch_*.pth"))
    if checkpoint_files:
        # Sort by epoch number for more reliable ordering
        checkpoint_files.sort(key=get_epoch_number)

        # Get the most recent checkpoint
        latest_checkpoint = checkpoint_files[-1]
        epoch_num = get_epoch_number(latest_checkpoint)

        print(f"Found checkpoint from epoch {epoch_num}. Resuming training...")

        # Load checkpoint
        checkpoint = torch.load(latest_checkpoint, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_val_loss = checkpoint.get('best_val_loss', float('inf'))

        print(f"Resuming from epoch {start_epoch}/{epochs} with best validation loss: {best_val_loss:.4f}")

    # Function to clean up old checkpoints
    def cleanup_old_checkpoints():
        # Get all epoch checkpoint files
        checkpoint_files = list(checkpoint_dir.glob("epoch_*.pth"))

        # Sort by actual epoch number, not just filename
        checkpoint_files.sort(key=get_epoch_number)

        if len(checkpoint_files) > max_checkpoints:
            files_to_delete = checkpoint_files[:-max_checkpoints]
            for file_path in files_to_delete:
                try:
                    file_path.unlink()
                    print(f"Deleted old checkpoint: {file_path.name}")
                except Exception as e:
                    print(f"Failed to delete {file_path.name}: {e}")

    # Training loop
    for epoch in range(start_epoch, epochs):
        print(f"Starting epoch {epoch+1}/{epochs}")

        # Training phase
        model.train()
        total_loss = 0
        for batch_idx, (batch_views, batch_y) in enumerate(train_loader):
            # Move views and targets to device
            batch_views = [view.to(device) for view in batch_views]
            batch_y = batch_y.to(device)

            # Forward pass through the full model (including proprioceptor)
            predictions = model.forward(batch_views)
            loss = criterion(predictions, batch_y)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Print progress every few batches
            if (batch_idx + 1) % 10 == 0:
                print(f"  Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item():.4f}")

        avg_train_loss = total_loss / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch_views, batch_y in test_loader:
                # Move views and targets to device
                batch_views = [view.to(device) for view in batch_views]
                batch_y = batch_y.to(device)

                predictions = model(batch_views)
                loss = criterion(predictions, batch_y)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(test_loader)

        # Update learning rate if scheduler is provided
        if scheduler is not None:
            scheduler.step(avg_val_loss)

        # Save checkpoint for this epoch - using formatted epoch numbers for reliable sorting
        checkpoint_path = checkpoint_dir / f"epoch_{epoch:06d}.pth"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
            'best_val_loss': best_val_loss
        }, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

        # Clean up old checkpoints to save space
        cleanup_old_checkpoints()

        # Save the best model separately
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_path = checkpoint_dir / "best_model.pth"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'best_val_loss': best_val_loss
            }, best_model_path)
            print(f"  New best model saved with validation loss: {best_val_loss:.4f}")

        # Log progress
        if (epoch + 1) % log_interval == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

    # Training completed successfully
    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

    # Load the best model for final save
    best_model_path = checkpoint_dir / "best_model.pth"
    if best_model_path.exists():
        best_checkpoint = torch.load(best_model_path, map_location=device)
        model.load_state_dict(best_checkpoint['model_state_dict'])
        print(f"Loaded best model from epoch {best_checkpoint['epoch']+1} with loss {best_checkpoint['val_loss']:.4f}")

    # Save to final model file only after completing all epochs
    torch.save(model.state_dict(), modelfile)
    print(f"Final model saved to {modelfile}")

    return model


In [16]:
    modelfile = pathlib.Path(exp["data_dir"], exp["proprioception_mlp_model_file"])
epochs = exp.get("epochs", 20)

    # First check for existing final model
    if modelfile.exists() and exp.get("reload_existing_model", True):
        print(f"Loading existing final model from {modelfile}")
        model.load_state_dict(torch.load(modelfile, map_location=device))

        # Evaluate the loaded model
        model.eval()
        with torch.no_grad():
            val_loss = 0
            for batch_views, batch_y in test_loader:
                # Move views and targets to device
                batch_views = [view.to(device) for view in batch_views]
                batch_y = batch_y.to(device)

                predictions = model(batch_views)
                loss = criterion(predictions, batch_y)
                val_loss += loss.item()

            avg_val_loss = val_loss / len(test_loader)
            print(f"Loaded model validation loss: {avg_val_loss:.4f}")

        # Test model on some examples
        test_concat_model_direct(model, view_inputs_validation, targets_validation, device)
    else:
        # Check for checkpoints to resume from, otherwise start fresh training
        checkpoint_dir = modelfile.parent / "checkpoints"
        checkpoint_dir.mkdir(exist_ok=True)

        # Use the get_epoch_number function for reliable sorting
        def get_epoch_number(checkpoint_file):
            try:
                filename = checkpoint_file.stem
                parts = filename.split('_')
                if len(parts) >= 2:
                    return int(parts[1])
                return 0
            except:
                return 0

        checkpoint_files = list(checkpoint_dir.glob("epoch_*.pth"))
        if checkpoint_files:
            # Sort by epoch number
            checkpoint_files.sort(key=get_epoch_number)
            latest_epoch = get_epoch_number(checkpoint_files[-1])
            print(f"Found checkpoints up to epoch {latest_epoch}. Will resume training from last checkpoint.")
        else:
            print(f"Starting new training for {epochs} epochs")

        # Train the model with checkpoint support
        model = train_and_save_concat_vit_model(
            model, criterion, optimizer, modelfile,
            device=device, epochs=epochs, scheduler=lr_scheduler
        )

Loading existing final model from /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_Vit_concat_multiview/vit_large_concat_multiview/proprioception_mlp.pth


  model.load_state_dict(torch.load(modelfile, map_location=device))


Loaded model validation loss: 0.0076

Testing concatenated image ViT model on random examples:
------------------------------------------------------------
Example 1:
  View 1 shape: torch.Size([1, 3, 256, 256])
  View 2 shape: torch.Size([1, 3, 256, 256])
  Latent shape: torch.Size([1, 128])
  Target position: [0.36338103 0.46871686 0.8122891  0.5504956  0.26950523 0.20594741]
  Predicted position: [0.3584069 0.5001888 0.7504191 0.5339994 0.2638954 0.2457543]
  Mean squared error: 0.001122

Example 2:
  View 1 shape: torch.Size([1, 3, 256, 256])
  View 2 shape: torch.Size([1, 3, 256, 256])
  Latent shape: torch.Size([1, 128])
  Target position: [0.91794205 0.4346888  0.26011056 0.13500535 0.49278548 0.8269405 ]
  Predicted position: [0.7771102  0.40149823 0.29160368 0.26021338 0.49528894 0.7670573 ]
  Mean squared error: 0.006866

Example 3:
  View 1 shape: torch.Size([1, 3, 256, 256])
  View 2 shape: torch.Size([1, 3, 256, 256])
  Latent shape: torch.Size([1, 128])
  Target position:

### Test the trained model


In [19]:

# Create the sensor processing module using the trained model
sp = ConcatImageVitSensorProcessing(exp, device)
model = sp.enc  # Get the actual encoder model for training

# Test it on a few validation examples

test_concat_model_direct(sp.enc, view_inputs_validation, targets_validation, device)



Initializing Concatenated Image ViT Sensor Processing:
  Model: vit_l_16
  Number of views: 2
  Latent dimension: 128
  Image size: 224x224
Using vit_l_16 with output dimension 1024
Created projection network: 1024 → 512 → 256 → 128
Created proprioceptor: 128 → 64 → 64 → 6
Feature extractor frozen. Projection and proprioceptor layers are trainable.
Loading Concatenated Image ViT encoder weights from /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/sensorprocessing_propriotuned_Vit_concat_multiview/vit_large_concat_multiview/proprioception_mlp.pth

Testing concatenated image ViT model on random examples:
------------------------------------------------------------
Example 1:
  View 1 shape: torch.Size([1, 3, 256, 256])
  View 2 shape: torch.Size([1, 3, 256, 256])
  Latent shape: torch.Size([1, 128])
  Target position: [0.9088733  0.37754595 0.21844389 0.15448141 0.5594522  0.97095925]
  Predicted position: [0.82214594 0.40359777 0.29848638 0.29593843 0.51397604 0.8868769 ]
  

### Verify the model's encoding and forward methods

In [21]:
def test_model_encoding(model, view_inputs_validation, device, n_samples=3):
    """Test the encoding functionality of the concatenated image ViT model."""
    if n_samples > len(view_inputs_validation[0]):
        n_samples = len(view_inputs_validation[0])

    # Get first few samples
    indices = range(n_samples)

    print("\nTesting model encoding functionality:")
    print("-" * 60)

    # Get expected dimensions from model configuration
    latent_size = model.latent_size
    output_size = model.output_size

    # Ensure model is in evaluation mode
    model.eval()

    with torch.no_grad():
        for i in indices:
            # Get views
            views = [view[i].unsqueeze(0).to(device) for view in view_inputs_validation]

            print(f"\nSample {i+1}:")

            # Test image concatenation
            concat_image = model.concatenate_images(views)
            print(f"  Concatenated image shape: {concat_image.shape}")
            # Assert concatenated image has same batch and channel size as input
            assert concat_image.shape[0] == views[0].shape[0], "Batch size mismatch in concatenated image"
            assert concat_image.shape[1] == views[0].shape[1], "Channel size mismatch in concatenated image"

            # Test latent encoding
            latent = model.encode(views)
            print(f"  Latent representation shape: {latent.shape}")
            # Assert latent has correct shape
            assert latent.shape[1] == latent_size, f"Expected latent size {latent_size}, got {latent.shape[1]}"

            # Test full forward pass
            output = model.forward(views)
            print(f"  Model output shape: {output.shape}")
            # Assert output has correct shape
            assert output.shape[1] == output_size, f"Expected output size {output_size}, got {output.shape[1]}"

            # Test that the proprioceptor gives the expected output when fed the latent
            proprio_output = model.proprioceptor(latent)
            print(f"  Proprioceptor output shape: {proprio_output.shape}")
            # Assert proprioceptor output has correct shape
            assert proprio_output.shape[1] == output_size, f"Expected proprioceptor output size {output_size}, got {proprio_output.shape[1]}"

            # Verify that direct proprioceptor output matches the forward pass
            is_close = torch.allclose(output, proprio_output, rtol=1e-5, atol=1e-5)
            print(f"  Forward pass matches proprioceptor: {is_close}")
            assert is_close, "Forward pass output does not match proprioceptor output"

            print("  All shape assertions passed!")


In [22]:
test_model_encoding(sp.enc, view_inputs_validation, device, n_samples=3)



Testing model encoding functionality:
------------------------------------------------------------

Sample 1:
  Concatenated image shape: torch.Size([1, 3, 224, 224])
  Latent representation shape: torch.Size([1, 128])
  Model output shape: torch.Size([1, 6])
  Proprioceptor output shape: torch.Size([1, 6])
  Forward pass matches proprioceptor: True
  All shape assertions passed!

Sample 2:
  Concatenated image shape: torch.Size([1, 3, 224, 224])
  Latent representation shape: torch.Size([1, 128])
  Model output shape: torch.Size([1, 6])
  Proprioceptor output shape: torch.Size([1, 6])
  Forward pass matches proprioceptor: True
  All shape assertions passed!

Sample 3:
  Concatenated image shape: torch.Size([1, 3, 224, 224])
  Latent representation shape: torch.Size([1, 128])
  Model output shape: torch.Size([1, 6])
  Proprioceptor output shape: torch.Size([1, 6])
  Forward pass matches proprioceptor: True
  All shape assertions passed!


### Test the process_file method with cache handling

In [23]:
def test_process_file(sp, view_inputs_validation, n_samples=3):
    """Test the process_file method which handles view caching."""
    if n_samples > len(view_inputs_validation[0]):
        n_samples = len(view_inputs_validation[0])

    # Get first few samples
    indices = range(n_samples)

    print("\nTesting process_file method with view caching:")
    print("-" * 60)

    # Create temporary files to simulate image files
    import tempfile
    import os
    from torchvision.utils import save_image

    temp_dir = tempfile.mkdtemp()
    try:
        # Create temporary image files for testing
        for idx in indices:
            print(f"\nProcessing timestep {idx+1}:")

            # Create a file for each view at this timestep
            file_paths = []
            for view_idx, view_tensor in enumerate(view_inputs_validation):
                # Create file name with timestep and camera ID
                file_name = f"{idx+1:05d}_camera{view_idx+1}.jpg"
                file_path = os.path.join(temp_dir, file_name)

                # Save tensor as image
                save_image(view_tensor[idx], file_path)
                file_paths.append(file_path)

            # Test processing with all views available
            print("  Testing with all views available:")
            for view_idx, file_path in enumerate(file_paths):
                latent = sp.process_file(file_path, camera_id=f"camera{view_idx+1}")
                print(f"    Processed view {view_idx+1}, latent shape: {latent.shape}")

            # Test processing with missing views (if we have more than one view)
            if len(file_paths) > 1:
                print("  Testing with one missing view:")
                # Clear cache to simulate new scenario
                if hasattr(sp, '_view_cache'):
                    sp._view_cache = {}
                    sp._timestep_cache = {}

                # Process only one view and see if cache is updated
                latent = sp.process_file(file_paths[0], camera_id=f"camera1")
                print(f"    Processed only view 1, latent shape: {latent.shape}")

                # Now process a different view and see if previous view is used from cache
                latent = sp.process_file(file_paths[1], camera_id=f"camera2")
                print(f"    Processed view 2 with view 1 from cache, latent shape: {latent.shape}")

    finally:
        # Clean up temporary directory
        import shutil
        shutil.rmtree(temp_dir)

test_process_file(sp, view_inputs_validation)


Testing process_file method with view caching:
------------------------------------------------------------

Processing timestep 1:
  Testing with all views available:
Using 1 unique views (with 1 duplicated)
    Processed view 1, latent shape: (128,)
Using complete set of 2 views from timestep 1
    Processed view 2, latent shape: (128,)
  Testing with one missing view:
Using 1 unique views (with 1 duplicated)
    Processed only view 1, latent shape: (128,)
Using complete set of 2 views from timestep 1
    Processed view 2 with view 1 from cache, latent shape: (128,)

Processing timestep 2:
  Testing with all views available:
Using 2 unique views (with 0 duplicated)
    Processed view 1, latent shape: (128,)
Using complete set of 2 views from timestep 2
    Processed view 2, latent shape: (128,)
  Testing with one missing view:
Using 1 unique views (with 1 duplicated)
    Processed only view 1, latent shape: (128,)
Using complete set of 2 views from timestep 2
    Processed view 2 wi

### Save final model and print summary

In [24]:



print("\nTraining complete!")
print(f"Vision Transformer type: {exp['vit_model']}")
print(f"Number of views: {exp.get('num_views', 2)}")
print(f"Latent space dimension: {exp['latent_size']}")
print(f"Output dimension (robot DOF): {exp['output_size']}")
print(f"Use the ConcatImageVitSensorProcessing class to load and use this model for inference.")



Training complete!
Vision Transformer type: vit_l_16
Number of views: 2
Latent space dimension: 128
Output dimension (robot DOF): 6
Use the ConcatImageVitSensorProcessing class to load and use this model for inference.
