In [1]:
%load_ext autoreload
%autoreload 2
%aimport -jax
%aimport -jaxlib

In [2]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(),"..","projects")))
print(sys.path)


['/home/daniel/Documents/code/stanza/notebooks', '/usr/lib/python310.zip', '/usr/lib/python3.10', '/usr/lib/python3.10/lib-dynload', '', '/home/daniel/Documents/code/stanza/.venv/lib/python3.10/site-packages', '/home/daniel/Documents/code/stanza', '/home/daniel/Documents/code/stanza/projects']


In [3]:
import jax.numpy as jnp
import jax
from jax.random import PRNGKey

In [4]:
import stanza.envs.pusht as pusht
import stanza.envs as envs
env = envs.create("pusht")
data = pusht.expert_data()
data_sample = data[0][0]
print(data_sample)

Timestep(observation=PushTPositionObs(agent_pos=Array([222.,  97.], dtype=float32), block_pos=Array([222.99382, 381.59903], dtype=float32), block_rot=Array(3.0079994, dtype=float32)), action=Array([233.,  71.], dtype=float32))


In [5]:
from diffusion_policy.networks import pusht_net
from stanza.model.diffusion import DDPMSchedule
from functools import partial

if True:
    sigma_params = {
        0: jnp.load("../../pretrained_lowsample.npy", allow_pickle=True).item(),
        2: jnp.load("../../pretrained_lowsample_2.npy", allow_pickle=True).item(),
        4: jnp.load("../../pretrained_lowsample_4.npy", allow_pickle=True).item(),
        8: jnp.load("../../pretrained_lowsample_8.npy", allow_pickle=True).item(),
        1000: jnp.load("../../pretrained_lowsample_1000.npy", allow_pickle=True).item()
    }
else:
    sigma_params = {
         0: jnp.load("../../pretrained_highsample.npy", allow_pickle=True).item(),
         2: jnp.load("../../pretrained_highsample_2.npy", allow_pickle=True).item(),
         4: jnp.load("../../pretrained_highsample_4.npy", allow_pickle=True).item(),
         8: jnp.load("../../pretrained_highsample_8.npy", allow_pickle=True).item()
    }

action_norm = jnp.load("../../action_norm", allow_pickle=True).item()
obs_norm = jnp.load("../../obs_norm", allow_pickle=True).item()

from stanza import Partial
diffuser = DDPMSchedule.make_squaredcos_cap_v2(100, clip_sample_range=1)

action_sample_traj = jax.tree_util.tree_map(
    lambda x: jnp.repeat(jnp.expand_dims(x, 0), 16, axis=0),
    data_sample.action
)

@jax.jit
def sample(params, obs, sample_sigma, rng_key, sigma_rng):
    norm = obs_norm.normalize(obs)
    if sample_sigma is not None:
        flat, uf = jax.flatten_util.ravel_pytree(norm)
        flat = flat + sample_sigma*jax.random.normal(sigma_rng, flat.shape)
        norm = uf(flat)
    model_fn = Partial(pusht_net.apply, params, cond=norm)
    sample = diffuser.sample(rng_key, model_fn,
            action_sample_traj, 
            num_steps=diffuser.num_steps)
    action = action_norm.unnormalize(sample)
    return action

@partial(jax.jit,static_argnums=0)
def multi_sample(model_sigma, obs, sigma, rng_keys, sigmas_rngs):
    params = sigma_params[model_sigma]
    return jax.vmap(sample, in_axes=(None, None, None, 0, 0))(params, obs, sigma, rng_keys, sigmas_rngs)

obs = jax.tree_util.tree_map(lambda a, b: jnp.stack((a,b)), data[10][0].observation, data[10][1].observation)

In [10]:
import plotly.graph_objects as go
import itertools

def plot_samples(rng_key, obs, sample_sigma=None, model_sigma=0, color='blue'):
    sigma_rng_key, model_rng_key = jax.random.split(rng_key, 2)
    sample_rngs = jax.random.split(sigma_rng_key, 10)
    rngs = jax.random.split(model_rng_key, 10)
    targets = multi_sample(model_sigma, obs, sample_sigma, rngs, sample_rngs)
    for i in range(10):
        yield go.Scatter(
            x=targets[i,:,0],
            y=targets[i,:,1], 
            marker=dict(color=color),
            showlegend=i==0,
            opacity=0.5,
            name=f"sigma={int(sample_sigma*250)}"
        )

sigmas = [0, 4/250, 8/250, 16/250]
colors = ['blue', 'red', 'green','orange', 'turquoise']
def plot_noises(rng_key, obs, model_sigma=0):
    keys = jax.random.split(rng_key, len(sigmas))
    for k, s, c in zip(keys,sigmas,colors):
        yield from plot_samples(k, obs, s, model_sigma, color=c)
        
traces = list(plot_noises(PRNGKey(41), obs, model_sigma=1000))
go.Figure(traces)

In [7]:
import ot
import numpy as np
from stanza.util.random import PRNGSequence

def calc_ot(samples_a, samples_b):
    samples_a = samples_a.reshape((samples_a.shape[0], -1))
    samples_b = samples_b.reshape((samples_b.shape[0], -1))
    M = ot.dist(samples_a, samples_b, 'euclidean')
    d_emd = ot.emd2([], [], M)
    return d_emd

def calc_dist(model_sigma, rng_key, obs, dist, perturb_n=15, dist_n=10):
    rng_key, sk = jax.random.split(rng_key)
    rngs = jax.random.split(rng_key, perturb_n*dist_n)
    base_rngs = jax.random.split(sk, perturb_n*dist_n)
    perturb_rngs = jax.random.split(rng_key, perturb_n)
    perturb_rngs = jnp.repeat(jnp.expand_dims(perturb_rngs,0),dist_n,axis=0)
    perturb_rngs = perturb_rngs.reshape((-1,perturb_rngs.shape[-1]))
    base_samples = multi_sample(model_sigma, obs, None, base_rngs, perturb_rngs)
    base_samples = jax.tree_util.tree_map(
        lambda x: x.reshape((perturb_n,-1) + x.shape[1:]),
        base_samples
    )
    samples = multi_sample(model_sigma, obs, dist, rngs, perturb_rngs)
    samples = jax.tree_util.tree_map(
        lambda x: x.reshape((perturb_n,-1) + x.shape[1:]),
        samples
    )
    dists = []
    for i in range(base_samples.shape[0]):
        bs = base_samples[i]
        s = samples[i]
        dists.append(calc_ot(np.array(bs),np.array(s)))
    return np.mean(dists)

print(calc_dist(0, PRNGKey(2), obs, 0, 5, 10))
    

34.85783641815186


In [9]:
def plot_model_dists(rng_key, model_sigma, color='blue'):
    xs = [0, 1, 2, 4, 8]
    ys = []
    for x in xs:
        ys.append(calc_dist(model_sigma, rng_key, obs, x/250))
    yield go.Scatter(
            x=xs,
            y=ys, 
            marker=dict(color=color),
            showlegend=True,
            opacity=0.5,
            name=f"train sigma={model_sigma}"
    )

def plot_all_dists(rng_key):
    rks = jax.random.split(rng_key, len(sigma_params))
    for (r,s,c) in zip(rks, sigma_params.keys(), colors):
        yield from plot_model_dists(r, s, color=c)
traces = list(plot_all_dists(PRNGKey(21)))
fig = go.Figure(traces)  
fig.update_layout(xaxis_title="Sample Sigma", yaxis_title="OT Distance")