In [None]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
import tensorflow as tf
#tf.config.experimental.set_visible_devices([], "GPU")

import importlib
from simulation_research.diffusion import ode_datasets
from simulation_research.diffusion import diffusion_unet
from simulation_research.diffusion import samplers
from simulation_research.diffusion import diffusion as train
importlib.reload(ode_datasets)
importlib.reload(diffusion_unet)
importlib.reload(samplers)
importlib.reload(train)

import matplotlib.pyplot as plt
from matplotlib import rc
rc('animation', html='jshtml')
import jax.numpy as jnp
import numpy as np
import jax

In [None]:
dt = 6.
bs = 400
ds = ode_datasets.FitzHughDataset(N=4000+bs,dt=dt,integration_time=3000)

train_x = ds.Zs[bs:,:60]
test_x = ds.Zs[:bs,:60]
T_long =ds.T_long[:60]
dataset = tf.data.Dataset.from_tensor_slices(train_x)

dataiter = dataset.shuffle(len(dataset)).batch(bs).as_numpy_iterator

In [None]:
plt.plot(T_long,train_x[:300,:,:2].sum(-1).T/2)
plt.xlabel('Time (s)')
plt.ylabel(r'$\bar{x}$')

In [None]:
jnp.abs(train_x).max()

In [None]:
x = test_x#next(dataiter())
t = np.random.rand(x.shape[0])
model = diffusion_unet.UNet(diffusion_unet.unet_64_config(out_dim=x.shape[-1],base_channels=24))

In [None]:
from jax import jit,vmap
@jit
def rel_err(x,y):
  return  jnp.abs(x-y).sum(-1)/(jnp.abs(x).sum(-1)+jnp.abs(y).sum(-1))


