In [None]:
import torch
import numpy as np
# You might need scipy for time warping
try:
    from scipy.interpolate import CubicSpline, interp1d, PchipInterpolator
except ImportError:
    print("Scipy not found. Time Warping augmentation will not be available.")
    CubicSpline, interp1d, PchipInterpolator = None, None, None


# Assume TRAIN is defined elsewhere, e.g., as True/False or an enum
# For these helper functions, we'll make them generic or controlled by parameters
# For the dataset integration, we typically apply augmentations during "training mode"

# --- Helper Augmentation Functions ---

def jitter(seq: torch.Tensor, noise_level: float) -> torch.Tensor:
    """
    Adds random noise to the sequence.

    Args:
        seq (torch.Tensor): Input sequence tensor [L, D].
        noise_level (float): Standard deviation of the noise to add.

    Returns:
        torch.Tensor: Jittered sequence tensor [L, D].
    """
    if noise_level <= 0:
        return seq
    # Add noise sampled from a normal distribution
    noise = torch.randn_like(seq) * noise_level
    return seq + noise

def scale(seq: torch.Tensor, scale_range=(0.8, 1.2)) -> torch.Tensor:
    """
    Randomly scales the sequence values.

    Args:
        seq (torch.Tensor): Input sequence tensor [L, D].
        scale_range (tuple): Tuple (min_scale, max_scale) for random scaling factor.

    Returns:
        torch.Tensor: Scaled sequence tensor [L, D].
    """
    min_scale, max_scale = scale_range
    if min_scale >= max_scale:
        return seq
    # Generate a random scale factor for the whole sequence
    scale_factor = (max_scale - min_scale) * torch.rand(1) + min_scale
    return seq * scale_factor

def time_warp(seq: torch.Tensor, num_control_points: int = 4, max_displacement_ratio: float = 0.1) -> torch.Tensor:
    """
    Applies time warping to the sequence using cubic spline interpolation.
    Requires scipy.

    Args:
        seq (torch.Tensor): Input sequence tensor [L, D].
        num_control_points (int): Number of control points for the spline.
                                  Must be >= 2. First and last points are fixed.
        max_displacement_ratio (float): Maximum displacement as a fraction of the
                                       interval between control points.

    Returns:
        torch.Tensor: Time-warped sequence tensor [L, D]. Returns original if scipy
                      is not available or warping fails.
    """
    if CubicSpline is None:
        print("Scipy not available. Skipping time warping.")
        return seq

    L, D = seq.shape
    if L <= 2 or num_control_points < 2:
        return seq # Cannot warp very short sequences

    try:
        # Convert to numpy for scipy operations
        seq_np = seq.numpy()

        # Original time points
        t_orig = np.arange(L)

        # Original control points (evenly spaced)
        t_orig_ctrl = np.linspace(0, L - 1, num_control_points)

        # Randomly displace intermediate control points in time
        # Keep start and end points fixed
        displacements = (np.random.rand(num_control_points - 2) - 0.5) * 2 * (L / (num_control_points - 1)) * max_displacement_ratio
        t_warp_ctrl = np.copy(t_orig_ctrl)
        t_warp_ctrl[1:-1] += displacements

        # Ensure warped control points are monotonic and within bounds [0, L-1]
        t_warp_ctrl = np.clip(t_warp_ctrl, 0, L - 1)
        t_warp_ctrl = np.sort(t_warp_ctrl) # Ensure monotonicity

        # Create a mapping from original time to warped time using PchipInterpolator
        # PchipInterpolator preserves monotonicity and shape better than CubicSpline here
        time_map_func = PchipInterpolator(t_orig_ctrl, t_warp_ctrl)

        # Get the new time points for the entire sequence
        t_warp_full = time_map_func(t_orig)

        # Ensure warped time points are within the original time bounds
        t_warp_full = np.clip(t_warp_full, 0, L - 1)

        # Interpolate the original data values onto the new warped time points
        # Apply interpolation independently for each feature dimension
        warped_seq_np = np.zeros_like(seq_np)
        for d in range(D):
            # Use interp1d to interpolate the data
            # fill_value="extrapolate" handles values outside the original time range
            interpolator = interp1d(t_orig, seq_np[:, d], kind='linear', fill_value="extrapolate")
            warped_seq_np[:, d] = interpolator(t_warp_full)

        return torch.tensor(warped_seq_np, dtype=seq.dtype)

    except Exception as e:
        print(f"Time warping failed: {e}. Returning original sequence.")
        # Handle potential errors during interpolation (e.g., non-monotonic control points despite sorting)
        return seq


