In [1]:
# Shuffle and Slice .mat File Data

# This notebook performs the following steps:
# 1.  Loads a single `.mat` file.
# 2.  Shuffles the data for a specific set of keys.
# 3.  Selects the first 500 samples from the shuffled data.
# 4.  Saves the result to a new `.mat` file.

In [4]:
import scipy.io
import numpy as np
import os

In [5]:
# --- 1. Configuration ---

# IMPORTANT: Replace this with the actual path to your source .mat file
source_file_path = "/home/zhuyekun/projects/repos/deeprte/data/raw/train/merge/merged_train.mat"

# IMPORTANT: Define the desired output path and filename for the new data
output_path_shuffled = "/home/zhuyekun/projects/repos/deeprte/data/raw/train/merge/shuffled_500_samples.mat"

# Number of random samples to keep
num_samples_to_keep = 500

# The specific keys whose arrays you want to shuffle and slice
keys_to_process = ['scattering_kernel', 'sigma_a', 'sigma_t', 'phi', 'psi_bc', 'psi_label']

In [6]:
# --- 2. Processing Logic ---

print(f"Processing file: {source_file_path}")
print(f"Will keep {num_samples_to_keep} shuffled samples.")
print(f"Will process the following keys: {keys_to_process}")

try:
    # --- Load Data ---
    print(f"Loading data from: {source_file_path}")
    data = scipy.io.loadmat(source_file_path)

    # --- Shuffle and Slice ---

    # Get the total number of samples from the first key in the list
    # (Assumes all arrays to be processed have the same number of samples)
    if not keys_to_process or keys_to_process[0] not in data:
        raise ValueError(f"Key '{keys_to_process[0]}' not found in the file. Cannot determine number of samples.")

    num_total_samples = data[keys_to_process[0]].shape[0]
    print(f"Found {num_total_samples} total samples in the file.")

    if num_total_samples < num_samples_to_keep:
        print(f"Warning: Total samples ({num_total_samples}) is less than requested samples ({num_samples_to_keep}). Keeping all samples.")
        num_samples_to_keep = num_total_samples

    # Generate a shuffled sequence of indices
    shuffled_indices = np.random.permutation(num_total_samples)

    # Take the first `num_samples_to_keep` indices from the shuffled sequence
    selected_indices = shuffled_indices[:num_samples_to_keep]

    # Create a new dictionary for the processed data, copying all original key-value pairs first
    processed_data = {k: v for k, v in data.items() if not k.startswith('__')}

    # For the specified keys, overwrite the data with the shuffled and sliced array
    print("Shuffling and slicing data for specified keys...")
    for key in keys_to_process:
        if key in processed_data:
            original_array = processed_data[key]
            # Use the same shuffled indices to slice the array to maintain consistency
            processed_data[key] = original_array[selected_indices]
        else:
            print(f"Warning: Key '{key}' not found in the file. Skipping.")

    # --- 3. Save the new file ---
    print(f"Saving {num_samples_to_keep} shuffled samples to: {output_path_shuffled}")
    output_dir = os.path.dirname(output_path_shuffled)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)

    scipy.io.savemat(output_path_shuffled, processed_data)
    print("Processing complete!")

    # --- 4. Verification (Optional) ---
    print("\nVerifying new file...")
    verification_data = scipy.io.loadmat(output_path_shuffled)
    for key, value in verification_data.items():
        if not key.startswith('__'):
            # Only print the shape for the keys we processed to confirm the slicing
            if key in keys_to_process:
                 print(f"Key: '{key}', New Shape: {value.shape}")

except FileNotFoundError:
    print(f"Error: Source file not found at {source_file_path}")
except ValueError as ve:
    print(f"Error: {ve}")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

Processing file: /home/zhuyekun/projects/repos/deeprte/data/raw/train/merge/merged_train.mat
Will keep 500 shuffled samples.
Will process the following keys: ['scattering_kernel', 'sigma_a', 'sigma_t', 'phi', 'psi_bc', 'psi_label']
Loading data from: /home/zhuyekun/projects/repos/deeprte/data/raw/train/merge/merged_train.mat
Found 3000 total samples in the file.
Shuffling and slicing data for specified keys...
Found 3000 total samples in the file.
Shuffling and slicing data for specified keys...
Saving 500 shuffled samples to: /home/zhuyekun/projects/repos/deeprte/data/raw/train/merge/shuffled_500_samples.mat
Saving 500 shuffled samples to: /home/zhuyekun/projects/repos/deeprte/data/raw/train/merge/shuffled_500_samples.mat
Processing complete!

Verifying new file...
Key: 'phi', New Shape: (500, 41, 41)
Key: 'psi_bc', New Shape: (500, 164, 12)
Key: 'psi_label', New Shape: (500, 41, 41, 24)
Key: 'scattering_kernel', New Shape: (500, 24, 24)
Key: 'sigma_a', New Shape: (500, 41, 41)
Key: '