kstart=10
@jit
def log_prediction_metric(qs):
  k=kstart
  z = q = qs[k:]
  T = T_long[k:]
  z_gt = ds.integrate(z[0],T)
  return jnp.log(rel_err(z,z_gt)[1:len(T)//3]).mean()

@jit
def pmetric(qs):
  log_metric = vmap(log_prediction_metric)(qs)
  return jnp.exp(log_metric.mean()),jnp.exp(log_metric.std()/jnp.sqrt(log_metric.shape[0]))

In [None]:
noisetype='White'#@param ['White','Pink','Brown']
noise = {'White':train.Identity,'Pink':train.PinkCovariance,'Brown':train.BrownianCovariance}[noisetype]
difftype='VE'#@param ['VP','VE','SubVP','Test']
diff = {'VP':train.VariancePreserving,'VE':train.VarianceExploding,
        'SubVP':train.SubVariancePreserving,'Test':train.Test}[difftype](noise)
epochs = 2000#@param {'type':'integer'}
score_fn = train.train_diffusion(model,dataiter,epochs,diffusion=diff,lr=3e-4)
key= jax.random.PRNGKey(38)
nll = samplers.compute_nll(diff,score_fn,key,x).mean()
stoch_samples = samplers.stochastic_sample(diff,score_fn,key,x[:30].shape,N=1000,traj=False)
err = pmetric(stoch_samples)[0]
print(f"{noise.__name__} gets NLL {nll:.3f} and err {err:.3f}")

In [None]:
z = jnp.linspace(-3,3,100)
plt.plot(z,jax.scipy.stats.norm.cdf(z),label='probit')
plt.plot(z,jax.nn.sigmoid(1.6*z),label='logit')
plt.legend()

In [None]:
import jax 
jax.config.update('jax_default_matmul_precision', 'float32')

In [None]:
from jax import grad,jit
condition_amount = 13
mb = x[:30,:]

def event_constraint(x):
    C = jnp.max(x[...,:2].mean(-1),-1)-2
    return C

def statistic(x):
    return jnp.max(x[...,:2].mean(-1),-1)

In [None]:
sample_traj = samplers.stochastic_sample(diff,score_fn,key,x[:400].shape,N=1000,traj=True)

In [None]:
#from train import unsqueeze_like
diffusion=diff
scorefn =score_fn#event_scores(diff,score_fn,event_constraint)
constraint=event_constraint
def xhat(xt,t):
    #print(xt.shape,t.shape)
    tt = train.unsqueeze_like(xt,t)
    dt = .00
    score_xhat1 = (xt+diffusion.sigma(tt+dt)**2*scorefn(xt,t+dt))/diffusion.scale(tt+dt)
    #score_xhat2 = (xt+diffusion.sigma(tt-dt)**2*scorefn(xt,t-dt))/diffusion.scale(tt-dt)
    #limiting_xhat = (xt/(1+diffusion.sigma(tt)**2/data_std**2))/diffusion.scale(tt)
    #m1 = (t+dt<=1)+0.
    #m2 = (t-dt>=0)+0.
    #m1,m2 = train.unsqueeze_like(xt,m1,m2)
    return (score_xhat1)#*m1)#+score_xhat2*m2)/(m1+m2)
def cstd(xt,t):
    xh = xhat(xt,t)
    C,DC = vmap(jax.value_and_grad(constraint))(xh)
    SigmaDC = vmap(jax.grad(lambda x,t: constraint(xhat(x[None],t)[0])))(xt,t)
    std2 = ((DC*SigmaDC).sum((-1,-2))*diffusion.scale(t))# NOTE: will not work with img inputs
    std3 = (DC*DC).sum((-1,-2))*diffusion.scale(t)
    std2 = jnp.sqrt(jnp.abs(std2**2))
    std = jnp.sqrt(jnp.abs(std2)+1e-4)*(diff.sigma(t)/diff.scale(t))
    return C,std

def log_p(xt,t):
    Cs,stds = cstd(xt,t)
    return jax.nn.log_sigmoid(1.6*Cs/stds).sum()
    #return jax.scipy.stats.norm.logcdf(Cs/stds).sum()

In [None]:
#jnp.where(event_constraint(sample_traj[-1])>0)

In [None]:
import numpy as np
N=2000
ts = (.5+np.arange(N)[::-1])[:-1:4]/N
#xt = sample_traj[:,77]
i = 195 #@param {type:"slider", min:0, max:200, step:1}
#xt = sample_traj[:,i] #
xt = event_samples_traj[::4,6]
Cs,stds = cstd(xt,ts)
#Cs,stds = log_p(event_samples_traj[:,1],ts)
grads = grad(log_p)(xt,ts)
xh = xhat(xt,ts)

In [None]:
tt = train.unsqueeze_like(xt,ts)
normed_scores = diffusion.sigma(tt)*scorefn(xt,ts)
normed_scores2 = diffusion.sigma(tt)*event_scores(diff,score_fn,event_constraint)(xt,ts)

In [None]:

plt.plot(ts,jnp.sqrt((normed_scores2**2).mean((-1,-2))),label=r'Normed Score w/ event constraint')
plt.plot(ts,jnp.sqrt((normed_scores**2).mean((-1,-2))),label=r'Normed Score')
plt.xlabel('t')
plt.legend()
plt.yscale('log')

In [None]:
xh = jnp.where(jnp.isnan(xh),jnp.zeros_like(xh),xh)

In [None]:
xnorm = jnp.sqrt((xh*xh).mean((-1,-2)))
plt.plot(ts,xnorm)
plt.plot(ts,diff.scale(ts)*data_std**2/jnp.sqrt((diff.scale(ts)*data_std)**2+diff.sigma(ts)**2))
plt.plot(ts,diff.sigma(ts)/100)
plt.legend([r'$||\hat{x}_t||$',r'$sa^2/\sqrt{s^2a^2+\sigma^2}$',r'$\sigma_t/100$'])
plt.xlabel('t')
plt.yscale('log')
plt.ylim(.5*jnp.min(xnorm),2*jnp.max(xnorm))

In [None]:
import matplotlib.pyplot as plt

import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable

i=4 #@param {type:"slider", min:0, max:30, step:1}

cmap='inferno'


fig1 = plt.figure()
ax1 = fig1.add_subplot(111)
data = jnp.sqrt((xh[:-50:5,:]**2).mean(-1)).T
ax1.plot(T_long,data[:],alpha=.6,lw=2)
colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))
#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]
for i,j in enumerate(ax1.lines):
    j.set_color(colors[i])
