<a href="https://colab.research.google.com/github/kallviktor/RandomInterpolationGAN/blob/main/Interpolations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
def ker(h, a=1, b=1):
  return np.exp(-b*np.abs(h)**a)

def muhat(t, z0, zT, T):
  num = z0*(ker(t) - ker(T-t)*ker(T)) + zT*(ker(T-t) - ker(t)*ker(T))
  denom = 1 - ker(T)**2
  return num/denom

def khat(t, s, T):
  T1 = ker(t-s)
  num = ker(T)*(ker(T-s)*ker(t) + ker(s)*ker(T-t)) - ker(T-s)*ker(T-t) - ker(s)*ker(t)
  denom = 1 - ker(T)**2
  T2 = num/denom
  return T1 + T2

def BP(z0, zT, T, N):
  # Bridge process from z0 to zT in N steps (over time-dimension) and in time T
  
  t_grid = np.linspace(0, T, N)    # Uniform grid on t-axis

  # For each coordinate of z0 and zT we simulate a bridge-process

  dim = len(z0)
  Z = np.zeros((dim,N))

  for (i, z0_i, zT_i) in zip(range(dim), z0, zT):
    mean_i = np.zeros(N)    # mean.shape --> (N,) i.e. a column vector of length N
    for k in range(N):
      t = t_grid[k]
      mean_i[k] = muhat(t, z0_i, zT_i, T)
    
    cov_i = np.zeros((N,N))
    for m in range(N):
      t = t_grid[m]
      for n in range(N):
        s = t_grid[n]
        cov_i[m][n] = khat(t, s, T)
    
    Z_i = np.random.multivariate_normal(mean_i, cov_i)
    Z[i][:] = Z_i
  Z.shape
  return Z

def gen_particles(z0, zT, T, N, batch_size):
  samples = np.zeros(shape=(dim,N,batch_size))
  pr = 0
  for i in range(batch_size):
    z0_i = z0[:,i]
    zT_i = zT[:,i]
    samples[:,:,i] = BP(z0_i, zT_i, T, N)
    if (i/batch_size) >= pr:
      percent = round(pr*100)
      pr += 0.1
      print('Percentage finished: {}%'.format(percent))
  print('Percentage finished: 100%')
  print('Done!')

  return samples[:,1,:]

def weight_func(z, G, D):
  z = z[np.newaxis]
  x = G(z)
  weight = D(x)/(1 - D(x))
  return weight

In [None]:
z0 = np.zeros(100)    # z0.shape --> (100,)
zT = np.ones(100)   # zT.shape --> (100,)

dim = len(z0)
T = 1
N = 10    # Number of steps of random walk / sample of process in t-dimension (time-dimension)

n = 10
parts = np.zeros((dim,n))
paths = np.zeros((dim,N,n))
paths[:,0,:] = np.full((dim,n),z0[np.newaxis].T)
paths[:,-1,:] = np.full((dim,n),zT[np.newaxis].T)

S = range(n)    # Set of indices, just in resampling step below

G = mnist_dcgan.generator.predict
D = mnist_dcgan.DM.predict

# z0 = z0[np.newaxis]
# print(np.transpose(z0).shape)

# print(G(z0).shape)

dt = T/N    # Length of timestep
zT = paths[:,-1,:]
for step in range(N-2):
  z0 = paths[:,step,:]
  Tcur = T - dt*step
  Ncur = N - step
  parts = np.zeros((dim,n))
  parts = gen_particles(z0, zT, Tcur, Ncur, n)
  weights = np.zeros(n)
  for k in range(n):
    z = parts[:,k]
    w = weight_func(z, G, D)
    weights[k] = w
  weights = weights/np.sum(weights)   # Normalize weights
  print(weights)
  # Resampling
  S_re = np.random.choice(S, replace=True, p=weights)   # Resampling indices in S rather than data points z directly
  parts = parts[S_re]

  # Save steps taken
  paths[:,step+1,:] = parts