In [None]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
import tensorflow as tf
#tf.config.experimental.set_visible_devices([], "GPU")

import importlib
from simulation_research.diffusion import ode_datasets
from simulation_research.diffusion import unet
from simulation_research.diffusion import samplers
from simulation_research.diffusion import diffusion
from simulation_research.diffusion import config as cfg
from simulation_research.diffusion import train
from clu import checkpoint
importlib.reload(ode_datasets)
importlib.reload(unet)
importlib.reload(samplers)
importlib.reload(train)

import matplotlib.pyplot as plt
from matplotlib import rc
rc('animation', html='jshtml')
import jax.numpy as jnp
import numpy as np
import jax

In [None]:
import os
import pickle

username="finzi"
exp_name = "test_3"#"all_datasets_and_ic"
xid=0
workdir = "/home/finzi/xm/test_diffusion/datasets_all2/43500559/1"

with tf.io.gfile.Open(os.path.join(workdir,'config.pickle'), "rb") as f:
  config = pickle.load(f)
with tf.io.gfile.Open(os.path.join(workdir,'data_std.pickle'), "rb") as f:
  data_std = pickle.load(f)
# appending the checkpoint folder
checkpoint_dir = os.path.join(workdir, "checkpoints")

# added the checkpoint
ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, {}, max_to_keep=2)


In [None]:
#config.dataset="LorenzDataset"

In [None]:
config

In [None]:
# config = cfg.get_config()
# config.ic_conditioning=False
# config.dataset='FitzHughDataset'

from jax import random

key = random.PRNGKey(config.seed)
# Construct the dataset
timesteps = 60
ds = getattr(ode_datasets, config.dataset)(N=config.ds + config.bs)
Zs = ds.Zs[config.bs:, :timesteps]  # pylint: disable=invalid-name
test_x = ds.Zs[:config.bs, :timesteps]
T_long = ds.T_long[:timesteps]  # pylint: disable=invalid-name
dataset = tf.data.Dataset.from_tensor_slices(Zs)
dataiter = dataset.shuffle(len(dataset)).batch(config.bs).as_numpy_iterator
assert Zs.shape[1] == timesteps

# initialize the model
x = test_x  # (bs, N, C)
modelconfig = unet.unet_64_config(
    x.shape[-1], base_channels=config.channels, attention=config.attention)
model = unet.UNet(modelconfig)
noise = getattr(diffusion, config.noisetype)
difftype = getattr(diffusion, config.difftype)(noise)

In [None]:
dataloader= dataiter
x = next(dataloader())
t = np.random.rand(x.shape[0])
cond_fn = lambda x: (x[:, :3] if config.ic_conditioning else None)
key = random.PRNGKey(config.seed)
key, init_seed = random.split(key)
params = model.init(init_seed, x=x, t=t, train=False, cond=cond_fn(x))

In [None]:
ema_params = ckpt.restore(params)

In [None]:
def score(params,
            x,
            t,
            train,
            cond = None):
    """Score function with appropriate input and output scaling."""
    # scaling is equivalent to that in https://arxiv.org/abs/2206.00364
    sigma, scale = diffusion.unsqueeze_like(x, difftype.sigma(t), difftype.scale(t))
    input_scale = 1 / jnp.sqrt(sigma**2 + (scale * data_std)**2)
    cond = cond / data_std if cond is not None else None
    out = model.apply(params, x=x * input_scale, t=t, train=train, cond=cond)
    return out / jnp.sqrt(sigma**2 + scale**2 * data_std**2)

@jax.jit
def score_out(x,t,cond=None) -> jnp.ndarray:
  if not hasattr(t, 'shape') or not t.shape:
    t = jnp.ones(x.shape[0]) * t
  return score(ema_params, x, t, train=False, cond=cond)

score_fn = score_out
from functools import partial
eval_scorefn = partial(score_out,cond=cond_fn(test_x))

In [None]:
import pandas as pd
import numpy as np
import jax

In [None]:
diff =difftype
#sde_samples = samplers.sde_sample(diff, eval_scorefn, key, test_x.shape,nsteps=1000)
#ode_samples = samplers.discrete_ode_sample(diff, eval_scorefn, key, test_x.shape,nsteps=1000)

In [None]:
def lorenz_C(x):
    fourier_mag = jnp.abs(jnp.fft.rfft(x[...,0],axis=-1))
    return -(fourier_mag[...,1:].mean(-1)-.6)

