# Train models for visual proprioception

Train a regression model for visual proprioception. The input is sensory data (eg. a camera image). This is encoded by a p;predefined sensorprocessing component into a latent representation. What we are training and saving here is a regressor that is mapping the latent representation to the position of the robot (eg. a vector of 6 degrees of freedom).

The specification of this regressor is specified in an experiment of the type "visual_proprioception". Running this notebook will train and save this model.

In [1]:
import sys
sys.path.append("..")
from settings import Config

import pathlib
from pprint import pprint
import matplotlib.pyplot as plt

import numpy as np
import torch
import torch.nn as nn
#import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

torch.manual_seed(1)


from visual_proprioception.visproprio_helper import load_demonstrations_as_proprioception_training,load_concat_demonstrations_as_proprioception_training, get_visual_proprioception_sp, load_multiview_demonstrations_as_proprioception_training,load_concat_demonstrations_as_proprioception_training2
from visual_proprioception.visproprio_models import VisProprio_SimpleMLPRegression


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Loading pointer config file: /home/ssheikholeslami/.config/BerryPicker/mainsettings.yaml
Loading machine-specific config file: /home/ssheikholeslami/SaharaBerryPickerData/settings-sahara.yaml
Using device: cuda


In [2]:
experiment = "visual_proprioception"

##############################################
#                 SingleView                 #
##############################################

# the latent space 128 ones

# run = "vp_aruco_128"  #DONE
# run = "vp_convvae_128" #DONE
# run = "vp_ptun_vgg19_128" #DONE
# run = "vp_ptun_resnet50_128" #DONE

# the latent space 256 ones
# run = "vp_convvae_256" #DONE
# run = "vp_ptun_vgg19_256" #DONE
# run = "vp_ptun_resnet50_256" #DONE

#vits
run ="vit_base_128" #DONE
# run ="vit_base_256" #DONE

# run ="vit_large_128" #DONE
# run ="vit_large_256" #DONE


##############################################
#                 MultiViews  - NEW!         #
##############################################

############  latent space: 128  ############
#concat_proj

# run ="vit_base_multiview_128"  #DONE
# run ="vit_large_multiview_128"  #DONE


##  indiv_proj
# run = "vit_base_multiview_indiv_proj_128"  # ViT Base_indiv_proj_128  #DONE
# run = "vit_large_multiview_indiv_proj_128" # ViT Large_indiv_proj_128  #DONE

##  attention
# run = "vit_base_multiview_attention_128"  # ViT Base_attention  #DONE
# run = "vit_large_multiview_attention_128" # ViT Large_attention  #DONE


##  weighted_sum
# run = "vit_base_multiview_weighted_sum_128"  # ViT Base_weighted_sum  #DONE
# run = "vit_large_multiview_weighted_sum_128" # ViT Large_weighted_sum  #DONE

##  gated
# run = "vit_base_multiview_gated_128"  # ViT Base_gated  #DONE
# run = "vit_large_multiview_gated_128" # ViT Large_gated  #DONE

########## the latent space 256 ones #########


# run ="vit_base_multiview_256"  #DONE
# run ="vit_large_multiview_256"  #DONE


##  indiv_proj
# run = "vit_base_multiview_indiv_proj_256"  # ViT Base_indiv_proj_256   #DONE
# run = "vit_large_multiview_indiv_proj_256" # ViT Large_indiv_proj_256 #DONE

##  attention
# run = "vit_base_multiview_attention_256"  # DONE
# run = "vit_large_multiview_attention_256" # ViT Large_attention #DONE


##  weighted_sum
# run = "vit_base_multiview_weighted_sum_256"  # ViT Base_weighted_sum   #DONE
# run = "vit_large_multiview_weighted_sum_256" # ViT Large_weighted_sum  #DONE


##  gated
# run = "vit_base_multiview_gated_256"  # ViT Base_gated  #DONE
# run = "vit_large_multiview_gated_256" # ViT Large_gated  #DONE


##############################################
#          MultiViews Image Concat - NEW!    #
##############################################
# the latent space 128 ones
# run = "vit_base_concat_multiview_128" # ViT Base  #DONE
# run = "vit_large_concat_multiview_128"  # ViT Large  #DONE
# run = "vp_convvae_128_concat_multiview"  #DONE

# the latent space 256 ones

# run = "vit_base_concat_multiview_256" # ViT Base  #DONE
# run = "vit_large_concat_multiview_256"  # ViT Large  #DONE
# run = "vp_convvae_256_concat_multiview" #DONE


##############################################
#          MultiViews CNN - NEW!             #
##############################################

# run = "vp_ptun_vgg19_128_multiview" #DONE
# run = "vp_ptun_resnet50_128_multiview" #DONE
# run = "vp_ptun_vgg19_256_multiview" #DONE
# run = "vp_ptun_resnet50_256_multiview" #DONE




exp = Config().get_experiment(experiment, run)
pprint(exp)

sp = get_visual_proprioception_sp(exp, device)


No system dependent experiment file
 /home/ssheikholeslami/SaharaBerryPickerData/experiments-Config/visual_proprioception/vit_base_128_sysdep.yaml,
 that is ok, proceeding.
