In [None]:
import torch
import random
from typing import List
from federated_inference.client.sensor_view import SensorView


class OneDStridePartitionTransform:
    def __init__(self, config):
        """
        Config should have:
        - METHOD_NAME: 'fixed_number', 'drop_probability', 'full', 'random'
        - MASK_SIZE: int
        - STRIDE: int
        - N_POSITION: int (if fixed_number)
        - DROP_P: float (if drop_probability)
        - TENSOR_SIZE: int
        """
        self.method_name = config.METHOD_NAME
        self.mask_size = config.MASK_SIZE
        self.stride = config.STRIDE
        self.n_position = getattr(config, "N_POSITION", None)
        self.drop_p = getattr(config, "DROP_P", None)
        self.tensor_size = config.TENSOR_SIZE

        self.positions = self._compute_positions()
    
    def _compute_positions(self):
        if self.mask_size > self.tensor_size:
            raise ValueError("Mask size too large for the given tensor size.")

        # --- RANDOM MODE ---
        if self.method_name == 'random':
            indices = list(range(self.tensor_size))
            random.shuffle(indices)
            # Create random groups of indices (chunks of mask_size)
            chunks = [indices[i:i + self.mask_size] for i in range(0, len(indices), self.mask_size)]
            return chunks  # Each chunk is a list of positions (non-contiguous)
        
        # --- STRIDED MODES ---
        positions = list(range(0, self.tensor_size - self.mask_size + 1, self.stride))
        if positions[-1] + self.mask_size < self.tensor_size:
            positions.append(self.tensor_size - self.mask_size)

        if self.method_name == 'fixed_number':
            if not self.n_position:
                raise ValueError("n_position must be set for method 'fixed_number'")
            return random.sample(positions, min(self.n_position, len(positions)))

        elif self.method_name == 'drop_probability':
            if self.drop_p is None:
                raise ValueError("drop_p must be set for method 'drop_probability'")
            return [pos for pos in positions if random.random() > self.drop_p]

        elif self.method_name == 'full':
            return positions

        else:
            raise ValueError(f"Unknown method_name: {self.method_name}")

    def _slice_tensor(self, tensor: torch.Tensor) -> List[SensorView]:
        """
        Slices a 1D tensor into views.
        In 'random' mode, each view contains randomly selected elements.
        """
        if tensor.ndim != 1:
            raise ValueError("Expected 1D tensor")

        views = []

        # Handle RANDOM mode separately
        if self.method_name == 'random':
            for chunk in self.positions:
                slice_tensor = tensor[torch.tensor(chunk)]
                index_slice = chunk  # not contiguous, store as list
                views.append(SensorView(slice_tensor, index_slice, tensor.shape))
            return views

        # Normal strided slicing
        for pos in self.positions:
            end = min(pos + self.mask_size, tensor.shape[0])
            slice_tensor = tensor[pos:end]
            index_slice = slice(pos, end)
            views.append(SensorView(slice_tensor, index_slice, tensor.shape))

        return views

    def __call__(self, tensor: torch.Tensor) -> List[SensorView]:
        return self._slice_tensor(tensor)
