In [7]:
import torch

def prepare_state_action_pairs(states, actions):
    """
    Prepare state-action pairs for all actions and all states in the batch.

    Parameters:
        states (torch.Tensor): The state tensor of shape [batch_size, state_dim].
        actions (torch.Tensor): A linspace tensor of actions of shape [num_actions].

    Returns:
        torch.Tensor: A tensor of shape [batch_size * num_actions, state_dim + 1] where each state
                      is repeated for each action and concatenated with the action.
    """
    # Number of actions and batch size
    num_actions = actions.size(0)
    batch_size = states.size(0)
    
    # Repeat each state for each action
    states = states.unsqueeze(1).repeat(1, num_actions, 1)
    states = states.view(batch_size * num_actions, -1)

    # Repeat actions for the whole batch
    actions = actions.repeat(batch_size, 1).view(batch_size * num_actions, 1)
    
    # Concatenate states with actions
    state_action_pairs = torch.cat((states, actions), dim=1)
    
    return state_action_pairs

# Example usage:
states = torch.tensor([[0.5, -0.1, 0.3, .4], [0.6, -0.16, 0.6, -.7]])  # Example state vectors for two instances
actions = torch.linspace(-1, 1, steps=2)  # Example actions

state_action_pairs = prepare_state_action_pairs(states, actions)

NUM_ACTIONS = 2
N_FEATURES = 4
N_BATCH = 2

print(state_action_pairs)


tensor([[ 0.5000, -0.1000,  0.3000,  0.4000, -1.0000],
        [ 0.5000, -0.1000,  0.3000,  0.4000,  1.0000],
        [ 0.6000, -0.1600,  0.6000, -0.7000, -1.0000],
        [ 0.6000, -0.1600,  0.6000, -0.7000,  1.0000]])


In [10]:
state_action_pairs = state_action_pairs.view(N_BATCH, (N_FEATURES + 1) * NUM_ACTIONS)
state_action_pairs

tensor([[ 0.5000, -0.1000,  0.3000,  0.4000, -1.0000,  0.5000, -0.1000,  0.3000,
          0.4000,  1.0000],
        [ 0.6000, -0.1600,  0.6000, -0.7000, -1.0000,  0.6000, -0.1600,  0.6000,
         -0.7000,  1.0000]])

In [13]:
state_action_pairs = state_action_pairs.view(N_BATCH * NUM_ACTIONS, N_FEATURES + 1)
state_action_pairs


tensor([[ 0.5000, -0.1000,  0.3000,  0.4000, -1.0000],
        [ 0.5000, -0.1000,  0.3000,  0.4000,  1.0000],
        [ 0.6000, -0.1600,  0.6000, -0.7000, -1.0000],
        [ 0.6000, -0.1600,  0.6000, -0.7000,  1.0000]])

In [18]:
t = torch.tensor([[1], [4], [7], [2]])
t

tensor([[1],
        [4],
        [7],
        [2]])

In [21]:
t = t.view(-1, NUM_ACTIONS)

t

tensor([[1, 4],
        [7, 2]])