plt.xlabel('Time t')
plt.ylabel(r'State')
plt.ylim(-1,7)
divider = make_axes_locatable(plt.gca())
ax_cb = divider.new_horizontal(size="5%", pad=0.05)
norm = mpl.colors.Normalize(vmin=ts[0], vmax=ts[-50])    
cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)
#cb1.ax.invert_yaxis()
cb1.set_label('diffusion time (0,1)')
plt.gcf().add_axes(ax_cb)

In [None]:
grad_norm = jnp.sqrt((grads**2).sum((-1,-2)))

In [None]:
plt.plot(ts,Cs,label='Cs',alpha=.5)
plt.plot(ts,stds,label='sigma',alpha=.5)
#plt.plot(ts,Cs/stds,label='ratio',alpha=.5)
plt.legend()
plt.ylim(-2,8)
plt.xlabel('t')

In [None]:
plt.plot(ts,jnp.abs(Cs/stds))
plt.yscale('log')
plt.xlabel('t')
plt.ylabel('C/std')

In [None]:
plt.plot(ts,stds)
plt.plot(ts,diff.sigma(ts))
plt.plot(ts,grad_norm)
plt.plot(ts,grad_norm*diff.sigma(ts))
plt.ylim(1e-4,1e3)
plt.yscale('log')
plt.xlabel('t')
plt.legend([r'$\sqrt{\nabla C^T\Sigma_t\nabla C}$',r'$\sigma_t$',r'$\nabla \log \Phi$',r'$\sigma_t \nabla \log \Phi$'])

In [None]:
plt.plot(ts,jax.scipy.stats.norm.cdf(Cs/stds),label='probit')
plt.plot(ts,jax.nn.sigmoid(1.6*Cs/stds),label='logit')
plt.ylabel('P(E|xt)')
plt.xlabel('t')
plt.yscale('log')
plt.legend()

In [None]:
data_std = test_x.std()
def event_scores(diffusion,scorefn,constraint):
  """ Conditions on inequality constraint C(x)>0"""
  def xhat(xt,t):
    tt = train.unsqueeze_like(xt,t)
    score_xhat = (xt+diffusion.sigma(tt)**2*scorefn(xt,t))/diffusion.scale(tt)
    return score_xhat

  def conditioned_scores(xt,t):
    b,n,c = xt.shape
    unflat_xt = xt.reshape(b,-1,c)
    unobserved_score = scorefn(xt,t).reshape(b,-1,c)
    if not hasattr(t,'shape') or not len(t.shape):
      tt = t*jnp.ones(b)
    else:
      tt = t
    def log_p(xt):
      xh = xhat(xt,tt)
      C,DC = vmap(jax.value_and_grad(constraint))(xh)#.reshape(b,-1,n*c)
      SigmaDC = vmap(jax.grad(lambda x,t: constraint(xhat(x[None],t)[0])))(xt,tt)
      std2 = ((DC*SigmaDC).sum((-1,-2))*diffusion.scale(t))# NOTE: will not work with img inputs
      std3 = (DC*DC).sum((-1,-2))*diffusion.scale(t)
      std2 = jnp.sqrt(jnp.abs(std2*std2))
      std = jnp.sqrt(jnp.abs(std2)+1e-2)*(diff.sigma(t)/diff.scale(t))
      #reg = 1e-5*jnp.eye(sig.shape[-1])[None]/(1+diffusion.sigma(t)**2/data_std**2)[:,None,None]
      #10*diff.scale(t)**2/diff.sigma(t)**2
      return jax.nn.log_sigmoid(1.6*C/std).sum()
      #return jax.scipy.stats.norm.logcdf(C/std).sum()
    unobserved_score += grad(log_p)(xt)#.reshape(unflat_xt.shape)
    return unobserved_score
  return jit(conditioned_scores)

