In [8]:
import torch
import numpy as np
import sys
import os

device = 'cpu'
p = 2
save_dir = 'HJ_residuals/'

In [9]:
save_dir

'HJ_residuals/'

In [10]:
def get_inf_convolution(xt, f, g, delta, domain, int_samples, alpha=1, linesearch_iters=0):
  # Compute argmin formula for infimal convolution tailored to HJ equations
  # This approach uses uniform sampling
  # Inputs:
  #   xt: tensor of size dim+1 by 1, where dim = spatial dimension

  assert(xt.shape[1]==1)
  linesearch_iters +=1

  # n_inputs = xt.shape[0]
  dim = xt.shape[0]-1

  x = xt[0:dim]
  t = xt[dim]

  y = (domain[1]-domain[0])*torch.rand(int_samples,dim, device=device) + domain[0]


  argmin_output = torch.zeros(dim, device=device)
  min_output = torch.zeros(1, device=device)

  # both values below are of size n_samples
  g_val = g(y, x, t)
  f_val = f(y)

  assert(g_val.shape==f_val.shape)

  z = -(alpha/delta)*(f_val + g_val) # shape =  n_samples
  w = torch.softmax(z, dim=0) # shape = n_samples

  assert(z.shape==w.shape)

  softmax_overflow_check = (w < np.inf)

  if softmax_overflow_check.prod()==0.0:
    print('xt = ', xt)
    print('z = ', z)
    print('w = ', w)
    alpha = 0.5*alpha
    return get_inf_convolution(x, t, f, delta=delta, int_samples=int_samples, alpha=alpha, linesearch_iters=linesearch_iters, device=device)
  else:
    prox_term = torch.matmul(w.t(), y)
    prox_term = prox_term.view(-1,1)

    prox_overflow = (prox_term < np.inf)
    if prox_overflow.prod() == 0.0:
      print('prox overflowed: ', prox_term)
    assert(prox_overflow.prod() == 1.0)

    min_output = f(prox_term.view(1,-1)) + g(prox_term.view(1,-1), x,t)

    assert(min_output<np.inf)
    assert(np.isnan(min_output.item())==False)

  return min_output, argmin_output, linesearch_iters

In [11]:
def f(y):
  # size of y is n_samples x n_dim
  return torch.norm(y, dim=1, p=1)

def g(y,x,t,p=p):
  # assumes x has shape n_dim,1
  # assumes y has shape n_samples by n_dim

  if p < np.infty:
    v = torch.abs((y-x.view(-1))/t)

    H_star = (p-1)/p * torch.sum(v**(p/(p-1)), dim=1)

    return t*H_star

def H(u, p=p):
  # assumes p is of size dim x 1
  if p < np.infty:
    return (torch.norm(u, p=p, dim=0)/(p**(1/p)))**p

def compute_HJ_residual(xt, H, domain, f, g, delta, int_samples):

  # assumes input xt is of shape dim+1 x 1

  HJ_res = 0.0
  xt.requires_grad=True

  dim = xt.shape[0]-1
  u_val, prox_term, linesearch_iters = get_inf_convolution(xt, f, g, delta, domain, int_samples)

  grad_u_full = torch.autograd.grad(outputs=u_val, inputs=xt, grad_outputs=torch.ones(u_val.shape, device=device), retain_graph=None, create_graph=True, only_inputs=True, allow_unused=False, is_grads_batched=False)[0]

  u_t = grad_u_full[dim,:].detach().view(1,-1)
  grad_u = grad_u_full[0:dim, :].detach().view(dim, -1)

  xt = xt.detach()

  H_term = H(grad_u)
  HJ_res = u_t + H_term

  return HJ_res

In [12]:
int_samples = int(1e5)
n_input_samples = int(1e3)
dim = 5

n_trials = 50

t_array = torch.rand(n_input_samples, device=device) + 1e-1

xt_array = 10*torch.rand(dim+1, n_input_samples, device=device)-5 # get samples from [-1,1]
xt_array[dim,:] = t_array

domain = [-10,10]
delta_array = [1e2, 1e1, 1e0, 1e-1, 1e-2]
samples_array = [1e1, 1e2, 1e3, 1e4, 1e5]

HJ_residual_norms = torch.zeros(n_trials, len(delta_array), len(samples_array))

for k in range(n_trials):
  print('\ntrial ', k+1)
  for i in range(len(delta_array)):
    current_delta = delta_array[i]

    max_ls_iters = 1

    print('\n')
    for l in range(len(samples_array)):

      current_int_samples = int(samples_array[l])

      for j in range(n_input_samples):

        xt = xt_array[:, j]
        xt = xt.view(-1, 1)

        inf_terms, prox_terms, linesearch_iters = get_inf_convolution(xt, f, g, current_delta, domain, current_int_samples)
        max_ls_iters = np.max([max_ls_iters, linesearch_iters])

        HJ_residuals = compute_HJ_residual(xt, H, domain, f, g, current_delta, current_int_samples).cpu()
        HJ_residual_norms[k,i,l] = HJ_residual_norms[k,i,l] + torch.norm(HJ_residuals, p=2)/n_input_samples

      print('delta: ', current_delta, ', num samples:', current_int_samples, ', HJ_res norms: ', HJ_residual_norms[k,i,l], ', ls_iters: ', max_ls_iters)


trial  1


delta:  100.0 , num samples: 10 , HJ_res norms:  tensor(157.2915) , ls_iters:  1
delta:  100.0 , num samples: 100 , HJ_res norms:  tensor(20.7180) , ls_iters:  1
delta:  100.0 , num samples: 1000 , HJ_res norms:  tensor(8.9706) , ls_iters:  1
delta:  100.0 , num samples: 10000 , HJ_res norms:  tensor(9.4493) , ls_iters:  1
delta:  100.0 , num samples: 100000 , HJ_res norms:  tensor(9.6992) , ls_iters:  1


delta:  10.0 , num samples: 10 , HJ_res norms:  tensor(5984.5649) , ls_iters:  1
delta:  10.0 , num samples: 100 , HJ_res norms:  tensor(667.9370) , ls_iters:  1
delta:  10.0 , num samples: 1000 , HJ_res norms:  tensor(36.5180) , ls_iters:  1
delta:  10.0 , num samples: 10000 , HJ_res norms:  tensor(5.8470) , ls_iters:  1
delta:  10.0 , num samples: 100000 , HJ_res norms:  tensor(1.1043) , ls_iters:  1


delta:  1.0 , num samples: 10 , HJ_res norms:  tensor(131893.7969) , ls_iters:  1
delta:  1.0 , num samples: 100 , HJ_res norms:  tensor(9479.0684) , ls_iters:  1
delta: 

In [14]:
data_name = 'hj_residuals_p' + str(p) + '_dim' + str(int(dim)) + '.pth'
file_name = save_dir + data_name
state = {
    'HJ_residual_norms': HJ_residual_norms,
    'n_input_samples': n_input_samples,
    'dim': dim,
    'xt_array': xt_array,
    'delta_array': delta_array,
    'samples_array': samples_array,
    'domain': domain,
    'n_trials': n_trials
}
torch.save(state, file_name)
print('files saved to ' + file_name)

files saved to HJ_residuals/hj_residuals_p2_dim5.pth
