In [None]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
import tensorflow as tf

from colabtools import adhoc_import
import importlib
from simulation_research.diffusion import ode_datasets
from simulation_research.diffusion import diffusion_unet
from simulation_research.diffusion import samplers
importlib.reload(ode_datasets)
importlib.reload(diffusion_unet)
importlib.reload(samplers)

import matplotlib.pyplot as plt
from matplotlib import rc
rc('animation', html='jshtml')
import jax.numpy as jnp
import numpy as np

In [None]:
from jax import devices,device_count
device_count()

In [None]:
tf.executing_eagerly()

# Generate the Trajectories

## N-Link Pendulum

In [None]:
dt = .1
ds = ode_datasets.NPendulum(N=2000,n=1,dt=dt)
thetas,vs = ode_datasets.unpack(ds.Zs)

In [None]:
# for i in range(20):
#   fig = plt.figure()
#   ax = fig.add_subplot(1, 1, 1)
#   line1, = ax.plot(ds.T_long,thetas[i,:,0])
#   line2, = ax.plot(ds.T_long,thetas[i,:,1])#
#   #line2, = ax.plot(ds.T_long,jnp.cos(thetas[i,:,1])+jnp.cos(thetas[i,:,0]))
#   plt.xlabel('Time t')
#   plt.ylabel(r'State')
#   plt.legend([r'$\theta_0$',r'$\theta_1$'])

In [None]:
dataset = tf.data.Dataset.from_tensor_slices(thetas)
data_std = thetas.std()

In [None]:
jnp.sqrt(((thetas[None,:400]-thetas[:400,None])**2).sum((-1,-2))).max()/jnp.sqrt(np.prod(thetas.shape[1:]))

In [None]:
bs = 400
dataiter = dataset.shuffle(len(dataset)).batch(bs).as_numpy_iterator

In [None]:
from matplotlib import rc
rc('animation', html='jshtml')
#ds.animate()

##Diffusion

In [None]:
import numpy as np
import jax.numpy as jnp
from jax import random
import jax
import flax

In [None]:
x = next(dataiter())
t = np.random.rand(x.shape[0])
model = diffusion_unet.UNet(diffusion_unet.unet_64_config(out_dim=x.shape[-1],base_channels=24))
params = model.init(random.PRNGKey(42), x=x,t=t,train=False)
x.shape

In [None]:
def count_params(params):
  if isinstance(params, jax.numpy.ndarray):
    return np.prod(params.shape)
  elif isinstance(params,(dict,flax.core.frozen_dict.FrozenDict)):
    return sum([count_params(v) for v in params.values()])
  else:
    assert False, type(params)

In [None]:
count_params(params)

Initialize the UNet

In [None]:
x.shape

In [None]:
from tqdm.auto import tqdm
import optax
from jax import jit
import pandas as pd
importlib.reload(samplers)

#sigma_min = 1e-3#2e-4#2e-3
#sigma_max = 1#100


In [None]:
key = random.PRNGKey(38)
with Mesh(mesh_utils.create_device_mesh((device_count(),)), ('data',)):
  for epoch in tqdm(range(601)):
    for data in dataiter():
      params,ema_params,opt_state,key,loss_val = update_fn(params,ema_params,opt_state,key,data)
    if epoch % 5 == 0:
      ema_loss = jloss(ema_params,data,key)
      message = f'Loss epoch {epoch}: {loss_val:.3f} Ema {ema_loss:.3f}'
      # if not epoch % 30:
      #   val = pmetric(samplers.stochastic_sampler(denoiser,params,key,(512,)+data.shape[1:],500)[0])[0]
      #   #message += f'     Precision: {}'
      print(message)
    if epoch %200 ==0:
      print(eval_metrics(dataiter,ema_params,key))

params=ema_params

In [None]:
mb = data[:30]

In [None]:
importlib.reload(samplers)
denoiser = jit(lambda params,x,sigma: denoised(params,x,jnp.ones(x.shape[0])*sigma,train=False))  
def conditioning_scores(observed_values,s=.2):
  b,n1,c = observed_values.shape
  return jax.grad(lambda x: -jnp.sum((x.reshape(b,-1,c)[:,:n1]-observed_values)**2)/(2*s**2))
#conditioning_scores(mb[:,:20]),

  


In [None]:
importlib.reload(samplers)
t=.001
z = samplers.sample(denoiser,params,key,mb.shape,t,t_max)#,conditioning_scores(mb[:,:50]))
noised_x = mb*samplers.s(t)+np.random.randn(*mb.shape)*(samplers.s(t)*samplers.sigma(t))
import matplotlib.pyplot as plt
i=2
plt.plot(ds.T_long,mb[i,:,0])
plt.plot(ds.T_long,noised_x[i,:,0])
plt.plot(ds.T_long,z[i,:,0])