Configuration for experiment: visual_proprioception/vit_base_128 successfully loaded
{'batch_size': 8,
 'data_dir': PosixPath('/home/ssheikholeslami/SaharaBerryPickerData/experiment_data/visual_proprioception/vit_base_128'),
 'encoding_size': 128,
 'epochs': 1000,
 'exp_run_sys_indep_file': PosixPath('/lustre/fs1/home/ssheikholeslami/BerryPicker/src/experiment_configs/visual_proprioception/vit_base_128.yaml'),
 'freeze_backbone': False,
 'freeze_feature_extractor': True,
 'group_name': 'visual_proprioception',
 'image_size': 224,
 'latent_size': 128,
 'learning_rate': 0.0001,
 'loss': 'MSE',
 'model_type': 'ViTProprioTunedRegression',
 'name': 'vit-base-128',
 'output_size': 6,
 'projection_hidden_dim': 512,
 'proprio_step_1': 64,
 'proprio_step_2': 64,
 'proprioception_input_file': 'train_inputs.pt',
 'propriocepti

  self.enc.load_state_dict(torch.load(modelfile, map_location=device))


In [3]:
"""
test_conv_vae_concat_fix.py

Updated test script that handles both list and tensor formats for cached data.
"""

import sys
import torch
import numpy as np
import pathlib
from settings import Config
from torch.utils.data import DataLoader, TensorDataset
from visual_proprioception.visproprio_models import VisProprio_SimpleMLPRegression

def test_conv_vae_concat_fix(experiment="sensorprocessing_conv_vae_concat_multiview",
                             run="proprio_128_concat_multiview"):
    """
    Test the fixed Conv-VAE concatenated multiview model.

    Args:
        experiment: Experiment name
        run: Run name

    Returns:
        True if all tests pass, False otherwise
    """
    print(f"Testing fix for {experiment}/{run}")

    # Load experiment configuration
    exp = Config().get_experiment(experiment, run)
    exp["debug"] = True  # Enable debug output

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Import the fixed sensor processing class
    from sensorprocessing.sp_conv_vae_concat_multiview import ConcatConvVaeSensorProcessing

    print("\n=== Test 1: Creating ConcatConvVaeSensorProcessing ===")
    try:
        sp = ConcatConvVaeSensorProcessing(exp, device)
        print("✓ Successfully created ConcatConvVaeSensorProcessing")
    except Exception as e:
        print(f"✗ Failed to create ConcatConvVaeSensorProcessing: {e}")
        import traceback
        traceback.print_exc()
        return False

    print("\n=== Test 2: Processing dummy views ===")
    try:
        # Create dummy views
        num_views = exp.get("num_views", 2)
        dummy_views = [torch.randn(1, 3, 64, 64).to(device) for _ in range(num_views)]
        print(f"Created {num_views} dummy views with shape {dummy_views[0].shape}")

        # Process through encoder
        latent = sp.encode(dummy_views)
        print(f"Encoded latent shape: {latent.shape}")

        # Verify latent size
        expected_size = exp.get("latent_size", 128)
        if latent.size != expected_size:
            print(f"✗ Latent size mismatch: expected {expected_size}, got {latent.size}")
            return False

        print(f"✓ Successfully encoded dummy views to latent size {latent.size}")
    except Exception as e:
        print(f"✗ Failed to process dummy views: {e}")
        import traceback
        traceback.print_exc()
        return False

    print("\n=== Test 3: Testing with MLP model ===")
    try:
        # Create the MLP model
        mlp_model = VisProprio_SimpleMLPRegression(exp)
        mlp_model.to(device)
        print(f"Created MLP with input_size={mlp_model.input_size}, output_size={mlp_model.output_size}")

        # Convert latent to tensor and reshape if needed
        latent_tensor = torch.tensor(latent, dtype=torch.float32).to(device)
        if len(latent_tensor.shape) == 1:
            latent_tensor = latent_tensor.unsqueeze(0)  # Add batch dimension

        print(f"Latent tensor shape for MLP: {latent_tensor.shape}")

        # Forward pass through MLP
        output = mlp_model(latent_tensor)
        print(f"MLP output shape: {output.shape}")

        # Verify output size
        expected_output_size = exp.get("output_size", 6)
        if output.shape[1] != expected_output_size:
            print(f"✗ Output size mismatch: expected {expected_output_size}, got {output.shape[1]}")
            return False

        print(f"✓ Successfully passed latent through MLP to get output of size {output.shape[1]}")
    except Exception as e:
        print(f"✗ Failed to test with MLP model: {e}")
        import traceback
        traceback.print_exc()
        return False

    print("\n=== Test 4: Testing data loading ===")
    try:
        # Data paths
        proprioception_input_file = pathlib.Path(exp["data_dir"], exp["proprioception_input_file"])
        proprioception_target_file = pathlib.Path(exp["data_dir"], exp["proprioception_target_file"])

        # Skip actual data loading to avoid long processing time during testing
        if proprioception_input_file.exists() and proprioception_target_file.exists():
            print("Cached data files exist, testing data loading...")
            try:
                # Load the data - handle both list and tensor formats
                raw_inputs = torch.load(proprioception_input_file, weights_only=True)
                raw_targets = torch.load(proprioception_target_file, weights_only=True)

                # Determine format and extract test samples
                test_data = {}

                # Handle list format (raw views)
                if isinstance(raw_inputs, list):
                    print("Data is in list format (raw views)")
                    # Take just a few samples for testing
                    view_samples = [view[:5] for view in raw_inputs]

                    # Process through encoder
                    test_inputs = []
                    for i in range(len(view_samples[0])):
                        sample_views = [view[i].unsqueeze(0).to(device) for view in view_samples]
                        latent = sp.encode(sample_views)
                        test_inputs.append(torch.tensor(latent, dtype=torch.float32))

                    test_data["inputs"] = torch.stack(test_inputs).to(device)
                    test_data["targets"] = raw_targets[:5].to(device)

                    print(f"  Processed {len(test_inputs)} samples")
                    print(f"  Each view shape: {[view.shape for view in view_samples]}")
                    print(f"  Processed inputs shape: {test_data['inputs'].shape}")
                else:
                    # Handle tensor format (already processed latents)
                    print("Data is in tensor format (processed latents)")
                    test_data["inputs"] = raw_inputs[:5].to(device)
                    test_data["targets"] = raw_targets[:5].to(device)

                    print(f"  Inputs shape: {test_data['inputs'].shape}")

                print(f"  Targets shape: {test_data['targets'].shape}")

                # Test with MLP
                output = mlp_model(test_data["inputs"])
                print(f"  MLP output with loaded data: {output.shape}")

                print("✓ Successfully tested data loading and MLP forward pass")
            except Exception as e:
                print(f"✗ Failed to load and test cached data: {e}")
                import traceback
                traceback.print_exc()
                return False
        else:
            print("No cached data found, skipping data loading test")
    except Exception as e:
        print(f"✗ Failed to import or test data loading: {e}")
        import traceback
        traceback.print_exc()
        return False

    print("\n=== All tests passed! ===")
    print("The Conv-VAE concatenated multiview model fix is working correctly.")
    return True

# if __name__ == "__main__":
#     success = test_conv_vae_concat_fix()
#     sys.exit(0 if success else 1)

In [4]:
success = test_conv_vae_concat_fix()

Testing fix for sensorprocessing_conv_vae_concat_multiview/proprio_128_concat_multiview
No system dependent experiment file
 /home/ssheikholeslami/SaharaBerryPickerData/experiments-Config/sensorprocessing_conv_vae_concat_multiview/proprio_128_concat_multiview_sysdep.yaml,
 that is ok, proceeding.
Configuration for experiment: sensorprocessing_conv_vae_concat_multiview/proprio_128_concat_multiview successfully loaded
Using device: cuda

=== Test 1: Creating ConcatConvVaeSensorProcessing ===
Initializing ConcatConvVaeSensorProcessing:
  num_views: 2
  stack_mode: width
  latent_size: 128
✓ Successfully created ConcatConvVaeSensorProcessing

=== Test 2: Processing dummy views ===
Created 2 dummy views with shape torch.Size([1, 3, 64, 64])
Concatenated shape: torch.Size([1, 3, 64, 128]), resizing to 64x64
Final latent shape: (128,), size: 128
Encoded latent shape: (128,)
✓ Successfully encoded dummy views to latent size 128

=== Test 3: Testing with MLP model ===
Created MLP with input_siz

In [5]:
# Create the regression model

model = VisProprio_SimpleMLPRegression(exp)
if exp["loss"] == "MSE":
    criterion = nn.MSELoss()
elif exp["loss"] == "L1":
    criterion = nn.L1Loss()
else:
    raise Exception(f'Unknown loss type {exp["loss"]}')

optimizer = optim.Adam(model.parameters(), lr=0.001)

### Load and cache the training data. 
* Iterate through the images and process them into latent encodings. 
* Iterate through the json files describing the robot position
* Save the input and target values into files in the experiment directory. These will act as caches for later runs
* Create the training and validation splits

In [6]:

task = exp["proprioception_training_task"]
proprioception_input_file = pathlib.Path(
    exp["data_dir"], exp["proprioception_input_file"])
proprioception_target_file = pathlib.Path(
    exp["data_dir"], exp["proprioception_target_file"])


# Check if we're using a multi-view approach
is_multiview = (
    exp.get("sensor_processing", "").endswith("_multiview") or
    exp.get("sensor_processing", "").startswith("Vit_concat") or  # Add this line
    "concat" in exp.get("sensor_processing", "").lower() or       # Add this line
    exp.get("num_views", 1) > 1
)
# is_multiview = exp.get("sensor_processing", "").endswith("_multiview") or exp.get("num_views", 1) > 1
# For specific CNN multi-view detection
is_cnn_multiview = (
    exp.get("sensor_processing", "") == "VGG19ProprioTunedSensorProcessing_multiview" or
    exp.get("sensor_processing", "") == "ResNetProprioTunedRegression_multiview"
)

is_conv_vae_concat =  exp.get("sensor_processing", "") == "ConvVaeSensorProcessing_concat_multiview"
print(is_conv_vae_concat)
print(is_multiview)
if is_multiview:
    print(f"Using multi-view approach with {exp.get('num_views', 2)} views")

    # # Create appropriate dataset and dataloader based on the data format
    # if is_conv_vae_concat:
    #     # For the concat model approach with pre-processed latents
    #     print("Using pre-processed latent approach for ConvVAE concat model")

    #     # Load data with your custom function
    #     tr = load_concat_demonstrations_as_proprioception_training2(
    #         task,
    #         proprioception_input_file,
    #         proprioception_target_file,
    #         num_views=exp.get("num_views", 2),
    #         sp=sp)

    #     # Create a dataset that wraps each input tensor in a list to match the multi-view format
    #     class ViewListDataset(torch.utils.data.Dataset):
    #         def __init__(self, inputs, targets, num_views=2):
    #             self.inputs = inputs
    #             self.targets = targets
    #             self.num_views = num_views

    #         def __len__(self):
    #             return len(self.targets)

    #         def __getitem__(self, idx):
    #             # Create a list with num_views copies of the same tensor
    #             views_list = [self.inputs[idx]] * self.num_views
    #             return views_list, self.targets[idx]

    #     is_preprocessed = tr.get("is_preprocessed", False)

    #     # Use ViewListDataset to ensure we get a list of tensors
    #     train_dataset = ViewListDataset(tr["inputs_training"], tr["targets_training"], num_views=exp.get("num_views", 2))
    #     test_dataset = ViewListDataset(tr["inputs_validation"], tr["targets_validation"], num_views=exp.get("num_views", 2))



    # Use standard TensorDataset for ConvVAE since views are already concatenated
        # train_dataset = TensorDataset(tr["inputs_training"], tr["targets_training"])
        # test_dataset = TensorDataset(tr["inputs_validation"], tr["targets_validation"])



    if is_cnn_multiview:
        print(f"Detected CNN-based multi-view model: {exp.get('sensor_processing')}")
    # Use the multiview loading function
        tr = load_multiview_demonstrations_as_proprioception_training(
        task,
        proprioception_input_file,
        proprioception_target_file,
        num_views=exp.get("num_views", 2))
    else:
        tr = load_multiview_demonstrations_as_proprioception_training(
        task,
        proprioception_input_file,
        proprioception_target_file,
        num_views=exp.get("num_views", 2)

    )

    # Create a custom dataset for multi-view data
    class MultiViewDataset(torch.utils.data.Dataset):
        def __init__(self, view_inputs, targets):
            self.view_inputs = view_inputs  # List of tensors, one per view
            self.targets = targets
            self.num_samples = len(targets)

        def __len__(self):
            return self.num_samples

        def __getitem__(self, idx):
            # Get corresponding sample from each view
            views = [view[idx] for view in self.view_inputs]
            target = self.targets[idx]
            return views, target

# Add the new class for processed latents
    class ProcessedLatentDataset(torch.utils.data.Dataset):
        def __init__(self, inputs, targets):
            self.inputs = inputs  # These are already processed latents
            self.targets = targets

        def __len__(self):
            return len(self.targets)

        def __getitem__(self, idx):
            # Return inputs directly - these are already processed
            # No need to pass through sp.encode again
            return self.inputs[idx], self.targets[idx]
    # Create DataLoaders for batching
    batch_size = exp.get('batch_size', 32)
    train_dataset = MultiViewDataset(tr["view_inputs_training"], tr["targets_training"])
    test_dataset = MultiViewDataset(tr["view_inputs_validation"], tr["targets_validation"])

else:
    print("Using single-view approach")

    # Use the original loading function
    tr = load_demonstrations_as_proprioception_training(
        sp, task, proprioception_input_file, proprioception_target_file
    )

    inputs_training = tr["inputs_training"]
    targets_training = tr["targets_training"]
    inputs_validation = tr["inputs_validation"]
    targets_validation = tr["targets_validation"]

    # Create standard DataLoaders for single-view data
    batch_size = exp.get('batch_size', 32)
    train_dataset = TensorDataset(inputs_training, targets_training)
    test_dataset = TensorDataset(inputs_validation, targets_validation)





# Create DataLoaders for batching


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

False
Using single-view approach


In [7]:

# Helper function to detect if a model has batch normalization layers
def has_batch_norm(model):
    """Check if the model contains any BatchNorm layers"""
    import torch.nn as nn

    for module in model.modules():
        if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            return True
    return False

### Perform the training

In [8]:
def train_and_save_proprioception_model(exp,is_preprocessed=False):
    """Trains and saves the proprioception model, handling both single and multi-view inputs
    with checkpoint support for resuming interrupted training
    """
    final_modelfile = pathlib.Path(exp["data_dir"], exp["proprioception_mlp_model_file"])
    checkpoint_dir = pathlib.Path(exp["data_dir"], "checkpoints")
    checkpoint_dir.mkdir(exist_ok=True)

    # Maximum number of checkpoints to keep (excluding the best model)
    max_checkpoints = 2

    # Check if we're using a multi-view approach
    is_multiview = (
        exp.get("sensor_processing", "").endswith("_multiview") or
        exp.get("sensor_processing", "").startswith("Vit_concat") or
        "concat" in exp.get("sensor_processing", "").lower() or
        exp.get("num_views", 1) > 1
    )

    # Detect concatenated model approach specifically
    is_concat_model =is_concat_model = "concat" in exp.get("sensor_processing", "").lower()


    is_cnn_multiview = (
        exp.get("sensor_processing", "") == "VGG19ProprioTunedSensorProcessing_multiview" or
        exp.get("sensor_processing", "") == "ResNetProprioTunedRegression_multiview"
    )
    num_views = exp.get("num_views", 2)

    # We'll always use the standard MLP model regardless of approach
    # This is model = VisProprio_SimpleMLPRegression(exp) defined outside this function
    train_model = model

    # Helper function to check if model has batch normalization
    def has_batch_norm(model):
        """Check if the model contains any BatchNorm layers"""
        import torch.nn as nn

        for module in model.modules():
            if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
                return True
        return False

    # First check for existing final model
    if final_modelfile.exists() and exp.get("reload_existing_model", True):
        print(f"Loading existing final model from {final_modelfile}")
        train_model.load_state_dict(torch.load(final_modelfile, map_location=device))

        # Evaluate the loaded model
        train_model.eval()
        with torch.no_grad():
            total_loss = 0
            batch_count = 0

            for batch_data in test_loader:
                try:
                    if is_multiview:
                        # if is_conv_vae_concat and is_preprocessed:
                        #     print("we are here in concattttttttt convvvvv")
                        #     # ConvVAE concat model expects single tensor with concatenated views
                        #     batch_X, batch_y = batch_data
                        #     batch_X = batch_X.to(device)
                        #     predictions = train_model(batch_X)
                        #     continue
                        if is_concat_model:
                            print ("not in is_concat_modelllllllllllllllllllll")
                            # For concat model, use sp.enc.encode to get latent features
                            batch_views, batch_y = batch_data

                            # Check batch size for BatchNorm compatibility
                            if batch_views[0].size(0) == 1 and has_batch_norm(train_model):
                                print(f"Warning: Skipping evaluation batch with size 1 (incompatible with BatchNorm)")
                                continue

                            # Process each sample to get latent features
                            batch_size = batch_views[0].size(0)
                            batch_features = []

                            for i in range(batch_size):
                                # Get views for this sample
                                sample_views = [view[i].unsqueeze(0).to(device) for view in batch_views]

                                # Get latent using concat model's encoder (without proprioceptor)
                                sample_features = sp.enc.encode(sample_views).cpu().numpy()

                                # Convert to tensor and move to device
                                sample_features_tensor = torch.tensor(sample_features, dtype=torch.float32).to(device)
                                batch_features.append(sample_features_tensor)

                            # Stack features
                            batch_X = torch.stack(batch_features).squeeze(1).to(device)
                            predictions = train_model(batch_X)
                        elif is_cnn_multiview:
                            # Handle CNN-based multi-view processing similar to standard multi-view
                            batch_views, batch_y = batch_data

                            # Check batch size for BatchNorm compatibility
                            if batch_views[0].size(0) == 1 and has_batch_norm(train_model):
                                print(f"Warning: Skipping evaluation batch with size 1 (incompatible with BatchNorm)")
                                continue

                            batch_size = batch_views[0].size(0)
                            batch_features = []

                            for i in range(batch_size):
                                sample_views = [view[i].unsqueeze(0).to(device) for view in batch_views]
                                # For CNN multi-view, use encode_views method of the model directly
                                sample_features = sp.process(sample_views)
                                # Convert numpy array to tensor and move to device
                                sample_features_tensor = torch.tensor(sample_features, dtype=torch.float32).to(device)
                                batch_features.append(sample_features_tensor)

                            batch_X = torch.stack(batch_features).to(device)
                            predictions = train_model(batch_X)
                        else:
                            # Standard multi-view processing
                            batch_views, batch_y = batch_data

                            # Check batch size for BatchNorm compatibility
                            if batch_views[0].size(0) == 1 and has_batch_norm(train_model):
                                print(f"Warning: Skipping evaluation batch with size 1 (incompatible with BatchNorm)")
                                continue

                            batch_size = batch_views[0].size(0)
                            batch_features = []

                            for i in range(batch_size):
                                sample_views = [view[i].unsqueeze(0).to(device) for view in batch_views]
                                sample_features = sp.process(sample_views)
                                # Convert numpy array to tensor and move to device
                                sample_features_tensor = torch.tensor(sample_features, dtype=torch.float32).to(device)
                                batch_features.append(sample_features_tensor)

                            batch_X = torch.stack(batch_features).to(device)
                            predictions = train_model(batch_X)
                    else:
                        batch_X, batch_y = batch_data

                        # Check batch size for BatchNorm compatibility
                        if batch_X.size(0) == 1 and has_batch_norm(train_model):
                            print(f"Warning: Skipping evaluation batch with size 1 (incompatible with BatchNorm)")
                            continue

                        batch_X = batch_X.to(device)
                        predictions = train_model(batch_X)

                    # Make sure batch_y is on the same device
                    batch_y = batch_y.to(device)
                    loss = criterion(predictions, batch_y)
                    total_loss += loss.item()
                    batch_count += 1

                except Exception as e:
                    print(f"Error in evaluation batch: {e}")
                    continue

            avg_loss = total_loss / max(batch_count, 1)
            print(f"Loaded model evaluation loss: {avg_loss:.4f}")

        return train_model

    # Function to extract epoch number from checkpoint file
    def get_epoch_number(checkpoint_file):
        try:
            # Use a more robust approach to extract epoch number
            # Format: epoch_XXXX.pth where XXXX is the epoch number
            filename = checkpoint_file.stem
            parts = filename.split('_')
            if len(parts) >= 2:
                return int(parts[1])  # Get the number after "epoch_"
            return 0
        except:
            return 0

    # Function to clean up old checkpoints
    def cleanup_old_checkpoints():
        # Get all epoch checkpoint files
        checkpoint_files = list(checkpoint_dir.glob("epoch_*.pth"))

        # Sort by actual epoch number, not just filename
        checkpoint_files.sort(key=get_epoch_number)

        if len(checkpoint_files) > max_checkpoints:
            files_to_delete = checkpoint_files[:-max_checkpoints]
            for file_path in files_to_delete:
                try:
                    file_path.unlink()
                    print(f"Deleted old checkpoint: {file_path.name}")
                except Exception as e:
                    print(f"Failed to delete {file_path.name}: {e}")

    # Make sure model is on the correct device
    train_model.to(device)
    print(f"Model moved to {device}")

    # Set training parameters
    num_epochs = exp["epochs"]
    start_epoch = 0
    best_loss = float('inf')

    # Check for existing checkpoints to resume from
    checkpoint_files = list(checkpoint_dir.glob("epoch_*.pth"))
    if checkpoint_files:
        # Sort by epoch number for more reliable ordering
        checkpoint_files.sort(key=get_epoch_number)

        # Get the most recent checkpoint
        latest_checkpoint = checkpoint_files[-1]
        epoch_num = get_epoch_number(latest_checkpoint)

        print(f"Found checkpoint from epoch {epoch_num}. Resuming training...")

        # Load checkpoint
        checkpoint = torch.load(latest_checkpoint, map_location=device)
        train_model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_loss = checkpoint.get('best_loss', float('inf'))

        print(f"Resuming from epoch {start_epoch}/{num_epochs} with best loss: {best_loss:.4f}")
    else:
        print(f"Starting new training for {num_epochs} epochs")

    # Start or resume training
    for epoch in range(start_epoch, num_epochs):
        print(f"Starting epoch {epoch+1}/{num_epochs}")
        train_model.train()
        total_loss = 0
        batch_count = 0

        # Training loop handles both single and multi-view cases
        for batch_idx, batch_data in enumerate(train_loader):
            try:
                # if isinstance(batch_data, tuple) and len(batch_data) == 2 and not isinstance(batch_data[0], list):
                #     # This is a standard (tensor, target) tuple, no list of views
                #     batch_X, batch_y = batch_data
                #     batch_X = batch_X.to(device)
                #     predictions = train_model(batch_X)
                if is_multiview:
                    # if is_conv_vae_concat and is_preprocessed:
                    #         print("we are here in concattttttttt convvvvv")
                    #         # ConvVAE concat model expects single tensor with concatenated views
                    #         batch_X, batch_y = batch_data
                    #         batch_X = batch_X.to(device)
                    #         predictions = train_model(batch_X)
                    #         continue
                    if is_concat_model:
                        # For concat model, get latent from sp.enc.encode
                        batch_views, batch_y = batch_data

                        # Check batch size for BatchNorm compatibility
                        if batch_views[0].size(0) == 1 and has_batch_norm(train_model):
                            print(f"Warning: Skipping batch {batch_idx} with size 1 (incompatible with BatchNorm)")
                            continue

                        # Process each sample to get latent features
                        batch_size = batch_views[0].size(0)
                        batch_features = []

                        for i in range(batch_size):
                            # Get views for this sample
                            sample_views = [view[i].unsqueeze(0).to(device) for view in batch_views]

                            # Get latent using concat model's encoder (without proprioceptor)
                            with torch.no_grad():  # Don't compute gradients through sp.enc
                                # sample_features = sp.enc.encode(sample_views).cpu().numpy()
                                sample_features = sp.enc.encode(sample_views)
                            # Convert to tensor and move to device
                            sample_features_tensor = torch.tensor(sample_features, dtype=torch.float32).to(device)
                            batch_features.append(sample_features_tensor)

                        # Stack features
                        batch_X = torch.stack(batch_features).squeeze(1).to(device)
                        predictions = train_model(batch_X)


                    elif is_cnn_multiview:
                        # Handle CNN-based multi-view processing similar to standard multi-view
                        batch_views, batch_y = batch_data

                        # Check if we have a problematic batch size=1 situation
                        if batch_views[0].size(0) == 1 and has_batch_norm(train_model):
                            print(f"Warning: Skipping batch {batch_idx} with size 1 (incompatible with BatchNorm)")
                            continue

                        # With CNN multi-view, batch_views is a list of tensors, each with shape [batch_size, C, H, W]
                        batch_size = batch_views[0].size(0)
                        batch_features = []

                        # Process each sample in the batch
                        for i in range(batch_size):
                            # Extract this sample's views
                            sample_views = [view[i].unsqueeze(0).to(device) for view in batch_views]

                            # Process this sample through sp
                            with torch.no_grad():  # Don't compute gradients through sp
                                sample_features = sp.process(sample_views)

                            # Convert numpy array to tensor and move to device
                            sample_features_tensor = torch.tensor(sample_features, dtype=torch.float32).to(device)
                            batch_features.append(sample_features_tensor)

                        # Stack all samples' features into a batch
                        batch_X = torch.stack(batch_features).to(device)

                        # Forward pass
                        predictions = train_model(batch_X)


                    else:
                        # Standard multi-view processing
                        batch_views, batch_y = batch_data

                        # Check if we have a problematic batch size=1 situation
                        if batch_views[0].size(0) == 1 and has_batch_norm(train_model):
                            print(f"Warning: Skipping batch {batch_idx} with size 1 (incompatible with BatchNorm)")
                            continue

                        # With standard multi-view, batch_views is a list of tensors, each with shape [batch_size, C, H, W]
                        batch_size = batch_views[0].size(0)
                        batch_features = []

                        # Process each sample in the batch
                        for i in range(batch_size):
                            # Extract this sample's views
                            sample_views = [view[i].unsqueeze(0).to(device) for view in batch_views]

                            # Process this sample through sp
                            sample_features = sp.process(sample_views)

                            # Convert numpy array to tensor and move to device
                            sample_features_tensor = torch.tensor(sample_features, dtype=torch.float32).to(device)
                            batch_features.append(sample_features_tensor)

                        # Stack all samples' features into a batch
                        batch_X = torch.stack(batch_features).to(device)

                        # Forward pass
                        predictions = train_model(batch_X)
                else:
                    batch_X, batch_y = batch_data

                    # Check if we have a problematic batch size=1 situation
                    if batch_X.size(0) == 1 and has_batch_norm(train_model):
                        print(f"Warning: Skipping batch {batch_idx} with size 1 (incompatible with BatchNorm)")
                        continue

                    # Move to device
                    batch_X = batch_X.to(device)
                    # Standard single-view processing
                    predictions = train_model(batch_X)

                # Make sure batch_y is on the same device
                batch_y = batch_y.to(device)
                loss = criterion(predictions, batch_y)

                # Backward pass and optimization
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                batch_count += 1

                # Print progress every few batches
                if (batch_idx + 1) % 10 == 0:
                    print(f"  Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item():.4f}")

            except Exception as e:
                print(f"Error in batch {batch_idx}: {e}")
                import traceback
                traceback.print_exc()  # Print the full stack trace for debugging
                # Save emergency checkpoint in case of error - use formatted epoch and batch numbers
                save_path = checkpoint_dir / f"emergency_epoch_{epoch:06d}_batch_{batch_idx:06d}.pth"
                torch.save({
                    'epoch': epoch,
                    'batch': batch_idx,
                    'model_state_dict': train_model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': total_loss / max(batch_count, 1),
                    'best_loss': best_loss
                }, save_path)
                print(f"Emergency checkpoint saved to {save_path}")
                continue

        # Calculate average loss for the epoch
        avg_loss = total_loss / max(batch_count, 1)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

        # Evaluate the model
        train_model.eval()
        test_loss = 0
        eval_batch_count = 0
        with torch.no_grad():
            for eval_batch_idx, batch_data in enumerate(test_loader):
                try:
                    if is_multiview:
                        # if is_conv_vae_concat and is_preprocessed:
                        #     print("we are here in concattttttttt convvvvv")
                        #     # ConvVAE concat model expects single tensor with concatenated views
                        #     batch_X, batch_y = batch_data
                        #     batch_X = batch_X.to(device)
                        #     predictions = train_model(batch_X)
                        #     continue
                        if is_concat_model:
                            # For concat model, get latent from sp.enc.encode
                            batch_views, batch_y = batch_data

                            # Check batch size for BatchNorm compatibility
                            if batch_views[0].size(0) == 1 and has_batch_norm(train_model):
                                print(f"Warning: Skipping eval batch {eval_batch_idx} with size 1 (incompatible with BatchNorm)")
                                continue

                            # Process each sample to get latent features
                            batch_size = batch_views[0].size(0)
                            batch_features = []

                            for i in range(batch_size):
                                # Get views for this sample
                                sample_views = [view[i].unsqueeze(0).to(device) for view in batch_views]

                                # Get latent using concat model's encoder (without proprioceptor)
                                # sample_features = sp.enc.encode(sample_views).cpu().numpy()
                                sample_features = sp.enc.encode(sample_views)
                                # Convert to tensor and move to device
                                sample_features_tensor = torch.tensor(sample_features, dtype=torch.float32).to(device)
                                batch_features.append(sample_features_tensor)

                            # Stack features
                            batch_X = torch.stack(batch_features).squeeze(1).to(device)
                            predictions = train_model(batch_X)

                        elif is_cnn_multiview:
                            # Handle CNN-based multi-view processing
                            batch_views, batch_y = batch_data

                            # Check batch size for BatchNorm compatibility
                            if batch_views[0].size(0) == 1 and has_batch_norm(train_model):
                                print(f"Warning: Skipping eval batch {eval_batch_idx} with size 1 (incompatible with BatchNorm)")
                                continue

                            batch_size = batch_views[0].size(0)
                            batch_features = []

                            for i in range(batch_size):
                                sample_views = [view[i].unsqueeze(0).to(device) for view in batch_views]
                                sample_features = sp.process(sample_views)
                                sample_features_tensor = torch.tensor(sample_features, dtype=torch.float32).to(device)
                                batch_features.append(sample_features_tensor)

                            batch_X = torch.stack(batch_features).to(device)
                            predictions = train_model(batch_X)
                        else:
                            # Process the batch the same way as in training
                            batch_views, batch_y = batch_data

                            # Check batch size for BatchNorm compatibility
                            if batch_views[0].size(0) == 1 and has_batch_norm(train_model):
                                print(f"Warning: Skipping eval batch {eval_batch_idx} with size 1 (incompatible with BatchNorm)")
                                continue

                            batch_size = batch_views[0].size(0)
                            batch_features = []

                            for i in range(batch_size):
                                sample_views = [view[i].unsqueeze(0).to(device) for view in batch_views]
                                sample_features = sp.process(sample_views)
                                # Convert numpy array to tensor and move to device
                                sample_features_tensor = torch.tensor(sample_features, dtype=torch.float32).to(device)
                                batch_features.append(sample_features_tensor)

                            batch_X = torch.stack(batch_features).to(device)
                            predictions = train_model(batch_X)
                    else:
                        batch_X, batch_y = batch_data

                        # Check batch size for BatchNorm compatibility
                        if batch_X.size(0) == 1 and has_batch_norm(train_model):
                            print(f"Warning: Skipping eval batch {eval_batch_idx} with size 1 (incompatible with BatchNorm)")
                            continue

                        batch_X = batch_X.to(device)
                        predictions = train_model(batch_X)

                    # Make sure batch_y is on the same device
                    batch_y = batch_y.to(device)
                    loss = criterion(predictions, batch_y)
                    test_loss += loss.item()
                    eval_batch_count += 1

                except Exception as e:
                    print(f"Error in evaluation batch {eval_batch_idx}: {e}")
                    continue

        avg_test_loss = test_loss / max(eval_batch_count, 1)
        print(f'Validation Loss: {avg_test_loss:.4f}')

        # Save checkpoint after each epoch - using formatted epoch numbers for reliable sorting
        checkpoint_path = checkpoint_dir / f"epoch_{epoch:06d}.pth"
        torch.save({
            'epoch': epoch,
            'model_state_dict': train_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': avg_loss,
            'test_loss': avg_test_loss,
            'best_loss': best_loss
        }, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

        # Clean up old checkpoints to save space
        cleanup_old_checkpoints()

        # Update best model if improved
        if avg_test_loss < best_loss:
            best_loss = avg_test_loss
            best_model_path = checkpoint_dir / "best_model.pth"
            torch.save({
                'epoch': epoch,
                'model_state_dict': train_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_loss,
                'test_loss': avg_test_loss,
                'best_loss': best_loss
            }, best_model_path)
            print(f"New best model saved with test loss: {best_loss:.4f}")

        # Update learning rate if using a scheduler
        # if scheduler is not None:
        #     scheduler.step(avg_test_loss)
        #     print(f"Current learning rate: {optimizer.param_groups[0]['lr']:.6f}")

    # Training completed successfully
    print(f"Training complete. Best test loss: {best_loss:.4f}")

    # Load the best model for final save
    best_model_path = checkpoint_dir / "best_model.pth"
    if best_model_path.exists():
        best_checkpoint = torch.load(best_model_path, map_location=device)
        train_model.load_state_dict(best_checkpoint['model_state_dict'])
        print(f"Loaded best model from epoch {best_checkpoint['epoch']+1} with test loss {best_checkpoint['test_loss']:.4f}")

    # Save the final model
    torch.save(train_model.state_dict(), final_modelfile)
    print(f"Final model saved to {final_modelfile}")

    return train_model

In [9]:
# modelfile = pathlib.Path(Config()["explorations"]["proprioception_mlp_model_file"])

#if modelfile.exists():
#    model.load_state_dict(torch.load(modelfile))
#else:
train_and_save_proprioception_model(exp, is_preprocessed=False)

Loading existing final model from /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/visual_proprioception/vit_base_128/proprioception_mlp.pth
Error in evaluation batch: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
Error in evaluation batch: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
Error in evaluation batch: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
Error in evaluation batch: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
Error in evaluation batch: Expected all tensors to be on the same device

  train_model.load_state_dict(torch.load(final_modelfile, map_location=device))


VisProprio_SimpleMLPRegression(
  (model): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=6, bias=True)
  )
)