def random_crop_or_pad(seq: torch.Tensor, target_len: int, pad_value: float = 0.0, mode: str = 'random') -> torch.Tensor:
    """
    Crops or pads the sequence to a target length.

    Args:
        seq (torch.Tensor): Input sequence tensor [L, D].
        target_len (int): The desired length of the output sequence.
        pad_value (float): The value to use for padding.
        mode (str): 'random' for random cropping/padding start, 'start' for
                    cropping/padding at the beginning, 'end' for cropping/padding
                    at the end.

    Returns:
        torch.Tensor: Processed sequence tensor [target_len, D].
    """
    L, D = seq.shape

    if L == target_len:
        return seq

    elif L > target_len:
        # Crop
        max_start = L - target_len
        if mode == 'random':
            start = np.random.randint(0, max_start + 1)
        elif mode == 'start':
            start = 0
        elif mode == 'end':
            start = max_start
        else:
             raise ValueError(f"Invalid mode: {mode}")

        return seq[start : start + target_len]

    else: # L < target_len
        # Pad
        total_padding = target_len - L
        if mode == 'random':
             pad_start = np.random.randint(0, total_padding + 1)
        elif mode == 'start':
             pad_start = 0
        elif mode == 'end':
             pad_start = total_padding
        else:
             raise ValueError(f"Invalid mode: {mode}")

        pad_end = total_padding - pad_start

        # Use torch.nn.functional.pad - pads (dim_last_start, dim_last_end, dim_prev_start, ...)
        padded_seq = torch.nn.functional.pad(seq, (0, 0, pad_start, pad_end), "constant", pad_value)
        return padded_seq

# --- Example Integration into MotionDataset ---