#event_samples = samplers.stochastic_sample(diff,event_scores(diff,score_fn,event_constraint),key,mb.shape,N=2000,traj=False)
#event_samples_traj = samplers.stochastic_sample(diff,event_scores(diff,score_fn,event_constraint),key,mb.shape,N=2000,traj=True)
#event_samples_det = samplers.sample(diff,event_scores(diff,score_fn,event_constraint),key,mb.shape)

In [None]:
event_samples_det = samplers.sample(diff,event_scores(diff,score_fn,event_constraint),key,mb.shape)

In [None]:
ds.Zs[(event_constraint(ds.Zs[:,:60])>0)].shape

In [None]:
from scipy.ndimage import correlate1d
inp = stoch_samples#train_x
ode = vmap(vmap(ds.dynamics))
for inp in [stoch_samples,train_x,event_samples,ds.Zs[(event_constraint(ds.Zs[:,:60])>0)][:,:60],stoch_samples[event_constraint(stoch_samples)>0],event_samples_det]:
  v = correlate1d(inp,np.array([-1,0,1])/2/(ds.T[1]-ds.T[0]),axis=1)
  F = ode(inp,None)
  print(rel_err(F[:,2:-2],v[:,2:-2]).mean())
  #print(F.shape)
  plt.plot(F[0,:,-1])
  plt.plot(v[0,:,-1])
  plt.xlabel('timesteps')
  plt.legend(['ODE F(z,t)','Finite diff dz/dt'])
  plt.show()

In [None]:
(event_constraint(event_samples_traj[-1])>0).mean()

In [None]:
(event_constraint(event_samples_det)>0).mean()

In [None]:
(event_constraint(event_samples)>0).mean()

In [None]:
from simulation_research.diffusion import samplers

prior_scale = 112/300

logp = samplers.logp(diff,score_fn,key,event_samples[5:10],prior_scale,num_probes=4)

In [None]:
logp

In [None]:
from functools import partial
def f(z):
  N=1000
  timesteps = (.5+np.arange(N)[::])/N
  scores = score_fn#event_scores(diff,score_fn,event_constraint)
  z0,_ = samplers.heun_integrate2(jit(partial(diffusion.dynamics,scores)),z,timesteps)
  return z0


In [None]:
z0 = f(event_samples[5:6])

In [None]:
z0.std()

In [None]:
J = jax.jacfwd(f)(event_samples[5:6])

In [None]:
std_max = diffusion.sigma(diffusion.tmax)#*prior_scale
logpxf = -(z0.reshape(z0.shape[0],-1)**2/std_max**2 + jnp.log(2*np.pi*std_max**2)).sum(-1)/2
s,logdet = jnp.linalg.slogdet(J.reshape(240,240))#+1e-3*jnp.eye(240))
logpa = logpxf+logdet
print(logpa)

In [None]:
[562.35547]

In [None]:
[977.17847]

In [None]:
jnp.exp(973.3657-977.17847)

In [None]:
logpxf,logdet

In [None]:
conditional_logp = samplers.logp(diff,event_scores(diff,score_fn,event_constraint),key,event_samples[5:10],prior_scale,num_probes=4)

In [None]:
conditional_logp

In [None]:
jnp.exp(logp-conditional_logp)

In [None]:
logp2 = samplers.logp(diff,score_fn,key,event_samples[5:10],prior_scale,num_probes=300)

In [None]:
logp3 = samplers.logp(diff,score_fn,key,event_samples[5:10],prior_scale,num_probes=1000)

In [None]:
logp3

In [None]:
logp2

In [None]:
stoch_samples = samplers.stochastic_sample(diff,score_fn,key,x[:400].shape,N=2000,traj=False)
sample_traj = samplers.stochastic_sample(diff,score_fn,key,x[:400].shape,N=2000,traj=True)
det_samples = samplers.sample(diff,score_fn,key,x[:400].shape)

