In [13]:
import math

import torch

In [5]:
__GAIT_CFG = {
    "tripod": {
        "default_alpha": ((0, math.pi, math.pi, 0, 0, math.pi)),
        "groups": ((0, 3, 4), (1, 2, 5)),
    },
    "tetrapod": {
        "default_alpha": ((0, 2 * math.pi / 3, 2 * math.pi / 3, 4 * math.pi / 3, 4 * math.pi / 3, 0)),
        "groups": ((0, 5), (1, 2), (3, 4)),
    },
    "wave": {
        "default_alpha": ((0, math.pi, math.pi / 3, 4 * math.pi / 3, 2 * math.pi / 3, 5 * math.pi / 3)),
        "groups": ((0,), (2,), (4,), (1,), (3,), (5,)),
    },
}

def __get_default_alpha(gaits: str, num_envs: int, device: str):
    __default_alpha_list = []
    try:
        for gait in gaits:
            __default_alpha_list.append(__GAIT_CFG[gait]["default_alpha"])
    except KeyError as gait_name:
        raise KeyError(f"Invalid gait: {gait_name}")

    __idxs = torch.randint(len(__default_alpha_list), (num_envs,))
    return torch.tensor(__default_alpha_list, device=device)[__idxs], __idxs


def __get_psi(gaits: str, idxs: torch.Tensor, device: str) -> torch.Tensor:
    __psi_list = []
    try:
        for gait in gaits:
            __default_alpha = torch.tensor(__GAIT_CFG[gait]["default_alpha"], device=device).unsqueeze(0)
            __psi_list.append((__default_alpha.T - __default_alpha).tolist())
    except KeyError as gait_name:
        raise KeyError(f"Invalid gait: {gait_name}")

    return torch.tensor(__psi_list, device=device)[idxs]


def __get_m(
    types: str,
    self_weight: float,
    in_group_weight: float,
    of_group_weight: float,
    num_envs: int,
    device: str,
):
    def __none_coupling():
        return torch.zeros((6, 6)).tolist()

    def __sparse_coupling(gait: str):
        __m = torch.zeros((6, 6)).tolist()
        try:
            groups = __GAIT_CFG[gait]["groups"]
        except KeyError as e:
            raise KeyError(f"Invalid gait: {e}")
        for group_idx in range(len(groups)):
            group = groups[group_idx]
            for i in group:
                for j in range(6):
                    if group[j] in group:
                        if i == j:
                            __m[i][j] = self_weight
                        else:
                            __m[i][j] = in_group_weight
                    else:
                        __m[i][j] = of_group_weight

        return __m

    def __ring_coupling():
        return (
            (
                self_weight,
                in_group_weight,
                in_group_weight,
                of_group_weight,
                of_group_weight,
                of_group_weight,
            ),
            (
                in_group_weight,
                self_weight,
                of_group_weight,
                in_group_weight,
                of_group_weight,
                of_group_weight,
            ),
            (
                in_group_weight,
                of_group_weight,
                self_weight,
                of_group_weight,
                in_group_weight,
                of_group_weight,
            ),
            (
                of_group_weight,
                in_group_weight,
                of_group_weight,
                self_weight,
                of_group_weight,
                in_group_weight,
            ),
            (
                of_group_weight,
                of_group_weight,
                in_group_weight,
                of_group_weight,
                self_weight,
                in_group_weight,
            ),
            (
                of_group_weight,
                of_group_weight,
                of_group_weight,
                in_group_weight,
                in_group_weight,
                self_weight,
            ),
        )

    def __all_to_all_coupling():
        return (torch.ones((6, 6)) * in_group_weight - torch.eye(6) * (in_group_weight - self_weight)).tolist()

    __m_list = []
    for type in types:
        if type == "none":
            __m_list.append(__none_coupling())
        elif type == "sparse":
            __m_list.append(__sparse_coupling(type))
        elif type == "ring":
            __m_list.append(__ring_coupling())
        elif type == "all-to-all":
            __m_list.append(__all_to_all_coupling())
        else:
            raise ValueError(f"Invalid coupling type: {type}")

    return torch.tensor(__m_list, device=device)[torch.randint(len(__m_list), (num_envs,), device=device)]

In [6]:
gaits = ("tripod", "tetrapod")

default_alpha, idxs = __get_default_alpha(gaits, 3, "cuda")
print(default_alpha)
psi = __get_psi(gaits, idxs, "cuda")
print(psi)

tensor([[0.0000, 3.1416, 3.1416, 0.0000, 0.0000, 3.1416],
        [0.0000, 3.1416, 3.1416, 0.0000, 0.0000, 3.1416],
        [0.0000, 2.0944, 2.0944, 4.1888, 4.1888, 0.0000]], device='cuda:0')
