In [1]:
import torch
from torch import nn

In [2]:
# num flag object points
NUM_X_POINTS = 2
NUM_Y_POINTS = 2
NUM_Z_POINTS = 1

INDEX_TO_POS = 1000

In [3]:
# object dimensions grid
grid_od1, grid_od2, grid_od3 = torch.meshgrid(torch.arange(1, NUM_X_POINTS+1), torch.arange(1, NUM_Y_POINTS+1), torch.arange(1, NUM_Z_POINTS+1), indexing="ij")
grid_od1.shape, grid_od2.shape, grid_od3.shape

(torch.Size([2, 2, 1]), torch.Size([2, 2, 1]), torch.Size([2, 2, 1]))

In [4]:
position = torch.stack((grid_od1, grid_od2, grid_od3), dim=-1) / INDEX_TO_POS
position.shape

torch.Size([2, 2, 1, 3])

In [5]:
# velocity = torch.randn_like(position) * .01
velocity = torch.zeros_like(position)

In [6]:
def f_gravity(position):
    # gravity force
    result = torch.zeros_like(position)
    result[..., 2] = -9.8
    result[0, 0, 0], f_gravity[1, 0, 0], f_gravity[0, 1, 0]
    return result


def roll_with_repeat(tensor, shifts, dims):
    if shifts not in [-1, 1]:
        raise ValueError("Shifts should be -1 or +1.")
    if dims not in [0, 1, 2]:
        raise ValueError("Dims should be 0, 1, or 2.")
    # roll tensor
    if dims == 0 and shifts == -1:
        roll_tensor = torch.roll(tensor, shifts=-1, dims=0)
        roll_tensor[-1, ...] = tensor[-1, ...]
    if dims == 0 and shifts == 1:
        roll_tensor = torch.roll(tensor, shifts=1, dims=0)
        roll_tensor[0, ...] = tensor[0, ...]
    if dims == 1 and shifts == -1:
        roll_tensor = torch.roll(tensor, shifts=-1, dims=1)
        roll_tensor[:, -1, ...] = tensor[:, -1, ...]
    if dims == 1 and shifts == 1:
        roll_tensor = torch.roll(tensor, shifts=1, dims=1)
        roll_tensor[:, 0, ...] = tensor[:, 0, ...]
    if dims == 2 and shifts == -1:
        roll_tensor = torch.roll(tensor, shifts=-1, dims=2)
        roll_tensor[:, :, -1, ...] = tensor[:, :, -1, ...]
    if dims == 2 and shifts == 1:
        roll_tensor = torch.roll(tensor, shifts=1, dims=2)
        roll_tensor[:, :, 0, ...] = tensor[:, :, 0, ...]
    return roll_tensor


def roll_position_and_velocity(position, velocity, shifts, dims):
    return roll_with_repeat(position, shifts, dims), roll_with_repeat(velocity, shifts, dims)


def compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length):
    delta_position = position - position_roll
    delta_position_normalize = nn.functional.normalize(delta_position, dim=-1)
    delta_velocity = velocity - velocity_roll
    f_spring = (torch.linalg.norm(delta_position, dim=-1) - rest_length) * ks
    f_damper = (delta_position_normalize * delta_velocity).sum(-1) * kd
    return (f_spring + f_damper).unsqueeze(-1) * delta_position_normalize


def compute_internal_forces(position, velocity, ks, kd, rest_length):
    ##################
    # adjacent nodes #
    ##################
    # object dimension 1 spring
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, 1, 0)
    od1_pos = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, -1, 0)
    od1_neg = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # object dimension 2 spring
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, 1, 1)
    od2_pos = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, -1, 1)
    od2_neg = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # object dimension 3 spring
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, 1, 2)
    od3_pos = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, -1, 2)
    od3_neg = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    ##################
    # diagonal nodes #
    ##################
    # object dimension diagonal 12
    # pos pos
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, 1, 0)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, 1, 1)
    od12_pos_pos = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # pos neg
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, 1, 0)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, -1, 1)
    od12_pos_neg = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # neg pos
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, -1, 0)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, 1, 1)
    od12_neg_pos = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # neg neg
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, -1, 0)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, -1, 1)
    od12_neg_neg = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # object dimension diagonal 13
    # pos pos
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, 1, 0)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, 1, 2)
    od13_pos_pos = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # pos neg
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, 1, 0)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, -1, 2)
    od13_pos_neg = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # neg pos
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, -1, 0)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, 1, 2)
    od13_neg_pos = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # neg neg
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, -1, 0)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, -1, 2)
    od13_neg_neg = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # object dimension diagonal 23
    # pos pos
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, 1, 1)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, 1, 2)
    od23_pos_pos = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # pos neg
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, 1, 1)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, -1, 2)
    od23_pos_neg = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # neg pos
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, -1, 1)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, 1, 2)
    od23_neg_pos = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # neg neg
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, -1, 1)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, -1, 2)
    od23_neg_neg = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    ##########
    # sum up #
    ##########
    adjacent_forces = od1_neg + od1_pos + od2_neg + od2_pos + od3_neg + od3_pos
    diagonal_forces = od12_pos_pos + od12_pos_neg + od12_neg_pos + od12_neg_neg + od13_pos_pos + od13_pos_neg + od13_neg_pos + od13_neg_neg + od23_pos_pos + od23_pos_neg + od23_neg_pos + od23_neg_neg
    forces = adjacent_forces + diagonal_forces
    return forces


def f_wind(position, time):
    torch.zeros_like(position)

In [7]:
ks, kd, rest_length = 1, 1, 1
compute_internal_forces(position, velocity, ks, kd, rest_length)

tensor([[[[ 4.7021,  4.7021, -0.0000]],

         [[ 4.7021, -4.7021, -0.0000]]],


        [[[-4.7021,  4.7021, -0.0000]],

         [[-4.7021, -4.7021, -0.0000]]]])