In [None]:
true_events = (event_constraint(ds.Zs[:,:60])>0)
model_events_ode = (event_constraint(det_samples)>0)
model_events_sde = (event_constraint(stoch_samples)>0)
print(f"True event rate        {true_events.mean():.3f}+-{true_events.std()/jnp.sqrt(len(true_events)):.3f}")
print(f"model event rate (ODE) {model_events_ode.mean():.3f}+-{model_events_ode.std()/jnp.sqrt(len(model_events_ode)):.3f}")
print(f"model event rate (SDE) {model_events_sde.mean():.3f}+-{model_events_sde.std()/jnp.sqrt(len(model_events_sde)):.3f} (2k steps)")

In [None]:
vals = np.array(statistic(ds.Zs[:,:60]))
vals2 = np.array(statistic(event_samples))
vals3 = np.array(statistic(ds.Zs[(event_constraint(ds.Zs[:,:60])>0),:60]))
vals4 = np.array(statistic(event_samples_det))
plt.hist(vals2,bins=30,density=True,alpha=.5)
plt.hist(vals4,bins=30,density=True,alpha=.5)
plt.hist(vals,bins=30,density=True,alpha=.5)

plt.hist(vals3,bins=30,density=True,alpha=.5)
plt.legend(['Model x|E (SDE)','Model x|E (ODE)','Data x','Data x|E'])
plt.xlim(0,6)

In [None]:
vals = np.array(statistic(ds.Zs[:,:60]))

plt.hist(vals4,bins=30,density=True)
plt.hist(vals,bins=30,density=True)
vals4 = np.array(statistic(event_samples_det))
plt.xlabel('Maximum value over trajectory')
plt.ylabel('Frequency')

In [None]:
import matplotlib.pyplot as plt
true_events = ds.Zs[(event_constraint(ds.Zs[:,:60])>0),:60]
i=28 # @param {type:"slider", min:0, max:30, step:1}
#plt.plot(T_long,conditioned_samples[-600::100,i,:,0].T,zorder=0,alpha=.2)
plt.plot(T_long,event_samples[i,  :,0].T,label='x|E model sde',zorder=2)
plt.plot(T_long,event_samples_det[i,  :,0].T,label='x|E model ode',zorder=2)
plt.plot(T_long,true_events[i,:,0],label='gt',alpha=1,zorder=99)
#plt.plot(T_long[slc],x[i,slc,0],label='cond',alpha=1,zorder=100,lw=3)

plt.xlabel('Time t')
plt.ylabel(r'State')
#plt.ylim(-3,3)
plt.legend()
#plt.legend([r'GT',r'Model'])

In [None]:

importlib.reload(samplers)
importlib.reload(train)
#samplers.probability_flow(diff,score_fn,x,1e-4,1.).std()

In [None]:
import jax
key= jax.random.PRNGKey(38)
samplers.compute_nll(diff,score_fn,key,x).mean()

Sample generation

In [None]:
import matplotlib.pyplot as plt
i=6 #@param {type:"slider", min:0, max:30, step:1}
plt.plot(T_long,sample_traj[0::100,i,:,0].T,alpha=1/2)
plt.xlabel('Time t')
plt.ylabel(r'State')
#plt.ylim(-5,5)
#plt.legend([r'GT',r'Model'])

In [None]:
from jax import vmap
n=sample_traj.shape[0]+1
ts = (.5+jnp.arange(n)[::-1])[:-1]/n
scores = vmap(score_fn)(sample_traj,ts).reshape(sample_traj.shape)
best_reconstructions = (sample_traj+diff.sigma(ts)[:,None,None,None]**2*scores)/diff.scale(ts)[:,None,None,None]

In [None]:
import matplotlib.pyplot as plt

import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable

i=4 #@param {type:"slider", min:0, max:30, step:1}

cmap='inferno'


fig1 = plt.figure()
ax1 = fig1.add_subplot(111)
data = best_reconstructions[100::25,i,:,-1].T
ax1.plot(T_long,data[:],alpha=.6,lw=2)
colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))
#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]
for i,j in enumerate(ax1.lines):
    j.set_color(colors[i])
