In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import plb
from plb.envs import make
import taichi as ti

env = make("Move-v1")

env.reset()
taichi_env = env.taichi_env

[Taichi] mode=release
[Taichi] preparing sandbox at /tmp/taichi-520yfvs6
[Taichi] version 0.7.14, llvm 10.0.0, commit 58feee37, linux, python 3.7.3
[Taichi] Starting on arch=cuda
Building primitive
action:
  dim: 3
  scale: (0.01, 0.01, 0.01)
color: (0.7, 0.7, 0.7)
friction: 0.9
init_pos: (0.5757143040494873, 0.5619162002773135, 0.7515980438048129)
init_rot: (1.0, 0.0, 0.0, 0.0)
lower_bound: (0.0, 0.0, 0.0)
radius: 0.03
shape: Sphere
upper_bound: (1.0, 1.0, 1.0)
variations: None
Building primitive
action:
  dim: 3
  scale: (0.01, 0.01, 0.01)
color: (0.7, 0.7, 0.7)
friction: 0.9
init_pos: (0.7757143040494873, 0.5619162002773135, 0.7515980438048129)
init_rot: (1.0, 0.0, 0.0, 0.0)
lower_bound: (0.0, 0.0, 0.0)
radius: 0.03
shape: Sphere
upper_bound: (1.0, 1.0, 1.0)
variations: None
{'radius': 0.1024534880385289, 'init_pos': (0.6757143040494873, 0.5619162002773135, 0.7515980438048129), 'color': 8323072}
Initialize Renderer
bake_size: 6  
camera_pos: (0.5, 1.2, 4.0)  
camera_rot: (0.2, 0)  


In [3]:
def make_copy_and_clear_kernel(sim: plb.engine.mpm_simulator.MPMSimulator):
    @ti.kernel
    def copy_and_clear(f:ti.i32):
        for i in range(sim.n_particles):
            sim.x.grad[f, i] = sim.x.grad[0, i]
            sim.v.grad[f, i] = sim.v.grad[0, i]
            sim.F.grad[f, i] = sim.F.grad[0, i]
            sim.C.grad[f, i] = sim.C.grad[0, i]

        if ti.static(sim.n_primitive>0):
            for i in ti.static(range(sim.n_primitive)):
                sim.primitives[i].position.grad[f] = sim.primitives[i].position.grad[0]
                sim.primitives[i].rotation.grad[f] = sim.primitives[i].rotation.grad[0]


        for s in range(f):
            # clear
            if ti.static(sim.n_primitive>0):
                for i in ti.static(range(sim.n_primitive)):
                    for j in ti.static(range(3)):
                        sim.primitives[i].position.grad[s][j] = 0
                        sim.primitives[i].v.grad[s][j] = 0
                        sim.primitives[i].w.grad[s][j] = 0
                    for j in ti.static(range(4)):
                        sim.primitives[i].rotation.grad[s][j] = 0

        for i in range(sim.n_particles):
            for s in range(f):
                sim.x.grad[s, i] = ti.Vector.zero(sim.dtype, sim.dim)
                sim.v.grad[s, i] = ti.Vector.zero(sim.dtype, sim.dim)
                sim.F.grad[s, i] = ti.Vector.zero(sim.dtype, sim.dim, sim.dim)
                sim.C.grad[s, i] = ti.Vector.zero(sim.dtype, sim.dim, sim.dim)

    return copy_and_clear

_copy_and_clear = make_copy_and_clear_kernel(taichi_env.simulator)
def copy_and_clear(f):
    #for s in range(f//taichi_env.simulator.substeps+1):
    #    for i in taichi_env.primitives:
    #        i.action_buffer.grad[s] = [0, 0, 0]
    for i in taichi_env.primitives:
        i.action_buffer.grad.fill(0)
    _copy_and_clear(f)
copy_and_clear(21)

In [4]:
import taichi as ti
import numpy as np
from plb.engine.taichi_env import TaichiEnv
env.reset()
state = taichi_env.get_state()
sim_state = state['state']
actions = np.random.random(size=(50, 6))*0.01

def forward(taichi_env: TaichiEnv, sim_state, action):
    taichi_env.set_state(sim_state, 666, False)
    with ti.Tape(loss=taichi_env.loss.loss):
        for i in range(len(action)):
            taichi_env.step(action[i])
            loss_info = taichi_env.compute_loss()
    loss = taichi_env.loss.loss[None]
    return loss, taichi_env.primitives.get_grad(len(action))

loss, grad = forward(taichi_env, sim_state, actions)

print(loss)

663.3039895763777


In [5]:
def forward2(taichi_env: TaichiEnv, init_state, action, T=12, compute_grad=True):
    max_timesteps = len(action)
    substeps = taichi_env.simulator.substeps

    checkpoints = {}
    loss = taichi_env.loss.loss
    total_loss = 0
    taichi_env.set_state(init_state, 666, False) #clear loss..

    with ti.Tape(loss=loss):
        # clear grad and normal..
        pass
    for i in range(max_timesteps):
        if i % T == 0:
            state = init_state if i == 0 else taichi_env.simulator.get_state(T * substeps)
            checkpoints[i] = state
            total_loss += loss[None]
            taichi_env.set_state(checkpoints[i], 666, False) #set_state clears loss
        taichi_env.step(action[i])
        taichi_env.compute_loss()
    total_loss += loss[None]

    if compute_grad:
        total_loss2 = loss[None]

        action_grads = []
        last = max_timesteps
        for i in range(max_timesteps-1, -1, -1):
            f = (i % T) * substeps
            taichi_env.loss.compute_loss_kernel.grad(taichi_env.loss, f + substeps)
            for s in reversed(range(f, f+substeps)):
                taichi_env.simulator.substep_grad(s)
            # no gradient for set action..
            for p in taichi_env.primitives:
                p.set_velocity.grad(i % T, substeps)

            if i % T == 0:
                action_grads.append(taichi_env.primitives.get_grad(last-i))
                last = i

                if i > 0:
                    start = ((i - 1) // T) * T
                    taichi_env.set_state(checkpoints[start], 666, False)
                    for s in range(start, i):
                        taichi_env.step(action[s])
                        taichi_env.compute_loss()

                    total_loss2 += loss[None]
                    copy_and_clear(T * substeps)

    return total_loss, np.concatenate(action_grads[::-1])

loss2, grad2 = forward2(taichi_env, sim_state, actions, T=5, compute_grad=True)

assert np.allclose(loss, loss2)
print(abs(grad -grad2).max())
assert abs(grad2 - grad).max() < 1e-4

1.534771548383773e-05