# You would add augmentation parameters to your dataset __init__
class MotionDatasetAugmented(torch.utils.data.Dataset):
    def __init__(self, X, y, alpha=0.2,
                 target_length=100,  # Define a standard length for sequences
                 augmentations=None, # List of augmentation names to apply
                 aug_params=None,    # Dictionary of parameters for augmentations
                 pad_mode='random',  # Padding mode for random_crop_or_pad
                 pad_value=0.0       # Padding value
                 ):

        # Ensure X and y are tensors, move to appropriate device if needed later
        self.X = torch.tensor(X, dtype=torch.float32) if isinstance(X, np.ndarray) else X.clone().detach()
        self.y = torch.tensor(y, dtype=torch.float32) if isinstance(y, np.ndarray) else y.clone().detach() # Assuming y is sequence-like

        # Make sure X and y have the same number of samples
        assert len(self.X) == len(self.y), "X and y must have the same number of samples."
        # Make sure samples in X and y have the same sequence length initially
        # This assumption might not hold if data is variable length initially.
        # If variable length, padding/cropping MUST be the first step before other augs.
        # Let's assume initial data might be variable length and pad/crop first or last.
        # Applying it LAST after other augs and Mixup is common.

        self.alpha = alpha # Mixup parameter
        self.target_length = target_length
        self.pad_mode = pad_mode
        self.pad_value = pad_value

        # Configure augmentations
        # Default parameters
        self._default_aug_params = {
            'jitter': {'noise_level': 0.05}, # 5% of the standard deviation of the data might be better
            'scale': {'scale_range': (0.9, 1.1)},
            'time_warp': {'num_control_points': 4, 'max_displacement_ratio': 0.1},
            # Note: random_crop_or_pad is handled separately at the end
        }
        self.augmentations = augmentations if augmentations is not None else []
        self.aug_params = self._default_aug_params.copy()
        if aug_params:
            for aug_name, params in aug_params.items():
                 if aug_name in self.aug_params:
                    self.aug_params[aug_name].update(params)
                 else:
                     print(f"Warning: Unknown augmentation '{aug_name}' ignored.")


    def __len__(self):
        return len(self.X)

    def _apply_single_augmentation(self, x: torch.Tensor, aug_name: str) -> torch.Tensor:
        """ Applies a single specified augmentation function. """
        params = self.aug_params.get(aug_name, {})
        try:
            if aug_name == 'jitter':
                # A potentially better jittering: relative to data std deviation
                # data_std = torch.std(x)
                # return jitter(x, noise_level=params.get('noise_level', 0.05) * data_std)
                # Or simpler fixed level:
                 return jitter(x, noise_level=params.get('noise_level', 0.05))
            elif aug_name == 'scale':
                return scale(x, scale_range=params.get('scale_range', (0.9, 1.1)))
            elif aug_name == 'time_warp':
                 if CubicSpline is None: return x # Skip if scipy not available
                 # Time warp needs numpy conversion
                 x_np = x.numpy()
                 x_warped_np = time_warp(x_np,
                                         num_control_points=params.get('num_control_points', 4),
                                         max_displacement_ratio=params.get('max_displacement_ratio', 0.1))
                 return torch.tensor(x_warped_np, dtype=x.dtype) # Convert back
            # Add other augmentations here if needed
            else:
                 print(f"Warning: Unknown augmentation '{aug_name}' specified.")
                 return x
        except Exception as e:
             print(f"Error applying augmentation '{aug_name}': {e}. Returning original sequence.")
             return x


    def _apply_augmentations(self, x: torch.Tensor) -> torch.Tensor:
        """ Applies the configured sequence of augmentations. """
        augmented_x = x
        # Apply augmentations in a fixed order (e.g., Jitter -> Scale -> Time Warp)
        # The order can matter!
        for aug_name in ['jitter', 'scale', 'time_warp']: # Define your desired order
            if aug_name in self.augmentations:
                 augmented_x = self._apply_single_augmentation(augmented_x, aug_name)
        return augmented_x


    def __getitem__(self, idx):
        # Get the first sample
        x1_orig, y1_orig = self.X[idx], self.y[idx]

        # Get a random second sample for Mixup
        shuffle_index = np.random.randint(0, len(self.X))
        x2_orig, y2_orig = self.X[shuffle_index], self.y[shuffle_index]

        # --- Apply Augmentations to Individual Samples (except random crop/pad) ---
        # Apply augmentations *before* Mixup.
        # Time warping requires numpy, so handle conversion if needed.
        # Jitter and Scale can work directly on torch tensors.
        x1_aug = self._apply_augmentations(x1_orig)
        x2_aug = self._apply_augmentations(x2_orig)

        # --- Apply Mixup ---
        # Only mix the augmented X. Mix the label Y if Y is a sequence label.
        # If Y is a single class label, you might only mix the label based on alpha
        # for label smoothing, not directly mixing the tensor.
        # Assuming Y is a sequence of the same length as X:
        weight = np.random.beta(self.alpha, self.alpha)
        weight_t = torch.tensor(weight, dtype=x1_aug.dtype) # Ensure weight is a tensor for tensor ops

        x_mix = x1_aug * weight_t + x2_aug * (1 - weight_t)
        y_mix = y1_orig * weight_t + y2_orig * (1 - weight_t) # Mix Y if Y is sequence-like

        # --- Apply Random Crop/Pad to standardize length ---
        # This should be applied *after* other augmentations and Mixup
        # to ensure the final output has the target_length.
        # It needs to be applied consistently to both x_mix and y_mix.

        # Combine x_mix and y_mix along a new dimension temporarily for consistent padding
        # Assuming x_mix shape [L, D_x] and y_mix shape [L, D_y]
        # Stack them to get [L, D_x + D_y]
        mixed_data = torch.cat([x_mix, y_mix], dim=-1) # Concatenate along feature dimension

        # Apply random crop or pad
        mixed_data_processed = random_crop_or_pad(mixed_data,
                                                  target_len=self.target_length,
                                                  pad_value=self.pad_value,
                                                  mode=self.pad_mode)

        # Split back into x and y
        # Assuming original dimensions D_x = x_mix.shape[-1] and D_y = y_mix.shape[-1]
        D_x = x_mix.shape[-1]
        x_final = mixed_data_processed[:, :D_x]
        y_final = mixed_data_processed[:, D_x:] # y_final will have shape [target_length, D_y]


        return x_final, y_final


# --- How to Use ---