def fitz_C(x):
    C = jnp.max(x[...,:2].mean(-1),-1)-2.5
    return C

def pendulum_C(x):
    raise NotImplementedError

constraints = {'FitzHughDataset':fitz_C,
          'LorenzDataset': lorenz_C,
          'NPendulumDataset':pendulum_C,
          }
event_constraint = constraints[config.dataset]

In [None]:
plt.plot(jnp.abs(jnp.fft.rfft(test_x[0,:,2]))[1:])

In [None]:
event_scores = samplers.event_scores(diff,score_fn, event_constraint, reg=1e-3)

In [None]:
sde_event_samples = samplers.sde_sample(diff, event_scores, key, test_x.shape,nsteps=1000)
#ode_event_samples = samplers.discrete_ode_sample(diff, event_scores, key, test_x.shape,nsteps=1000)

In [None]:
sde_samples = samplers.sde_sample(diff, score_fn, key, test_x.shape,nsteps=1000)
#ode_samples = samplers.discrete_ode_sample(diff, score_fn, key, test_x.shape,nsteps=1000)

In [None]:
T = ds.T_long[:timesteps]

In [None]:
event_distribution = event_constraint(Zs)
events_train = Zs[event_constraint(Zs)>0]
events_test = test_x[event_constraint(test_x)>0]

In [None]:
plt.plot(T,sde_samples[event_constraint(sde_samples)>0][:5,:,0].T)
plt.xlabel(r'$\tau$')
plt.ylabel('x')
plt.title('Example model events')

In [None]:
plt.plot(T,test_x[10:15,:,0].T)
plt.xlabel(r'$\tau$')
plt.ylabel('x')
plt.title('Data samples')

In [None]:
plt.plot(T,events_test[:5,:,0].T)
plt.xlabel(r'$\tau$')
plt.ylabel('x')
plt.title('Example events in dataset')

In [None]:
T = ds.T[:timesteps]
plt.plot(T,events_test[0,:,:2].mean(-1))
plt.plot(T,2*np.ones_like(T),color='k')
plt.xlabel(r'Time ($\tau$)')
plt.ylabel(r'$x(\tau)$')
plt.legend(['example spike', 'our cutoff y'])

In [None]:
plt.hist(np.array(event_constraint(sde_samples)),bins=80,color='red',density=True,alpha=.5)
#plt.hist(np.array(ode_samples[:,:timesteps,:2].mean(-1).max(-1)),bins=100,color='g',density=True,alpha=.2)
plt.hist(np.array(event_constraint(ds.Zs[:4000,:timesteps])),bins=80,color='y',density=True,alpha=.5)

#plt.yscale('log')
#plt.xlabel(r'$\max_\tau x(\tau)$')
plt.xlabel(r'$.6-||F[x]_{1:}||_1$')
plt.ylabel('Density')
plt.ylim(1e-2,2.5)
plt.axvline(x=0,color='k')
plt.legend(['y cutoff','Diffusion samples','True distribution'])

In [None]:
sde_event_samples.shape

In [None]:
plt.hist(np.array(event_constraint(sde_event_samples[:,:])),bins=100,color='red',density=True,alpha=.5)
#plt.hist(np.array(ode_samples[:,:timesteps,:2].mean(-1).max(-1)),bins=100,color='g',density=True,alpha=.2)
plt.hist(np.array(event_constraint(events_train[:,:timesteps])),bins=50,color='y',density=True,alpha=.5)

#plt.yscale('log')
#plt.xlabel(r'$\max_\tau x(\tau)$')
plt.axvline(0,color='k')
plt.xlabel(r'$.6-||F[x]_{1:}||_1$')
plt.ylabel('density')
plt.ylim(1e-2,6)
plt.legend(['y cutoff','Conditional diffusion samples x|E','True distribution x|E'])

In [None]:
import seaborn
seaborn.kdeplot(np.array(sde_event_samples[:,:timesteps,:2].mean(-1).max(-1)))
seaborn.kdeplot(np.array(events_train[:,:timesteps,:2].mean(-1).max(-1)))

In [None]:
(event_constraint(sde_event_samples)>0).mean()

