In [1]:
import torch
from torch import nn

In [2]:
# num flag object points
NUM_X_POINTS = 100
NUM_Y_POINTS = 70
NUM_Z_POINTS = 1
MASS = 1

INDEX_TO_POS = 100
REST_LENGTH = 1 / INDEX_TO_POS
KS = 1
KD = .8

In [3]:
# object dimensions grid
grid_od1, grid_od2, grid_od3 = torch.meshgrid(torch.arange(1, NUM_X_POINTS+1), torch.arange(1, NUM_Y_POINTS+1), torch.arange(1, NUM_Z_POINTS+1), indexing="ij")
position = torch.stack((grid_od1, grid_od2, grid_od3), dim=-1) / INDEX_TO_POS

In [4]:
# random z
position[..., 2] = (.5 - torch.rand_like(position[..., 2])) / 50
# zero z
# position[..., 2] = 0

In [5]:
velocity = torch.zeros_like(position)

In [6]:
def f_gravity(position):
    # gravity force
    result = torch.zeros_like(position)
    result[..., 2] = -9.8
    result[0, 0, 0], f_gravity[1, 0, 0], f_gravity[0, 1, 0]
    return result


def roll_with_repeat(tensor, shifts, dims):
    if shifts not in [-1, 1]:
        raise ValueError("Shifts should be -1 or +1.")
    if dims not in [0, 1, 2]:
        raise ValueError("Dims should be 0, 1, or 2.")
    # roll tensor
    if dims == 0 and shifts == -1:
        roll_tensor = torch.roll(tensor, shifts=-1, dims=0)
        roll_tensor[-1, ...] = tensor[-1, ...]
    if dims == 0 and shifts == 1:
        roll_tensor = torch.roll(tensor, shifts=1, dims=0)
        roll_tensor[0, ...] = tensor[0, ...]
    if dims == 1 and shifts == -1:
        roll_tensor = torch.roll(tensor, shifts=-1, dims=1)
        roll_tensor[:, -1, ...] = tensor[:, -1, ...]
    if dims == 1 and shifts == 1:
        roll_tensor = torch.roll(tensor, shifts=1, dims=1)
        roll_tensor[:, 0, ...] = tensor[:, 0, ...]
    if dims == 2 and shifts == -1:
        roll_tensor = torch.roll(tensor, shifts=-1, dims=2)
        roll_tensor[:, :, -1, ...] = tensor[:, :, -1, ...]
    if dims == 2 and shifts == 1:
        roll_tensor = torch.roll(tensor, shifts=1, dims=2)
        roll_tensor[:, :, 0, ...] = tensor[:, :, 0, ...]
    return roll_tensor


def roll_position_and_velocity(position, velocity, shifts, dims):
    return roll_with_repeat(position, shifts, dims), roll_with_repeat(velocity, shifts, dims)


def compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length):
    delta_position = position_roll - position
    delta_position_normalize = nn.functional.normalize(delta_position, dim=-1)
    delta_velocity = velocity - velocity_roll
    f_spring = (torch.linalg.norm(delta_position, dim=-1) - rest_length) * ks
    f_damper = (delta_position_normalize * delta_velocity).sum(-1) * kd
    return (f_spring - f_damper).unsqueeze(-1) * delta_position_normalize


def compute_internal_forces(position, velocity, ks, kd, rest_length):
    ##################
    # adjacent nodes #
    ##################
    # object dimension 1 spring
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, 1, 0)
    od1_pos = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, -1, 0)
    od1_neg = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # object dimension 2 spring
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, 1, 1)
    od2_pos = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, -1, 1)
    od2_neg = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # object dimension 3 spring
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, 1, 2)
    od3_pos = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, -1, 2)
    od3_neg = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    ##################
    # diagonal nodes #
    ##################
    # object dimension diagonal 12
    # pos pos
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, 1, 0)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, 1, 1)
    od12_pos_pos = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # pos neg
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, 1, 0)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, -1, 1)
    od12_pos_neg = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # neg pos
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, -1, 0)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, 1, 1)
    od12_neg_pos = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # neg neg
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, -1, 0)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, -1, 1)
    od12_neg_neg = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # object dimension diagonal 13
    # pos pos
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, 1, 0)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, 1, 2)
    od13_pos_pos = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # pos neg
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, 1, 0)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, -1, 2)
    od13_pos_neg = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # neg pos
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, -1, 0)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, 1, 2)
    od13_neg_pos = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # neg neg
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, -1, 0)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, -1, 2)
    od13_neg_neg = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # object dimension diagonal 23
    # pos pos
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, 1, 1)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, 1, 2)
    od23_pos_pos = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # pos neg
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, 1, 1)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, -1, 2)
    od23_pos_neg = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # neg pos
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, -1, 1)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, 1, 2)
    od23_neg_pos = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    # neg neg
    position_roll, velocity_roll = roll_position_and_velocity(position, velocity, -1, 1)
    position_roll, velocity_roll = roll_position_and_velocity(position_roll, velocity_roll, -1, 2)
    od23_neg_neg = compute_spring_and_damper_forces(position, position_roll, velocity, velocity_roll, ks, kd, rest_length)
    ##########
    # sum up #
    ##########
    adjacent_forces = od1_neg + od1_pos + od2_neg + od2_pos + od3_neg + od3_pos
    diagonal_forces = od12_pos_pos + od12_pos_neg + od12_neg_pos + od12_neg_neg + od13_pos_pos + od13_pos_neg + od13_neg_pos + od13_neg_neg + od23_pos_pos + od23_pos_neg + od23_neg_pos + od23_neg_neg
    forces = adjacent_forces + diagonal_forces
    return forces


def f_wind(position, time):
    torch.zeros_like(position)

# simulation

In [7]:
import matplotlib.pyplot as plt
from tqdm import trange

delta_t = .01
oversampling_factor = 100
steps = 1000 * oversampling_factor

plot_2d = False
t = 0
for step in trange(steps):
    t += delta_t
    # compute force and acceleration
    forces = compute_internal_forces(position, velocity, KS, KD, REST_LENGTH)
    # condition: fix header
    forces[0] = 0
    # update velocity
    velocity += forces / MASS * delta_t
    # update position
    position += velocity * delta_t
    # save result
    if step % oversampling_factor == 0:
        torch.save(position.cpu(), f"/tmp/position{step//oversampling_factor:04d}.pt")
    if plot_2d:
        # save graph
        plt.clf()
        plt.ion()
        fig, ax = plt.subplots()
        ax.set_xlim(0, 3)
        ax.set_ylim(0, 3)
        line, = ax.plot(position[..., 0].view(-1), position[..., 1].view(-1), 'bo', lw=2)
        plt.savefig(f"/tmp/frame{step:04d}.png")
        plt.close()

# ffmpeg -framerate 30 -i frame%04d.png -c:v libx264 -pix_fmt yuv420p output.mp4

100%|██████████| 100000/100000 [12:04<00:00, 138.12it/s]