plt.xlabel('Time t')
plt.ylabel(r'State')
#plt.ylim(-2,2)
divider = make_axes_locatable(plt.gca())
ax_cb = divider.new_horizontal(size="5%", pad=0.05)
norm = mpl.colors.Normalize(vmin=ts[100], vmax=ts[-25])    
cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)
#cb1.ax.invert_yaxis()
cb1.set_label('diffusion time (0,1)')
plt.gcf().add_axes(ax_cb)

In [None]:
from scipy.ndimage import correlate1d
i=22 #@param {type:"slider", min:0, max:30, step:1}
vs = -correlate1d(best_reconstructions,np.array([-1,0,1])/2/(ds.T[1]-ds.T[0]),axis=2)
print(vs.shape)
fig1 = plt.figure()
ax1 = fig1.add_subplot(111)
data = vs[100::25,i,:,-1].T
ax1.plot(T_long,data[:],alpha=.6,lw=2)
colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))
#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]
for i,j in enumerate(ax1.lines):
    j.set_color(colors[i])
plt.xlabel('Time t')
plt.ylabel(r'$\dot \theta$')
#plt.ylim(-2,2)
divider = make_axes_locatable(plt.gca())
ax_cb = divider.new_horizontal(size="5%", pad=0.05)
norm = mpl.colors.Normalize(vmin=ts[100], vmax=ts[-25])    
cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)
#cb1.ax.invert_yaxis()
cb1.set_label('diffusion time (0,1)')
plt.gcf().add_axes(ax_cb)

In [None]:
i=15 # @param {type:"slider", min:0, max:30, step:1}
nn = sample_traj.shape[2]
fft = jnp.abs(np.fft.rfft(sample_traj,axis=2))#[:,:,:nn//2]
freq = np.fft.rfftfreq(sample_traj.shape[2],d=(ds.T[1]-ds.T[0]))#[:nn//2]

fig1 = plt.figure()
ax1 = fig1.add_subplot(111)
data = fft[0::25,i,:,-1].T
ax1.plot(freq,data[:,:],alpha=.6,lw=2)
colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))
#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]
for i,j in enumerate(ax1.lines):
    j.set_color(colors[i])
plt.xlabel('Frequency f')
plt.ylabel(r'Fourier spectrum')
plt.yscale('log')
plt.xscale('log')
#plt.ylim(-2,2)
divider = make_axes_locatable(plt.gca())
ax_cb = divider.new_horizontal(size="5%", pad=0.05)
norm = mpl.colors.Normalize(vmin=ts[0], vmax=ts[-25])    
cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)
#cb1.ax.invert_yaxis()
cb1.set_label('diffusion time (0,1)')
plt.gcf().add_axes(ax_cb)
ax1.plot(freq,jnp.abs(np.fft.rfft(x,axis=1))[::10,:,-1].T,color='blue',alpha=.1);

In [None]:
i=8 # @param {type:"slider", min:0, max:30, step:1}
nn = best_reconstructions.shape[2]
fft = jnp.abs(np.fft.rfft(best_reconstructions,axis=2))#[:,:,:nn//2]
freq = np.fft.rfftfreq(best_reconstructions.shape[2],d=(ds.T[1]-ds.T[0]))#[:nn//2]

fig1 = plt.figure()
ax1 = fig1.add_subplot(111)
data = fft[100::25,i,:,-1].T
ax1.plot(freq,data[:,:],alpha=.6,lw=2)
colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))
#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]
for i,j in enumerate(ax1.lines):
    j.set_color(colors[i])
plt.xlabel('Frequency f')
plt.ylabel(r'Fourier spectrum')
plt.yscale('log')
plt.xscale('log')
#plt.ylim(-2,2)
divider = make_axes_locatable(plt.gca())
ax_cb = divider.new_horizontal(size="5%", pad=0.05)
norm = mpl.colors.Normalize(vmax=ts[100], vmin=ts[-25])    
cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)
#cb1.ax.invert_yaxis()
cb1.set_label('diffusion time (0,1)')
plt.gcf().add_axes(ax_cb)
ax1.plot(freq,jnp.abs(np.fft.rfft(x,axis=1))[::10,:,-1].T,color='blue',alpha=.1);