In [None]:
sde_events2 = sde_event_samples[event_constraint(sde_event_samples)>0,:,0]
plt.plot(T,sde_events2[:5].T)
plt.xlabel(r'$\tau$')
plt.ylabel('x')
plt.title('x|E conditional model samples')

In [None]:
ds.animate(sde_event_samples[event_constraint(sde_event_samples)>0][1])

In [None]:
ds.animate(test_x[1])

In [None]:
(event_constraint(Zs)>0).mean()

In [None]:
jnp.exp(-10)

In [None]:
type((jnp.ones(3)*.2).sum())

In [None]:
logprob,logprob_std = samplers.marginal_logprob(diff, score_fn, event_constraint, test_x[0].shape,nsteps=1000)

In [None]:
conditional_likelihood = samplers.discrete_time_likelihood(diff, event_scores, sde_event_samples[:2])
unconditional_likelihood = samplers.discrete_time_likelihood(diff, scores_fn, sde_event_samples[:2])
print(conditional_likelihood,unconditional_likelihood)

In [None]:
print(logprob,logprob_std)

In [None]:
jnp.exp(-logprob)

In [None]:
plt.plot(ode_event_samples[:5,:,:2].mean(-1).T)

In [None]:
#nll1 = samplers.compute_nll(diff,score_fn,key,sde_samples)
nll2 = -samplers.discrete_time_likelihood(diff,score_fn,sde_samples[:5])/sde_samples[0].size

In [None]:
nll2

In [None]:
expanded = (mb[None]+jnp.zeros((10,1,1,1))).reshape(mb.shape[0]*10,*mb.shape[1:])#[:,slc]
predictions = samplers.stochastic_sample(diff,inpainting_scores2(diff,score_fn,expanded[:,slc],slc,scale=300.),key,expanded.shape,N=2000,traj=False)

In [None]:
sde_samples.shape

In [None]:
from jax import vmap
#
T = T_long
z1 = sde_samples
z2 = ode_samples
z_gts = test_x[:z1.shape[0]]
z0 = z_gts[:,0]#z_gts[:,0]
#z0 = test_x
#z_gts = vmap(ds.integrate,(0,None),0)(z0,T)
z_pert = vmap(ds.integrate,(0,None),0)(z0+1e-3*np.random.randn(*z0.shape),T)
z_random = vmap(ds.integrate,(0,None),0)(ds.sample_initial_conditions(z0.shape[0]),T)
for pred in [z1,z2,z_pert,z_random]:
  clamped_errs = jax.lax.clamp(1e-3,train.rel_err(pred,z_gts),np.inf)
  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))
  rel_stds = np.exp(jnp.log(clamped_errs).std(0))
  plt.plot(T,rel_errs)
  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)

plt.plot()
plt.yscale('log')
plt.xlabel('Time')
plt.ylabel('Prediction Error')
plt.legend(['SDE Diffusion Model Rollout','ode','1e-3 Perturbed GT','Random Init'])

In [None]:
from jax import vmap
#z_gts = test_x[:z1.shape[0]]
T = T_long
z = ode_samples
z0 = z[:,0]#z_gts[:,0]
#z0 = test_x
z_gts = vmap(ds.integrate,(0,None),0)(z0,T)
z_pert = vmap(ds.integrate,(0,None),0)(z0+1e-3*np.random.randn(*z0.shape),T)
z_random = vmap(ds.integrate,(0,None),0)(ds.sample_initial_conditions(z0.shape[0]),T)
for pred in [z,z_pert,z_random]:
  clamped_errs = jax.lax.clamp(1e-3,train.rel_err(pred,z_gts),np.inf)
  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))
  rel_stds = np.exp(jnp.log(clamped_errs).std(0))
  plt.plot(T,rel_errs)
  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)

plt.plot()
plt.yscale('log')
plt.xlabel('Time')
plt.ylabel('Prediction Error')
plt.legend(['SDE Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])

In [None]:
from flax.core.frozen_dict import FrozenDict
import numpy as np
def sum_params(params):
  if isinstance(params, (jax.numpy.ndarray,np.ndarray)):
    return params.sum()
  elif isinstance(params, (dict, FrozenDict)):
    return sum([sum_params(v) for v in params.values()])
  else:
    assert False, type(params)
print(sum_params(params))
print(sum_params(p2))

In [None]:
type(None)

In [None]:
type(p2)

In [None]:
import jax.numpy as jnp
jnp.exp(973.3657-977.17847)