tensor([[[ 0.0000, -3.1416, -3.1416,  0.0000,  0.0000, -3.1416],
         [ 3.1416,  0.0000,  0.0000,  3.1416,  3.1416,  0.0000],
         [ 3.1416,  0.0000,  0.0000,  3.1416,  3.1416,  0.0000],
         [ 0.0000, -3.1416, -3.1416,  0.0000,  0.0000, -3.1416],
         [ 0.0000, -3.1416, -3.1416,  0.0000,  0.0000, -3.1416],
         [ 3.1416,  0.0000,  0.0000,  3.1416,  3.1416,  0.0000]],

        [[ 0.0000, -3.1416, -3.1416,  0.0000,  0.0000, -3.1416],
         [ 3.1416,  0.0000,  0.0000,  3.1416,  3.1416,  0.0000],
         [ 3.1416,  0.0000,  0.0000,  3.1416,  3.1416,  0.0000],
         [ 0.0000, -3.1416, -3.1416,  0.0000,  0.0000, -3.1416],
         [ 0.0000, -3.1416, -3.1416,  0.0000,  0.0000, -3.1416],
         [ 3.1416,  0.0000,  0.0000,  3.1416,  3.1416,  0.0000]],

        [[ 0.0000, -2.09

In [2]:
import torch
import math

def phase_to_weight_matrix(
    phase_offsets: torch.Tensor,   # shape: [num_envs, num_legs, num_legs]
    self_weight: float,
    in_group_weight: float,
    of_group_weight: float,
    phase_threshold: float = 0.1   # how close to zero modulo 2pi counts as "in group"
) -> torch.Tensor:
    """
    Convert phase offset matrix to coupling weight matrix.

    Args:
        phase_offsets: tensor of shape [num_envs, num_legs, num_legs] with phase differences in radians.
        self_weight: weight for diagonal elements (self-coupling).
        in_group_weight: weight for legs with phase difference close to 0 mod 2pi.
        of_group_weight: weight for legs with phase difference NOT close to 0 mod 2pi.
        phase_threshold: max absolute phase difference (mod 2pi) to count as "in group".

    Returns:
        weights: tensor same shape as phase_offsets with coupling weights.
    """
    num_envs, num_legs, _ = phase_offsets.shape

    # Normalize phase differences mod 2pi to [0, 2pi)
    two_pi = 2 * math.pi
    phase_mod = torch.fmod(phase_offsets, two_pi)
    # Shift negatives into [0, 2pi)
    phase_mod = torch.where(phase_mod < 0, phase_mod + two_pi, phase_mod)

    # Distance to 0 mod 2pi is min(phase_mod, 2pi - phase_mod)
    dist_to_zero = torch.minimum(phase_mod, two_pi - phase_mod)

    # Create weight tensor filled with of_group_weight by default
    weights = torch.full_like(phase_offsets, of_group_weight)

    # Self connections (diagonal)
    eye_mask = torch.eye(num_legs, dtype=torch.bool, device=phase_offsets.device).unsqueeze(0).expand(num_envs, -1, -1)
    weights[eye_mask] = self_weight

    # In-group connections: phase difference close to zero modulo 2pi, excluding diagonal
    in_group_mask = (dist_to_zero <= phase_threshold) & (~eye_mask)
    weights[in_group_mask] = in_group_weight

    return weights


In [53]:
phase_offsets = torch.tensor([[
    [0.0, -3.1416, -3.1416, 0.0, 0.0, -3.1416],
    [3.1416, 0.0, 0.0, 3.1416, 3.1416, 0.0],
    [3.1416, 0.0, 0.0, 3.1416, 3.1416, 0.0],
    [0.0, -3.1416, -3.1416, 0.0, 0.0, -3.1416],
    [0.0, -3.1416, -3.1416, 0.0, 0.0, -3.1416],
    [3.1416, 0.0, 0.0, 3.1416, 3.1416, 0.0],
]])

weights = phase_to_weight_matrix(
    phase_offsets=psi,
    self_weight=0.0,
    in_group_weight=0.5,
    of_group_weight=-0.5,
    phase_threshold=0.5
)

print(weights)


tensor([[[ 0.0000, -0.5000, -0.5000,  0.5000,  0.5000, -0.5000],
         [-0.5000,  0.0000,  0.5000, -0.5000, -0.5000,  0.5000],
         [-0.5000,  0.5000,  0.0000, -0.5000, -0.5000,  0.5000],
         [ 0.5000, -0.5000, -0.5000,  0.0000,  0.5000, -0.5000],
         [ 0.5000, -0.5000, -0.5000,  0.5000,  0.0000, -0.5000],
         [-0.5000,  0.5000,  0.5000, -0.5000, -0.5000,  0.0000]],

        [[ 0.0000, -0.5000, -0.5000,  0.5000,  0.5000, -0.5000],
         [-0.5000,  0.0000,  0.5000, -0.5000, -0.5000,  0.5000],
         [-0.5000,  0.5000,  0.0000, -0.5000, -0.5000,  0.5000],
         [ 0.5000, -0.5000, -0.5000,  0.0000,  0.5000, -0.5000],
         [ 0.5000, -0.5000, -0.5000,  0.5000,  0.0000, -0.5000],
         [-0.5000,  0.5000,  0.5000, -0.5000, -0.5000,  0.0000]],

        [[ 0.0000, -0.5000, -0.5000, -0.5000, -0.5000,  0.5000],
         [-0.5000,  0.0000,  0.5000, -0.5000, -0.5000, -0.5000],
         [-0.5000,  0.5000,  0.0000, -0.5000, -0.5000, -0.5000],
         [-0.5000, -0

In [38]:
jitter = torch.randn(10, 6)
print(jitter)
alpha = torch.randint(0, 2, (6,))
print(alpha)
res = alpha + jitter
print(res)

tensor([[ 1.5354,  0.1325,  0.7060,  1.3514, -1.7453, -0.2223],
        [-1.4312,  0.1611, -1.0931, -0.1287,  0.9859, -0.8875],
        [ 0.5958,  0.7297,  0.3742,  1.5163, -0.4440,  1.5191],
        [ 0.4367,  0.0983,  1.3179,  1.2835, -0.1600, -2.1167],
        [-0.0321,  1.3956, -0.5046, -0.7822, -1.7249,  0.6443],
        [ 0.1064,  2.2260, -0.4442, -0.8040,  0.3010,  0.3476],
        [-0.1815, -0.5898, -1.7871, -0.7962,  1.3015, -0.2152],
        [-0.4850, -0.5750,  0.5492, -0.3482,  1.1132,  1.9977],
        [ 0.0218,  1.1390, -0.7775, -0.8866,  0.2568, -0.0314],
        [ 1.1084, -0.1673, -1.2611, -0.0281, -0.1894,  0.6685]])
tensor([1, 0, 0, 0, 1, 1])
tensor([[ 2.5354,  0.1325,  0.7060,  1.3514, -0.7453,  0.7777],
        [-0.4312,  0.1611, -1.0931, -0.1287,  1.9859,  0.1125],
        [ 1.5958,  0.7297,  0.3742,  1.5163,  0.5560,  2.5191],
        [ 1.4367,  0.0983,  1.3179,  1.2835,  0.8400, -1.1167],
        [ 0.9679,  1.3956, -0.5046, -0.7822, -0.7249,  1.6443],
        [ 1.

In [47]:
arr = psi[0]
arr.unsqueeze(0).expand(3, -1, -1)

tensor([[[ 0.0000, -3.1416, -3.1416,  0.0000,  0.0000, -3.1416],
         [ 3.1416,  0.0000,  0.0000,  3.1416,  3.1416,  0.0000],
         [ 3.1416,  0.0000,  0.0000,  3.1416,  3.1416,  0.0000],
         [ 0.0000, -3.1416, -3.1416,  0.0000,  0.0000, -3.1416],
         [ 0.0000, -3.1416, -3.1416,  0.0000,  0.0000, -3.1416],
         [ 3.1416,  0.0000,  0.0000,  3.1416,  3.1416,  0.0000]],

        [[ 0.0000, -3.1416, -3.1416,  0.0000,  0.0000, -3.1416],
         [ 3.1416,  0.0000,  0.0000,  3.1416,  3.1416,  0.0000],
         [ 3.1416,  0.0000,  0.0000,  3.1416,  3.1416,  0.0000],
         [ 0.0000, -3.1416, -3.1416,  0.0000,  0.0000, -3.1416],
         [ 0.0000, -3.1416, -3.1416,  0.0000,  0.0000, -3.1416],
         [ 3.1416,  0.0000,  0.0000,  3.1416,  3.1416,  0.0000]],

        [[ 0.0000, -3.1416, -3.1416,  0.0000,  0.0000, -3.1416],
         [ 3.1416,  0.0000,  0.0000,  3.1416,  3.1416,  0.0000],
         [ 3.1416,  0.0000,  0.0000,  3.1416,  3.1416,  0.0000],
         [ 0.0000, -3

In [59]:
phase_offsets[0]

tensor([[ 0.0000, -3.1416, -3.1416,  0.0000,  0.0000, -3.1416],
        [ 3.1416,  0.0000,  0.0000,  3.1416,  3.1416,  0.0000],
        [ 3.1416,  0.0000,  0.0000,  3.1416,  3.1416,  0.0000],
        [ 0.0000, -3.1416, -3.1416,  0.0000,  0.0000, -3.1416],
        [ 0.0000, -3.1416, -3.1416,  0.0000,  0.0000, -3.1416],
        [ 3.1416,  0.0000,  0.0000,  3.1416,  3.1416,  0.0000]])

In [64]:
r = torch.randint(0, 3, (10, 6))
alpha = torch.randint(-2, 2, (10, 6))
m = torch.randint(-1, 1, (6, 6))
psi = torch.randn((6, 6)) * torch.pi

num_envs, net_size = r.shape

m_exp = m.unsqueeze(0).expand(num_envs, -1, -1)     # (num_envs, net_size, net_size)
psi_exp = psi.unsqueeze(0).expand(num_envs, -1, -1) # (num_envs, net_size, net_size)

r_j = r.unsqueeze(1)    # (num_envs, 1, net_size)
alpha_i = alpha.unsqueeze(2)  # (num_envs, net_size, 1)
alpha_j = alpha.unsqueeze(1)  # (num_envs, 1, net_size)

diff = alpha_j - alpha_i      # (num_envs, net_size, net_size)
sin_term = torch.sin(diff - psi_exp)  # (num_envs, net_size, net_size)

result = (r_j * m_exp * sin_term).sum(dim=2)  # (num_envs, net_size)


print(result)


tensor([[ 0.0000,  0.0000,  0.0000, -1.7284,  1.7502,  0.0000],
        [ 0.3799, -2.4861,  1.4293,  1.6425, -1.6664, -3.3912],
        [-0.4757, -0.3705,  1.3467, -2.2338,  0.0838,  5.0986],
        [ 0.7406, -1.3262,  2.0067,  1.3693,  0.0000, -4.5630],
        [-0.9281, -1.2776,  0.8987,  2.3723, -0.9558,  3.8508],
        [ 0.0000,  0.7354,  1.7122, -1.6038, -1.8693, -1.7703],
        [ 0.0000,  0.0000,  1.9513,  1.6089,  0.0000,  1.9875],
        [ 1.4495, -0.5741, -0.1828, -0.3780, -0.8332, -1.3100],
        [ 0.1899, -0.2535, -0.7581, -0.2488, -2.6352, -1.9093],
        [ 0.0000,  0.0000, -1.1041, -0.0271, -1.2467, -1.9361]])


In [71]:
arr = torch.randint(0, 3, (6, 6))
print(arr)
res1 = arr.remainder(2)
print(res)
res2 = res1.remainder(2)
print(res2)

tensor([[2, 2, 1, 1, 0, 1],
        [2, 0, 2, 1, 2, 0],
        [2, 0, 0, 1, 2, 0],
        [0, 2, 1, 1, 1, 1],
        [2, 2, 1, 1, 0, 1],
        [0, 2, 2, 2, 0, 0]])
tensor([[0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0],
        [0, 0, 1, 0, 1, 0],
        [0, 0, 1, 0, 1, 0],
        [0, 0, 1, 1, 1, 1]])
tensor([[0, 0, 1, 1, 0, 1],
        [0, 0, 0, 1, 0, 0],
        [0, 0, 0, 1, 0, 0],
        [0, 0, 1, 1, 1, 1],
        [0, 0, 1, 1, 0, 1],
        [0, 0, 0, 0, 0, 0]])


In [72]:
res1 = res2
res2 = res2 + 1
print(res1)
print(res2)

tensor([[0, 0, 1, 1, 0, 1],
        [0, 0, 0, 1, 0, 0],
        [0, 0, 0, 1, 0, 0],
        [0, 0, 1, 1, 1, 1],
        [0, 0, 1, 1, 0, 1],
        [0, 0, 0, 0, 0, 0]])
tensor([[1, 1, 2, 2, 1, 2],
        [1, 1, 1, 2, 1, 1],
        [1, 1, 1, 2, 1, 1],
        [1, 1, 2, 2, 2, 2],
        [1, 1, 2, 2, 1, 2],
        [1, 1, 1, 1, 1, 1]])


In [22]:
import torch
torch.rand((10,), dtype=torch.float32)

tensor([0.1156, 0.1537, 0.3683, 0.5784, 0.7465, 0.2966, 0.1635, 0.7019, 0.5059,
        0.5744])