In [None]:

import matplotlib.pyplot as plt
i=4 # @param {type:"slider", min:0, max:30, step:1}
plt.plot(T_long,test_x[i,:,-1])
#plt.plot(T_long,det_samples[i,:,-1])
plt.plot(T_long,stoch_samples[i,:,-1])
plt.xlabel('Time t')
plt.ylabel(r'State')
plt.legend([r'GT',r'Model (SDE)'])#r'Model (ODE)', r'Model (SDE)'])

Test ability to condition model on previous timesteps

In [None]:
conditioned_sample = samplers.sample(diff,inpainting_scores2(diff,score_fn,mb[:,slc],slc),key,mb.shape)

In [None]:
from jax import jit,vmap,random

@jit
def rel_err(z1,z2):
  return jnp.abs((jnp.abs(z1-z2)).sum(-1)/(jnp.abs(z1).sum(-1)*jnp.abs(z2).sum(-1)))

gt = x[:30]
for pred in [conditioned_samples[-1],conditioned_sample]:
  clamped_errs = jax.lax.clamp(1e-5,rel_err(pred,gt),np.inf)
  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))
  rel_stds = np.exp(jnp.log(clamped_errs).std(0))
  plt.plot(T_long,rel_errs)
  plt.fill_between(T_long, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)

plt.plot()
plt.yscale('log')
plt.xlabel('Time')
plt.ylabel('Prediction Error')
plt.legend(['SDE completion','ODE completion'])

In [None]:
i=7 # @param {type:"slider", min:0, max:29, step:1}
plt.plot(T_long,x[i,:,1])
plt.plot(T_long[slc],x[i,slc,1],lw=3)
plt.plot(T_long,conditioned_sample[i,:,1])
plt.xlabel('Time t')
plt.ylabel(r'State')
plt.legend([r'GT','Conditioning',r'Model'])
#plt.ylim(-3,3)

Unconditional Prediction quality

In [None]:
# stoch_samples = samplers.stochastic_sample(diff,score_fn,key,x[:30].shape,N=1000,traj=False)
# det_samples = samplers.sample(diff,score_fn,key,x[:30].shape)
print(f'ODE performance {pmetric(det_samples)[0]}')
print(f'SDE performance {pmetric(stoch_samples)[0]}')

In [None]:
from jax import random
key = random.PRNGKey(45)
#s=s2#,history = samplers.stochastic_sampler(denoiser,params,key,(32,)+data.shape[1:],N=500,smin=sigma_min,smax=sigma_max)
s = stoch_samples#energy_samples_det#stoch_samples

k = 5
z = q = s[:,k:]
T = T_long[k:]
z0 = z[:,0]
z_gts = vmap(ds.integrate,(0,None),0)(z0,T)
z_pert = vmap(ds.integrate,(0,None),0)(z0+1e-3*np.random.randn(*z0.shape)*jnp.abs(z0).mean(),T)
z_random = vmap(ds.integrate,(0,None),0)(ds.sample_initial_conditions(z0.shape[0]),T)


In [None]:
for pred in [z,z_pert,z_random]:
  clamped_errs = jax.lax.clamp(1e-3,rel_err(pred,z_gts),np.inf)
  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))
  rel_stds = np.exp(jnp.log(clamped_errs).std(0))
  plt.plot(T,rel_errs)
  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)

plt.plot()
plt.yscale('log')
plt.xlabel('Time')
plt.ylabel('Prediction Error')
plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])

Compared trajectories

In [None]:
for i in range(20):
  fig = plt.figure()
  ax = fig.add_subplot(1, 1, 1)
  line1, = ax.plot(T,z_gts[i,:,:2].sum(-1))
  line2, = ax.plot(T,z[i,:,:2].sum(-1))
  line3, = ax.plot(T,z_pert[i,:,:2].sum(-1))
  plt.xlabel('Time t')
  plt.ylabel(r'State')
  plt.legend(['gt','model','pert'])