plt.xlabel('Time t')
plt.ylabel(r'State')
plt.legend([r'GT','GT noised xt',r'Model xt'])

In [None]:
importlib.reload(samplers)
nll = samplers.compute_nll(denoiser,params,key,data[:400])

In [None]:
nll.mean()

In [None]:
importlib.reload(samplers)
from jax import grad
def score(x,t):
  return (denoiser(params,x.reshape(mb.shape)/samplers.s(t),samplers.sigma(t)).reshape(-1)-x/samplers.s(t))/(samplers.s(t)*samplers.sigma(t)**2)
dynamics = lambda t,x: grad(samplers.s)(t)*x/samplers.s(t)-(samplers.s(t)**2)*grad(samplers.sigma)(t)*score(x,t).reshape(-1)*samplers.sigma(t)
dynamics2 = lambda t,x: (grad(samplers.s)(t)/samplers.s(t)+grad(samplers.sigma)(t)/samplers.sigma(t))*x - (grad(samplers.sigma)(t)/samplers.sigma(t))*samplers.s(t)*denoiser(params,x.reshape(mb.shape)/samplers.s(t),samplers.sigma(t)).reshape(-1)

In [None]:
dynamics(.99,xt.reshape(-1))

In [None]:
dynamics2(.99,xt.reshape(-1))

In [None]:
xt = np.random.randn(*mb.shape)*samplers.s(t_max)*samplers.sigma(t_max)

In [None]:
xt.shape

In [None]:
jnp.max(jnp.abs(samplers.score(denoiser,params,mb.shape)(mb.reshape(-1),1.)))

In [None]:
t=1.
denoiser(params,mb/samplers.s(t),samplers.sigma(t)).reshape(-1)

In [None]:
denoiser(params,mb/samplers.s(t),t)

In [None]:
denoised(params,mb,jnp.ones(mb.shape[0])*samplers.sigma(t),train=False)

In [None]:
dynamics(1.,mb.reshape(-1))

In [None]:
1/samplers.s(t_max)

In [None]:
# z = jax.random.normal(key,(64,)+input_data.shape[1:])
# y = denoiser(z,.1)
# import numpy as np
# perm = np.random.permutation(z.shape[0])
# y2 = denoiser(z[perm],.1)[np.argsort(perm)]
# print(jnp.linalg.norm(y-y2))

In [None]:
import matplotlib.pyplot as plt
i=5
plt.plot(ds.T_long,mb[i,:,0])
plt.plot(ds.T_long,z[i,:,0])
plt.xlabel('Time t')
plt.ylabel(r'State')
plt.legend([r'GT',r'Model'])

In [None]:
data = next(dataiter())
key = random.PRNGKey(26)

In [None]:
nll = samplers.compute_nll(denoiser,params,key,data[:400],smin=sigma_min,smax=sigma_max,num_probes=1)

In [None]:
nll.mean()

In [None]:
nll.std(0)/jnp.sqrt(len(nll))

In [None]:

noised_data = samplers.forward_process2(denoiser,params,key,data,smin=sigma_min,smax=sigma_max)

In [None]:
noised_data.shape

In [None]:
noised_data.std()

In [None]:
noised_data.std()

In [None]:
T = samplers.timesteps(30,sigma_min,sigma_max)
print(np.sum(T[1:]-T[:-1]))

In [None]:

key = random.PRNGKey(45)

s,history = samplers.stochastic_sampler(denoiser,params,key,(32,)+data.shape[1:],N=1000,smin=sigma_min,smax=sigma_max)

In [None]:
#stochastic_sampler(params,key,(128,)+input_data.shape[1:],N=2000)

In [None]:
s = samplers.sample(denoiser,params,random.split(key)[0],(64,)+data.shape[1:])

In [None]:
import matplotlib.pyplot as plt
plt.plot(ds.T_long,thetas[2,:,0])
plt.plot(ds.T_long,thetas[2,:,-1])
plt.xlabel('Time t')
plt.ylabel(r'State')
plt.legend([r'$\theta_0$',r'$\theta_1$'])

In [None]:
for i,h in enumerate(history[::200]):
  plt.plot(ds.T_long,h[1,:,-1],label=str(i),alpha=1/3)
plt.plot(ds.T_long,s[1,:,-1],label=str(i))
plt.xlabel('Time t')
plt.ylabel(r'State')
plt.legend()
plt.ylim((-3,3))
#plt.legend([r'$\theta_0$',r'$\theta_1$'])

In [None]:
import matplotlib.pyplot as plt
from ipywidgets import interact