In [10]:

# Additional debug info - print model type information
print("\n*** Model Information ***")
if hasattr(sp, 'enc'):
    # Check if model is multi-view
    is_multiview = hasattr(sp.enc, 'feature_extractors') and isinstance(sp.enc.feature_extractors, nn.ModuleList)
    if is_multiview:
        num_views = len(sp.enc.feature_extractors)
        view_type = "CNN-based" if any(['CNN' in str(type(sp.enc)), 'ResNet' in str(type(sp.enc)), 'VGG' in str(type(sp.enc))]) else "ViT-based"
        print(f"✓ Multi-view model detected: {view_type} with {num_views} views")
        print(f"  Model type: {type(sp.enc).__name__}")
        print(f"  Latent size: {sp.enc.latent_size}")

        # Print fusion method if applicable
        if hasattr(sp.enc, 'fusion_type'):
            print(f"  Fusion method: {sp.enc.fusion_type}")
        else:
            print(f"  Fusion method: feature concatenation")
    else:
        print(f"✗ Single-view model detected")
        print(f"  Model type: {type(sp.enc).__name__}")
        if hasattr(sp.enc, 'latent_size'):
            print(f"  Latent size: {sp.enc.latent_size}")
else:
    print("Cannot determine model type - no 'enc' attribute found")

print("\nTraining complete! Model saved to:")
print(f"  {pathlib.Path(exp['data_dir'], exp['proprioception_mlp_model_file'])}")


*** Model Information ***
✗ Single-view model detected
  Model type: ViTEncoder
  Latent size: 128

Training complete! Model saved to:
  /home/ssheikholeslami/SaharaBerryPickerData/experiment_data/visual_proprioception/vit_base_128/proprioception_mlp.pth