In [None]:
for i in range(10):
  fig = plt.figure()
  ax = fig.add_subplot(1, 1, 1)
  line1, = ax.plot(T,z_gts[i,:,0])
  line2, = ax.plot(T,z[i,:,0])
  line3, = ax.plot(T,z_gts[i,:,-1])
  line5, = ax.plot(T,z[i,:,-1])
  plt.xlabel('Time t')
  plt.ylabel(r'State')
  plt.legend(['gt0','model0','gt3','model3'])

In [None]:
metric_vals =[]
metric_stds = []
Ns = [25,50,100,200,500,1000,2000]
for N in Ns:
  s = samplers.stochastic_sample(diff,score_fn,key,x[:30].shape,N=N)
  mean,std = pmetric(s)
  metric_vals.append(mean)
  metric_stds.append(std)
metric_vals = np.array(metric_vals)
metric_stds = np.array(metric_stds)

plt.plot(Ns,metric_vals)
plt.fill_between(Ns, metric_vals/metric_stds, metric_vals*metric_stds,alpha=.3)
plt.xlabel('Sampler steps')
plt.ylabel('Pmetric value')
plt.xscale('log')

In [None]:
from jax import grad,jit
condition_amount = 10# @param {type:"slider", min:0, max:50, step:1}
mb = x[:30,:]
data_std = x.std()

def inpainting_scores(diffusion,scorefn,observed_values,slc):
  b,n,c = observed_values.shape
  def conditioned_scores(xt,t):
    unflat_xt = xt.reshape(b,-1,c)

    observed_score = diffusion.noise_score(unflat_xt[:,slc],observed_values,t)
    unobserved_score = scorefn(xt,t).reshape(b,-1,c)
    combined_score = unobserved_score.at[:,slc].set(observed_score)
    return combined_score
  return conditioned_scores

def inpainting_scores2(diffusion,scorefn,observed_values,slc,scale=300.):
  b,n,c = observed_values.shape
  def conditioned_scores(xt,t):
    unflat_xt = xt.reshape(b,-1,c)

    observed_score = diffusion.noise_score(unflat_xt[:,slc],observed_values,t)
    unobserved_score = scorefn(xt,t).reshape(b,-1,c)
    def constraint(xt):
      one_step_xhat = (xt+diffusion.sigma(t)**2*scorefn(xt,t))/diffusion.scale(t)
      return jnp.sum((one_step_xhat.reshape(b,-1,c)[:,slc]-observed_values)**2)
    #unobserved_score -= grad(constraint)(xt).reshape(unflat_xt.shape)*10/(diff.g2(t)/2)
    unobserved_score -= grad(constraint)(xt).reshape(unflat_xt.shape)*scale*diff.scale(t)**2/diff.sigma(t)**2
    combined_score = unobserved_score.at[:,slc].set(observed_score)
    return combined_score#.reshape(-1)
  return jit(conditioned_scores)

slc = slice(condition_amount)
conditioned_samples = samplers.stochastic_sample(diff,inpainting_scores2(diff,score_fn,mb[:,slc],slc),key,mb.shape,N=1000,traj=True)


In [None]:
k=30
expanded = (mb[None]+jnp.zeros((k,1,1,1))).reshape(mb.shape[0]*k,*mb.shape[1:])#[:,slc]
predictions = samplers.stochastic_sample(diff,inpainting_scores2(diff,score_fn,expanded[:,slc],slc,scale=300.),key,expanded.shape,N=2000,traj=False)

In [None]:
preds = predictions.reshape(k,-1,*predictions.shape[1:])
lower = np.percentile(preds.mean(-1),10,axis=0)
upper = np.percentile(preds.mean(-1),90,axis=0)
for i in range(mb.shape[0]):
  if i>30: break
  plt.plot(T_long,mb[i].mean(-1))
  #plt.plot(T_long,z_pert[i].mean(-1))
  plt.fill_between(T_long,lower[i],upper[i],alpha=.3,color='y')
  plt.plot()
  #plt.yscale('log')
  plt.xlabel('Time')
  plt.ylabel('State sum')
  plt.legend(['Ground Truth','Model 10-90 percentiles'])
  plt.show()