# @interact(i=(0,s.shape[0]-1))
# def plot(i=1):
#   fig = plt.figure()
#   ax = fig.add_subplot(1, 1, 1)
#   line1, = ax.plot(ds.T_long,s[i,:,0])
#   line2, = ax.plot(ds.T_long,s[i,:,1])
#   plt.xlabel('Time t')
#   plt.ylabel(r'State')
#   plt.legend([r'$\theta_0$',r'$\theta_1$'])
  #plt.ylim(-2,2)

In [None]:
for i in range(2):
  fig = plt.figure()
  ax = fig.add_subplot(1, 1, 1)
  line1, = ax.plot(ds.T_long,s[i,:,0])
  line2, = ax.plot(ds.T_long,s[i,:,-1])
  plt.xlabel('Time t')
  plt.ylabel(r'State')
  plt.legend([r'$\theta_0$',r'$\theta_1$'])

In [None]:
for i in range(10):
  fig = plt.figure()
  ax = fig.add_subplot(1, 1, 1)
  line1, = ax.plot(ds.T_long,s[i,:,0])
  line2, = ax.plot(ds.T_long,s[i,:,-1])
  plt.xlabel('Time t')
  plt.ylabel(r'State')
  plt.legend([r'$\theta_0$',r'$\theta_1$'])

In [None]:
key = random.PRNGKey(45)
#s=s2#,history = samplers.stochastic_sampler(denoiser,params,key,(32,)+data.shape[1:],N=500,smin=sigma_min,smax=sigma_max)


k = 5
q = s[:,k:]
v = -(q[:,:-2]-q[:,2:])/(2*(ds.T[1]-ds.T[0]))
z = ode_datasets.pack(q[:,1:-1],(vmap(vmap(ds.M))(q[:,1:-1])@v[...,None]).squeeze(-1))
T = ds.T_long[k+1:-1]
z0 = z[:,0]
z_gts = vmap(ds.integrate,(0,None),0)(z0,T)
z_pert = vmap(ds.integrate,(0,None),0)(z0+1e-3*np.random.randn(*z0.shape),T)
z_random = vmap(ds.integrate,(0,None),0)(ds.sample_initial_conditions(z0.shape[0]),T)




In [None]:
q.shape

In [None]:
for i in range(10):
  fig = plt.figure()
  ax = fig.add_subplot(1, 1, 1)
  line1, = ax.plot(T,z_gts[i,:,0])
  line2, = ax.plot(T,z[i,:,0])
  line3, = ax.plot(T,z_pert[i,:,0])
  plt.xlabel('Time t')
  plt.ylabel(r'State')
  plt.legend(['gt','model','pert'])

In [None]:
for i in range(10):
  fig = plt.figure()
  ax = fig.add_subplot(1, 1, 1)
  line1, = ax.plot(T,z_gts[i,:,0])
  line2, = ax.plot(T,z[i,:,0])
  line3, = ax.plot(T,z_gts[i,:,-1])
  line5, = ax.plot(T,z[i,:,-1])
  plt.xlabel('Time t')
  plt.ylabel(r'State')
  plt.legend([r'$\theta_0$ gt',r'$\theta_0$ model',r'v gt', r'v model'])

In [None]:
pmetric(s)

In [None]:
for pred in [z,z_pert,z_random]:
  clamped_errs = jax.lax.clamp(1e-3,rel_err(pred,z_gts),np.inf)
  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))
  rel_stds = np.exp(jnp.log(clamped_errs).std(0))
  plt.plot(T,rel_errs)
  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)

plt.plot()
plt.yscale('log')
plt.xlabel('Time')
plt.ylabel('Prediction Error')
plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])

In [None]:
for pred in [z,z_pert,z_random]:
  clamped_errs = jax.lax.clamp(1e-3,rel_err(pred,z_gts),np.inf)
  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))
  rel_stds = np.exp(jnp.log(clamped_errs).std(0))
  plt.plot(T,rel_errs)
  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)

plt.plot()
plt.yscale('log')
plt.xlabel('Time')
plt.ylabel('Prediction Error')
plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])

In [None]:
H_gts = vmap(vmap(ds.H))(z_gts)
for pred in [z,z_pert,z_random]:
  Hs = vmap(vmap(ds.H))(pred)
  clamped_errs = jax.lax.clamp(1e-3,jnp.abs(Hs-H_gts)/jnp.abs(Hs*H_gts),np.inf)
  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))
  rel_stds = np.exp(jnp.log(clamped_errs).std(0))
  plt.plot(T,rel_errs)
  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)

plt.plot()
plt.yscale('log')
plt.xlabel('Time')
plt.ylabel('Energy Error')
plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])

In [None]:

for H in Hs:
  plt.plot(ds.T_long[1:-1],jnp.abs(H-H[0]))