if __name__ == '__main__':
    # Example Dummy Data (batch_size=10, seq_len=50-150, features=10)
    # Assume y is also a sequence of same length, e.g., per-frame labels
    num_samples = 100
    min_len, max_len = 50, 150
    num_features_x = 10
    num_features_y = 3 # Example: one-hot encoding for per-frame classification

    X_data = []
    y_data = []
    for i in range(num_samples):
        current_len = np.random.randint(min_len, max_len + 1)
        # Create some simple wavy data + noise
        t = np.linspace(0, 4 * np.pi, current_len)
        x_sample = np.vstack([
            np.sin(t) + np.random.randn(current_len) * 0.1 * (i / num_samples),
            np.cos(t * 0.5) + np.random.randn(current_len) * 0.1 * (i / num_samples),
            np.linspace(0, 1, current_len) + np.random.randn(current_len) * 0.05,
            np.random.rand(current_len, num_features_x - 3) * 0.1 # Other features as noise
        ]).T.astype(np.float32) # Shape [L, D_x]

        # Create dummy per-frame labels (e.g., based on time)
        y_sample_idx = (t > np.pi).astype(int) + (t > 2*np.pi).astype(int) # 0, 1, or 2
        y_sample = np.zeros((current_len, num_features_y), dtype=np.float32)
        y_sample[np.arange(current_len), y_sample_idx] = 1 # One-hot like [L, D_y]


        X_data.append(x_sample)
        y_data.append(y_sample)

    print(f"Generated {num_samples} samples with variable lengths ({min_len}-{max_len})")
    print(f"Sample 0 X shape: {X_data[0].shape}, y shape: {y_data[0].shape}")


    # --- Configure the Augmented Dataset ---
    target_sequence_length = 128

    # Augmentations to use: jitter, scale, time_warp
    # random_crop_or_pad is applied automatically at the end
    augmentations_to_apply = ['jitter', 'scale', 'time_warp']

    # Optional: Customize parameters
    custom_aug_params = {
        'jitter': {'noise_level': 0.02},
        'scale': {'scale_range': (0.85, 1.15)},
        # Use default params for time_warp
    }


    dataset = MotionDatasetAugmented(
        X=X_data,
        y=y_data,
        alpha=0.4, # Mixup strength
        target_length=target_sequence_length,
        augmentations=augmentations_to_apply,
        aug_params=custom_aug_params,
        pad_mode='random', # 'random', 'start', or 'end' padding/cropping
        pad_value=0.0
    )

    print(f"\nCreated augmented dataset with target length {target_sequence_length}")
    print(f"Augmentations active: {dataset.augmentations}")
    print(f"Mixup alpha: {dataset.alpha}")
    print(f"Padding/Cropping mode: {dataset.pad_mode}")


    # --- Test getting a sample ---
    sample_idx = 0
    x_orig_sample, y_orig_sample = X_data[sample_idx], y_data[sample_idx]
    x_aug_mix_processed, y_aug_mix_processed = dataset[sample_idx]

    print(f"\nOriginal sample {sample_idx}: X shape {x_orig_sample.shape}, y shape {y_orig_sample.shape}")
    print(f"Processed sample {sample_idx} from dataset: X shape {x_aug_mix_processed.shape}, y shape {y_aug_mix_processed.shape}")

    # Verify shape is the target length
    assert x_aug_mix_processed.shape[0] == target_sequence_length
    assert y_aug_mix_processed.shape[0] == target_sequence_length
    assert x_aug_mix_processed.shape[1] == num_features_x
    assert y_aug_mix_processed.shape[1] == num_features_y

    print("\nSample processed successfully and has target length.")

    # --- Example with DataLoader ---
    # Use a DataLoader to process data in batches with multiprocessing
    # num_workers > 0 is recommended for faster data loading and augmentation
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=32,
        shuffle=True, # Shuffle the dataset indices
        num_workers=0 # Set to > 0 for faster loading in practice
    )

    print(f"\nCreated DataLoader with batch size {dataloader.batch_size}")

    # Get a batch
    for batch_x, batch_y in dataloader:
        print(f"\nReceived batch:")
        print(f" Batch X shape: {batch_x.shape}") # Expected: [batch_size, target_length, num_features_x]
        print(f" Batch y shape: {batch_y.shape}") # Expected: [batch_size, target_length, num_features_y]
        print(f" Data type X: {batch_x.dtype}, Data type y: {batch_y.dtype}")
        # Assert batch dimensions match expectations
        assert batch_x.shape == (dataloader.batch_size, target_sequence_length, num_features_x)
        assert batch_y.shape == (dataloader.batch_size, target_sequence_length, num_features_y)
        break # Just take one batch for demonstration

    print("\nDataLoader batch processed successfully.")
    # Remember to install scipy if time_warp is used: pip install scipy