In [1]:
import torch
from torch import nn

In [2]:
# num flag object points
NUM_X_POINTS = 1500
NUM_Y_POINTS = 1000
NUM_Z_POINTS = 1

INDEX_TO_POS = 1000

In [3]:
# object dimensions grid
grid_od1, grid_od2, grid_od3 = torch.meshgrid(torch.arange(1, NUM_X_POINTS+1), torch.arange(1, NUM_Y_POINTS+1), torch.arange(1, NUM_Z_POINTS+1), indexing="ij")
grid_od1.shape, grid_od2.shape, grid_od3.shape

(torch.Size([1500, 1000, 1]),
 torch.Size([1500, 1000, 1]),
 torch.Size([1500, 1000, 1]))

In [4]:
position = torch.stack((grid_od1, grid_od2, grid_od3), dim=-1) / INDEX_TO_POS
position.shape

torch.Size([1500, 1000, 1, 3])

In [None]:
# velocity = torch.randn_like(position) * .01
velocity = torch.zeros_like(position)

In [None]:
def f_gravity(position):
    # gravity force
    result = torch.zeros_like(position)
    result[..., 2] = -9.8
    result[0, 0, 0], f_gravity[1, 0, 0], f_gravity[0, 1, 0]
    return result


def f_spring(position, ks, rest_length):
    pdist = nn.PairwiseDistance(p=2)
    ##################
    # adjacent nodes #
    ##################
    # object dimension 1 spring
    roll_pos = torch.roll(position, shifts=1, dims=0)
    roll_pos[0, ...] = position[0, ...]
    od1_pos = (pdist(position, roll_pos) - rest_length) * ks
    roll_neg = torch.roll(position, shifts=-1, dims=0)
    roll_neg[-1, ...] = position[-1, ...]
    od1_neg = (pdist(position, roll_neg) - rest_length) * ks
    # object dimension 2 spring
    roll_pos = torch.roll(position, shifts=1, dims=1)
    roll_pos[:, 0, ...] = position[:, 0, ...]
    od2_pos = (pdist(position, roll_pos) - rest_length) * ks
    roll_neg = torch.roll(position, shifts=-1, dims=1)
    roll_neg[:, -1, ...] = position[:, -1, ...]
    od2_neg = (pdist(position, roll_neg) - rest_length) * ks
    # object dimension 3 spring
    roll_pos = torch.roll(position, shifts=1, dims=2)
    roll_pos[:, :, 0, ...] = position[:, :, 0, ...]
    od3_pos = (pdist(position, roll_pos) - rest_length) * ks
    roll_neg = torch.roll(position, shifts=-1, dims=2)
    roll_neg[:, :, -1, ...] = position[:, :, -1, ...]
    od3_neg = (pdist(position, roll_neg) - rest_length) * ks
    ##################
    # diagonal nodes #
    ##################
    # object dimension diagonal 12
    # pos pos
    roll_pos = torch.roll(position, shifts=1, dims=0)
    roll_pos[0, ...] = position[0, ...]
    roll_pos_pos = torch.roll(roll_pos, shifts=1, dims=1)
    roll_pos_pos[:, 0, ...] = roll_pos[:, 0, ...]
    od12_pos_pos = (pdist(position, roll_pos_pos) - rest_length) * ks
    # pos neg
    roll_pos = torch.roll(position, shifts=1, dims=0)
    roll_pos[0, ...] = position[0, ...]
    roll_pos_neg = torch.roll(roll_pos, shifts=-1, dims=1)
    roll_pos_neg[:, -1, ...] = roll_pos[:, -1, ...]
    od12_pos_neg = (pdist(position, roll_pos_neg) - rest_length) * ks
    # neg pos
    roll_neg = torch.roll(position, shifts=-1, dims=0)
    roll_neg[-1, ...] = position[-1, ...]
    roll_neg_pos = torch.roll(roll_neg, shifts=1, dims=1)
    roll_neg_pos[:, 0, ...] = roll_neg[:, 0, ...]
    od12_neg_pos = (pdist(position, roll_neg_pos) - rest_length) * ks
    # neg neg
    roll_neg = torch.roll(position, shifts=-1, dims=0)
    roll_neg[-1, ...] = position[-1, ...]
    roll_neg_neg = torch.roll(roll_neg, shifts=1, dims=1)
    roll_neg_neg[:, -1, ...] = roll_neg[:, -1, ...]
    od12_neg_neg = (pdist(position, roll_neg_neg) - rest_length) * ks
    # object dimension diagonal 13
    # pos pos
    roll_pos = torch.roll(position, shifts=1, dims=0)
    roll_pos[0, ...] = position[0, ...]
    roll_pos_pos = torch.roll(roll_pos, shifts=1, dims=1)
    roll_pos_pos[:, :, 0, ...] = roll_pos[:, :, 0, ...]
    od13_pos_pos = (pdist(position, roll_pos_pos) - rest_length) * ks
    # pos neg
    roll_pos = torch.roll(position, shifts=1, dims=0)
    roll_pos[0, ...] = position[0, ...]
    roll_pos_neg = torch.roll(roll_pos, shifts=-1, dims=1)
    roll_pos_neg[:, :, -1, ...] = roll_pos[:, :, -1, ...]
    od13_pos_neg = (pdist(position, roll_pos_neg) - rest_length) * ks
    # neg pos
    roll_neg = torch.roll(position, shifts=-1, dims=0)
    roll_neg[-1, ...] = position[-1, ...]
    roll_neg_pos = torch.roll(roll_neg, shifts=1, dims=1)
    roll_neg_pos[:, :, 0, ...] = roll_neg[:, :, 0, ...]
    od13_neg_pos = (pdist(position, roll_neg_pos) - rest_length) * ks
    # neg neg
    roll_neg = torch.roll(position, shifts=-1, dims=0)
    roll_neg[-1, ...] = position[-1, ...]
    roll_neg_neg = torch.roll(roll_neg, shifts=1, dims=1)
    roll_neg_neg[:, :, -1, ...] = roll_neg[:, :, -1, ...]
    od13_neg_neg = (pdist(position, roll_neg_neg) - rest_length) * ks
    # object dimension diagonal 23
    # pos pos
    roll_pos = torch.roll(position, shifts=1, dims=0)
    roll_pos[:, 0, ...] = position[:, 0, ...]
    roll_pos_pos = torch.roll(roll_pos, shifts=1, dims=1)
    roll_pos_pos[:, :, 0, ...] = roll_pos[:, :, 0, ...]
    od23_pos_pos = (pdist(position, roll_pos_pos) - rest_length) * ks
    # pos neg
    roll_pos = torch.roll(position, shifts=1, dims=0)
    roll_pos[:, 0, ...] = position[:, 0, ...]
    roll_pos_neg = torch.roll(roll_pos, shifts=-1, dims=1)
    roll_pos_neg[:, :, -1, ...] = roll_pos[:, :, -1, ...]
    od23_pos_neg = (pdist(position, roll_pos_neg) - rest_length) * ks
    # neg pos
    roll_neg = torch.roll(position, shifts=-1, dims=0)
    roll_neg[:, -1, ...] = position[:, -1, ...]
    roll_neg_pos = torch.roll(roll_neg, shifts=1, dims=1)
    roll_neg_pos[:, :, 0, ...] = roll_neg[:, :, 0, ...]
    od23_neg_pos = (pdist(position, roll_neg_pos) - rest_length) * ks
    # neg neg
    roll_neg = torch.roll(position, shifts=-1, dims=0)
    roll_neg[:, -1, ...] = position[:, -1, ...]
    roll_neg_neg = torch.roll(roll_neg, shifts=1, dims=1)
    roll_neg_neg[:, :, -1, ...] = roll_neg[:, :, -1, ...]
    od23_neg_neg = (pdist(position, roll_neg_neg) - rest_length) * ks
    # sum up
    adjacent_forces = od1_neg + od1_pos + od2_neg + od2_pos + od3_neg + od3_pos
    diagonal_forces = od12_pos_pos + od12_pos_neg + od12_neg_pos + od12_neg_neg + od13_pos_pos + od13_pos_neg + od13_neg_pos + od13_neg_neg + od23_pos_pos + od23_pos_neg + od23_neg_pos + od23_neg_neg
    forces = adjacent_forces + diagonal_forces


def f_damp(position, velocity, kd):
    # object dimension 1 damp
    # object dimension 2 damp
    # object dimension 3 damp
    torch.zeros_like(position)


def f_wind(position, time):
    torch.zeros_like(position)