plt.yscale('log')
plt.xlabel('Time')
plt.ylabel('Energy Error')

In [None]:
metric_vals =[]
metric_stds = []
Ns = [25,50,100,200,500,1000,2000]
for N in Ns:
  s,_ = samplers.stochastic_sampler(denoiser,params,key,(256,)+data.shape[1:],N=N,smin=sigma_min,smax=sigma_max)
  mean,std = pmetric(s)
  metric_vals.append(mean)
  metric_stds.append(std)
metric_vals = np.array(metric_vals)
metric_stds = np.array(metric_stds)
plt.plot(Ns,metric_vals)
plt.fill_between(Ns, metric_vals/metric_stds, metric_vals*metric_stds,alpha=.3)
plt.xlabel('Sampler steps')
plt.ylabel('Pmetric value')
plt.xscale('log')

In [None]:
plt.plot(Ns,metric_vals)
plt.fill_between(Ns, metric_vals/metric_stds, metric_vals*metric_stds,alpha=.3)
plt.xlabel('Sampler steps')
plt.ylabel('Pmetric value')
plt.xscale('log')

In [None]:
data = next(dataiter())
key = random.PRNGKey(26)
noised_x,sigma = noise_input(data,key)
weighting = (sigma**2+data_std**2)/(sigma*data_std)**2
losses = jnp.mean(((denoised(ema_params,noised_x,sigma)-data)**2)*weighting[:,None,None],axis=(-1,-2))

In [None]:

plt.scatter(sigma,losses)
#plt.plot(np.sort(sigma),jax.scipy.stats.norm.pdf(np.log(np.sort(sigma)),mu,std),color='y')
#plt.hline(1e-1)
#plt.scatter(sigma,weighting)
plt.yscale('log')
plt.xscale('log')
plt.ylabel('weighted loss')
plt.xlabel(r'$\sigma$')
plt.legend(['loss values','sigma sample pdf'][:1:-1])

In [None]:
x = np.random.randn(256)

binomial = [np.array([1., 1.])/2]
for _ in range(int(np.floor(np.log2(len(x))))):
  sqr = np.convolve(binomial[-1],binomial[-1])
  #binomial[-1] /= sqr[sqr.shape[0]//2+1]
  #binomial.append(sqr/sqr.sum())
  binomial.append(sqr/sqr[sqr.shape[0]//2+1])
  
binomial = [np.array([1.])]+binomial[:-1]
def blur(z):
  return jnp.convolve(binomial[-1],z,mode='same')

#vblur = vmap(vmap(blur,0,0),2,2)

def vblur(z):
  s = jnp.fft.rfft(z,axis=1)
  f = 1+jnp.abs(jnp.fft.fftfreq(z.shape[1])[:s.shape[1]])*s.shape[1]
  scaled = s/f[None,:,None]**.5
  scaled = scaled/jnp.mean(jnp.abs(scaled),axis=1,keepdims=True)
  noise = jnp.fft.irfft(scaled,axis=1)
  return noise


In [None]:
x = np.random.randn(300)

binomial = [np.array([1., 1.])/2]
for _ in range(int(np.floor(np.log2(len(x))))):
  sqr = np.convolve(binomial[-1],binomial[-1])
  #binomial[-1] /= sqr[sqr.shape[0]//2+1]
  binomial.append(sqr/sqr.sum())
  #binomial.append(sqr/sqr[sqr.shape[0]//2+1])
  
binomial = [np.array([1.])]+binomial[:-1]

In [None]:
blurred = [jax.scipy.signal.convolve(x,bin,mode='same') for bin in binomial]
blurred.append(jnp.cumsum(x)/np.sqrt(len(x)))

In [None]:
for i,bx in enumerate(blurred):
  plt.plot(bx,label=str(2**i))
plt.legend()

In [None]:
freq = np.fft.fftfreq(x.shape[0])[:x.shape[0]//2]*x.shape[0]
for i,bx in enumerate(blurred):
  plt.plot(freq, jnp.abs(np.fft.fft(bx)[:x.shape[0]//2]),label=str(2**i))

plt.plot(freq,1/freq**2,label='brown')
plt.plot(freq,1/freq,label='pink')
plt.yscale('log')
plt.xscale('log')
plt.legend()

In [None]:
#plt.plot(freq, jnp.abs(np.fft.fft(vblur(x[None,:,None])[0,:,0])[:x.shape[0]//2]),label=str(2**i))
plt.plot(freq,1/freq**2,label='brown')
plt.plot(freq,1/freq,label='pink')
plt.yscale('log')
plt.xscale('log')
plt.legend()

In [None]:
plt.plot(vblur(x[None,:,None])[0,:,0])

In [None]:
[(bx**2).mean() for x in blurred]