<a href="https://colab.research.google.com/github/johannnamr/Discrepancy-based-inference-using-QMC/blob/main/Helper-functions/Plot_fcts.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Plot functions

Imports:

In [None]:
#import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
#import scipy.stats as stats

Set size for plot labels:

In [None]:
SMALL_SIZE = 20
MEDIUM_SIZE = 22
BIGGER_SIZE = 26

plt.rc('font', size=MEDIUM_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

## Plot samples

Plot samples of the **uniform distribution**:

In [None]:
def plot_unif(y,d,cols,fig_size):
  # check distribution
  print('mean: ',[np.round(np.mean(y[:,j]),4) for j in range(d)])
  print('min: ',[np.round(np.min(y[:,j]),4) for j in range(d)])
  print('max: ',[np.round(np.max(y[:,j]),4) for j in range(d)])

  # number of rows required
  rows = d // cols 
  rows += d % cols

  # position index
  position = range(1,d + 1)

  # main figure
  fig = plt.figure(figsize=fig_size)

  # add subplots
  for j in range(d):
    ax = fig.add_subplot(rows,cols,position[j])
    ax.hist(y[:,j], bins=np.linspace(np.min(y[:,j]), np.max(y[:,j]),100), density=True)
    rng = np.arange(0, 1, 0.001)
    ax.set_ylim((0,1.6))
    ax.set_ylabel('density')
    ax.set_xlabel('y')
    ax.set_title('Dimension ' + str(j+1))
    ax.plot(rng,np.ones(len(rng)))

  plt.tight_layout()
  plt.savefig('Histogram.pdf')
  plt.show()

Plot samples of the **Gaussian location model**:

In [None]:
def plot_gaussian(y,theta,d,s,cols,fig_size):
  # check distribution
  print('mean: ',[np.round(np.mean(y[:,j]),4) for j in range(d)])
  print('sd:   ',[np.round(np.std(y[:,j]),4) for j in range(d)])

  # number of rows required
  rows = d // cols 
  rows += d % cols

  # position index
  position = range(1,d + 1)

  # main figure
  fig = plt.figure(figsize=fig_size)

  # add subplots
  for j in range(d):
    ax = fig.add_subplot(rows,cols,position[j])
    ax.hist(y[:,j], bins=np.linspace(np.min(y[:,j]), np.max(y[:,j]),100), density=True)
    rng = np.arange(theta[j]-10, theta[j]+10, 0.001)
    pdf = stats.norm.pdf(rng,theta[j],s)
    ax.set_ylim((0,0.3))
    ax.set_xlim((theta[j]-10, theta[j]+10))
    ax.set_ylabel('density')
    ax.set_xlabel('y')
    ax.set_title('Dimension ' + str(j+1))
    ax.plot(rng,pdf)

  plt.tight_layout()
  plt.savefig('Histogram.pdf')
  plt.show()

Plot samples from **beta distribution**:

In [None]:
def plot_beta(y,fig_size,theta):

  # plot histogram with true density
  fig = plt.figure(figsize=fig_size)
  plt.hist(y, bins=np.linspace(np.min(y), np.max(y),100), density=True)
  rng = np.arange(-0.05, 1.05, 0.001)
  pdf = stats.beta.pdf(rng,theta[0],theta[1])
  plt.ylabel('density')
  plt.xlabel('y')
  plt.title('Histogram')
  plt.plot(rng,pdf)

  plt.savefig('beta_hist.pdf')
  plt.show()

Plot samples from **g-and-k distribution**:

In [None]:
def plot_gandk(y,fig_size,theta):

  # check generator
  rng = np.arange(0.01,1,0.01)
  z_rng = stats.norm.ppf(rng, loc=0, scale=1)
  plt.figure(figsize=fig_size)
  plt.plot(rng,gen_gandk(z_rng,theta))
  plt.title('Quantile function')
  plt.xlabel(r'$u$')
  plt.ylabel(r'$G_\theta(u)$')
  plt.savefig('generator.pdf')
  plt.show()

  # plot histogram
  plt.figure(figsize=fig_size)
  plt.hist(y, bins=np.linspace(np.min(y), np.max(y),100), density=True)
  plt.title('Histogram')
  plt.xlabel('y')
  plt.ylabel('density')
  plt.savefig('histogram.pdf')
  plt.show()

Plot samples from **multivariate g-and-k distribution**:

In [None]:
def plot_mvgandk(y,dim,fig_size,theta):

  # check generator
  rng = np.arange(0.01,0.99,0.01)
  z_rng = stats.norm.ppf(rng, loc=0, scale=1)
  z_rng = np.squeeze(np.dstack([z_rng]*y.shape[1]))
  plt.figure(figsize=fig_size)
  plt.plot(rng,gen_mvgandk(z_rng,theta)[:,dim])
  plt.title('Quantile function')
  plt.xlabel(r'$u$')
  plt.ylabel(r'$G_\theta(u)$')
  plt.savefig('generator.pdf')
  plt.show()

  # plot histogram
  plt.figure(figsize=fig_size)
  plt.hist(y[:,dim], bins=np.linspace(np.min(y), np.max(y),100), density=True)
  plt.title('Histogram')
  plt.xlabel('y')
  plt.ylabel('density')
  plt.savefig('histogram.pdf')
  plt.show()

Plot samples from **SV model**:

In [None]:
def plot_sv(y,fig_size):

  # define color map
  c = np.arange(1, 11)
  norm = mpl.colors.Normalize(vmin=c.min(), vmax=c.max())
  cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Blues)
  cmap.set_array([])

  # plot 10 realisations
  plt.figure(figsize=fig_size)
  for j in range(10):
    plt.plot(y[j+10,:], c=cmap.to_rgba(j + 1), label=r'$y_{'+str(j+1)+'}$')
  plt.title('Stochastic Volatility Model')
  plt.ylabel(r'$y$')
  plt.xlabel('t')
  plt.legend(bbox_to_anchor=(1.05, 0.92),fancybox=True)
  plt.savefig('stochastic_vol.pdf', bbox_inches="tight")
  plt.show()

Plot samples from **banana-shaped distribution**:

In [None]:
def plot_banana(y,fig_size):
  plt.figure(figsize=fig_size)
  plt.scatter(y[:,0],y[:,1])
  plt.title('Scatterplot')
  plt.xlabel(r'$x$')
  plt.ylabel(r'$x^2+y$')
  plt.savefig('scatter.pdf')
  plt.show

## Plot optimisation results

Plot MMD$^2$ for MC, QMC and RQMC in one plot:

In [None]:
def plot_loss(it,d,fig_size,loss,label):
  fig = plt.figure(figsize=fig_size)
  cmap = plt.get_cmap("tab10")
  plt.plot(np.abs(loss[:it]),label=label)
  plt.ylim(1e-7,1)
  plt.yscale('log', basey=10)
  plt.xlabel('Descent steps (t)')
  plt.ylabel(r'$| \widehat{MMD}^2 |$')
  plt.title('MMD'+r'$^2$'+' loss')
  plt.legend(loc='lower right')
  plt.savefig('MMD_loss_d='+str(d)+'.pdf')
  plt.show()

Plot estimates against iterations for MC, QMC and RQMC in one plot:

In [None]:
def plot_estimates(cols,p,d,fig_size,max_it,theta,theta_star,label):

  # number of rows required
  rows = p // cols 
  rows += p % cols

  # position index
  position = range(1,p + 1)

  # main figure
  fig = plt.figure(figsize=fig_size)
  cmap = plt.get_cmap("tab10")

  # add subplots
  for j in range(p):
    ax = fig.add_subplot(rows,cols,position[j])
    ax.plot(range(max_it+1),theta[:,j], label=label, color=cmap(0))
    ax.plot(range(max_it+1),theta_star[j]*np.ones(max_it+1),linestyle='--', color='grey')
    ax.set_xlabel('Descent steps (t)')
    ax.set_ylabel(r'$\hat{\theta}_{'+str(j+1)+'}$')
    ax.set_title('Estimate for '+ r'$\theta_{'+str(j+1)+'}$')
    ax.legend(loc='lower right')
    plt.savefig('Estimates_d='+str(d)+'.pdf')

  plt.tight_layout()
  plt.show()

Plot of MSE against iterations for MC, QMC and RQMC in one plot:

In [None]:
def plot_mse(cols,p,d,fig_size,max_it,mse1,mse3,label1,label3,model):

  # number of rows required
  rows = p // cols 
  rows += p % cols

  # position index
  position = range(1,p + 1)

  # main figure
  fig = plt.figure(figsize=fig_size)
  cmap = plt.get_cmap("tab10")

  # add subplots
  for j in range(p):
    ax = fig.add_subplot(rows,cols,position[j])
    ax.plot(range(max_it-1),mse1[:,j], label=label1, color=cmap(0))
    ax.plot(range(max_it-1),mse3[:,j], label=label3, color=cmap(2))
    if model == 'gaussian':
      ax.set_ylim(1e-3,1)
    if model == 'gandk':
      ax.set_ylim(1e-2,6.7)
    plt.yscale('log', basey=10)
    ax.set_xlabel('Descent steps (t)')
    ax.set_ylabel('MSE')
    ax.set_title('MSE for '+ r'$\theta_{'+str(j+1)+'}$')
    ax.legend()
    plt.savefig('MSE_d='+str(d)+'.pdf')

  plt.tight_layout()
  plt.show()

## Convergence of MMD$^2$

Plot MMD$^2$ against $n$ with closed form expressions for Gaussian location model and uniform distribution:

In [None]:
def plot_mmd_conv_closedform(d,fig_size,mmd1,mmd2,mmd3,label1,label2,label3,mmd1_min,mmd1_max,mmd3_min,mmd3_max,model,mmd_all_mc=None,maxall=15):

  plt.figure(figsize=fig_size)
  cmap = plt.get_cmap("tab10")
  plt.plot(n, np.abs(mmd1), linewidth=2, label=label1)
  plt.plot(n, np.abs(mmd2), linewidth=2, label=label2)
  plt.plot(n, np.abs(mmd3), linewidth=2, label=label3)
  plt.errorbar(n, np.abs(mmd1), yerr=[np.array(np.abs(mmd1))-np.array(mmd1_min),np.array(mmd1_max)-np.array(np.abs(mmd1))], fmt='.', color=cmap(0), capsize=10, elinewidth=1.5)
  plt.errorbar(n, np.abs(mmd3), yerr=[np.array(np.abs(mmd3))-np.array(mmd3_min),np.array(mmd3_max)-np.array(np.abs(mmd3))], fmt='.', color=cmap(2), capsize=10, elinewidth=1.5)
  if mmd_all_mc is not None:
    if maxall<len(mmd_all_mc[0]):
      num = maxall
    else:
      num = len(mmd_all_mc[0])
    for i in range(num):
      plt.plot(n, np.abs(np.array(mmd_all_mc)[:,i]),color=cmap(0),alpha=0.3)
  if model == 'unif':
    plt.ylim(1e-12,0.45)
  if model == 'gaussian':
    plt.ylim(1e-12,0.45)
  plt.xscale('log', basex=2)
  plt.yscale('log', basey=10)
  plt.xlabel('n')
  plt.ylabel(r'$|\widehat{MMD}^2(P||P^n)|$')
  #plt.title(r'$\widehat{MMD}^2$'+' against number of samples used')
  plt.legend()
  if mmd_all_mc is not None:
    plt.savefig('MMD_against_n_d='+str(d)+'_allmc.pdf')
  else:
    plt.savefig('MMD_against_n_d='+str(d)+'.pdf')
  plt.show() 

Plot MMD$^2$ against $d$ with closed form expressions for Gaussian location model and uniform distribution:

In [None]:
def plot_mmd_conv_closedform_d(d,fig_size,mmd1,mmd2,mmd3,label1,label2,label3,mmd1_min,mmd1_max,mmd3_min,mmd3_max,model,mmd_all_mc=None,maxall=15):

  plt.figure(figsize=fig_size)
  cmap = plt.get_cmap("tab10")
  plt.plot(d, np.abs(mmd1), linewidth=2, label=label1)
  plt.plot(d, np.abs(mmd2), linewidth=2, label=label2)
  plt.plot(d, np.abs(mmd3), linewidth=2, label=label3)
  plt.errorbar(d, np.abs(mmd1), yerr=[np.array(np.abs(mmd1))-np.array(mmd1_min),np.array(mmd1_max)-np.array(np.abs(mmd1))], fmt='.', color=cmap(0), capsize=10, elinewidth=1.5)
  plt.errorbar(d, np.abs(mmd3), yerr=[np.array(np.abs(mmd3))-np.array(mmd3_min),np.array(mmd3_max)-np.array(np.abs(mmd3))], fmt='.', color=cmap(2), capsize=10, elinewidth=1.5)
  if mmd_all_mc is not None:
    if maxall<len(mmd_all_mc[0]):
      num = maxall
    else:
      num = len(mmd_all_mc[0])
    for i in range(num):
      plt.plot(d, np.abs(np.array(mmd_all_mc)[:,i]),color=cmap(0),alpha=0.3)
  if model == 'unif':
    plt.ylim(1e-12,0.45)
  if model == 'gaussian':
    plt.ylim(1e-12,0.45)
  #plt.xscale('log', basex=2)
  plt.yscale('log', basey=10)
  plt.xlabel('d')
  plt.ylabel(r'$|\widehat{MMD}^2(P||P^n)|$')
  #plt.title(r'$\widehat{MMD}^2$'+' against number of samples used')
  plt.legend()
  if mmd_all_mc is not None:
    plt.savefig('MMD_against_n_d='+str(d)+'_allmc.pdf')
  else:
    plt.savefig('MMD_against_n_d='+str(d)+'.pdf')
  plt.show() 

Plot MMD$^2$ against n for MC, QMC and RQMC:

In [None]:
def plot_mmd_conv(d,fig_size,mmd1,mmd2,mmd3,label1,label2,label3,mmd1_min,mmd1_max,mmd3_min,mmd3_max,model,stat_type,mmd_all_mc=None,maxall=15):

  plt.figure(figsize=fig_size)
  cmap = plt.get_cmap("tab10")
  plt.plot(n, np.abs(mmd1), linewidth=2, label=label1)
  #plt.plot(n, np.abs(mmd2), linewidth=2, label=label2)
  plt.plot(n, np.abs(mmd3), linewidth=2, label=label3, color=cmap(2))
  plt.errorbar(n, np.abs(mmd1), yerr=[np.array(np.abs(mmd1))-np.array(mmd1_min),np.array(mmd1_max)-np.array(np.abs(mmd1))], fmt='.', color=cmap(0), capsize=10, elinewidth=1.5)
  plt.errorbar(n, np.abs(mmd3), yerr=[np.array(np.abs(mmd3))-np.array(mmd3_min),np.array(mmd3_max)-np.array(np.abs(mmd3))], fmt='.', color=cmap(2), capsize=10, elinewidth=1.5)
  if mmd_all_mc is not None:
    if maxall<len(mmd_all_mc[0]):
      num = maxall
    else:
      num = len(mmd_all_mc[0])
    for i in range(num):
      plt.plot(n, np.abs(np.array(mmd_all_mc)[:,i]),color=cmap(0),alpha=0.3)
  if model == 'unif':
    plt.ylim(1e-12,0.8)
  if model == 'gaussian':
    plt.ylim(1e-12,0.8)
  if model == 'beta':
    #plt.ylim(1e-9,1e-3)
    plt.ylim(1e-8,0.3)
  if model == 'mvgandk':
    plt.ylim(1e-12,0.8)
  if model == 'gandk':
    plt.ylim(1e-7,0.25)
  if model == 'bibeta':
    plt.ylim(1e-12,0.3)
  if model == 'banana':
    plt.ylim(1e-8,0.25)
  plt.xscale('log', basex=2)
  plt.yscale('log', basey=10)
  plt.xlabel('n=m')
  plt.ylabel(r'$|\widehat{MMD}^2_V(P^n||P^m)|$')
  #plt.title(r'$\widehat{MMD}^2$'+' against number of samples used')
  plt.legend()
  if mmd_all_mc is not None:
    plt.savefig('MMD_against_n_d='+str(d)+'_'+stat_type+'stat_allmc.pdf')
  else:
    plt.savefig('MMD_against_n_d='+str(d)+'_'+stat_type+'stat.pdf')
  plt.show() 

Plot MMD$^2$ against d for MC, QMC and RQMC:

In [None]:
def plot_mmd_conv_d(d,fig_size,mmd1,mmd2,mmd3,label1,label2,label3,mmd1_min,mmd1_max,mmd3_min,mmd3_max,model,stat_type,mmd_all_mc=None,maxall=15):

  plt.figure(figsize=fig_size)
  cmap = plt.get_cmap("tab10")
  plt.plot(d, np.abs(mmd1), linewidth=2, label=label1)
  plt.plot(d, np.abs(mmd2), linewidth=2, label=label2)
  plt.plot(d, np.abs(mmd3), linewidth=2, label=label3)
  plt.errorbar(d, np.abs(mmd1), yerr=[np.array(np.abs(mmd1))-np.array(mmd1_min),np.array(mmd1_max)-np.array(np.abs(mmd1))], fmt='.', color=cmap(0), capsize=10, elinewidth=1.5)
  plt.errorbar(d, np.abs(mmd3), yerr=[np.array(np.abs(mmd3))-np.array(mmd3_min),np.array(mmd3_max)-np.array(np.abs(mmd3))], fmt='.', color=cmap(2), capsize=10, elinewidth=1.5)
  if mmd_all_mc is not None:
    if maxall<len(mmd_all_mc[0]):
      num = maxall
    else:
      num = len(mmd_all_mc[0])
    for i in range(num):
      plt.plot(d, np.abs(np.array(mmd_all_mc)[:,i]),color=cmap(0),alpha=0.3)
  if model == 'unif':
    plt.ylim(1e-12,0.45)
  if model == 'gaussian':
    plt.ylim(1e-12,0.45)
  if model == 'mvgandk':
    plt.ylim(1e-8,1e-3)
  #plt.xscale('log', basex=2)
  plt.yscale('log', basey=10)
  plt.xlabel('d')
  plt.ylabel(r'$|\widehat{MMD}^2(P^n||P^m)|$')
  #plt.title(r'$\widehat{MMD}^2$'+' against number of samples used')
  plt.legend()
  if mmd_all_mc is not None:
    plt.savefig('MMD_against_d_n=8192'+'_'+stat_type+'stat_allmc.pdf')
  else:
    plt.savefig('MMD_against_d'+'_'+stat_type+'stat_n=8192.pdf')
  plt.show() 

### Convergence of Wasserstein distance

Plot Wasserstein distance against n for MC, QMC and RQMC:

In [None]:
def plot_W_conv(d,fig_size,W1,W2,W3,label1,label2,label3,W1_min,W1_max,W3_min,W3_max):

  plt.figure(figsize=fig_size)
  cmap = plt.get_cmap("tab10")
  plt.plot(n, np.abs(W1), linewidth=2, label=label1)
  plt.plot(n, np.abs(W2), linewidth=2, label=label2)
  plt.plot(n, np.abs(W3), linewidth=2, label=label3)
  plt.errorbar(n, np.abs(W1), yerr=[np.array(np.abs(W1))-np.array(W1_min),np.array(W1_max)-np.array(np.abs(W1))], fmt='.', color=cmap(0), capsize=10, elinewidth=1.5)
  plt.errorbar(n, np.abs(W3), yerr=[np.array(np.abs(W3))-np.array(W3_min),np.array(W3_max)-np.array(np.abs(W3))], fmt='.', color=cmap(2), capsize=10, elinewidth=1.5)
  #if d==1:
  #  plt.ylim(1e-3,0.35)
  #if d==5:
  #  plt.ylim(0.05,0.35)
  plt.xscale('log', basex=2)
  plt.yscale('log', basey=10)
  plt.xlabel('n')
  plt.ylabel(r'$|W_c(P,P^n)|$')
  plt.title('Wasserstein distance against the number of samples used - '+r'$c(x,y)=||x-y||$')
  plt.legend()
  plt.savefig('W_against_n_d='+str(d)+'.pdf')
  plt.show() 

In [None]:
def plot_W_conv_d(d,fig_size,W1,W2,W3,label1,label2,label3,W1_min,W1_max,W3_min,W3_max):

  plt.figure(figsize=fig_size)
  cmap = plt.get_cmap("tab10")
  plt.plot(d, np.abs(W1), linewidth=2, label=label1)
  plt.plot(d, np.abs(W2), linewidth=2, label=label2)
  plt.plot(d, np.abs(W3), linewidth=2, label=label3)
  plt.errorbar(d, np.abs(W1), yerr=[np.array(np.abs(W1))-np.array(W1_min),np.array(W1_max)-np.array(np.abs(W1))], fmt='.', color=cmap(0), capsize=10, elinewidth=1.5)
  plt.errorbar(d, np.abs(W3), yerr=[np.array(np.abs(W3))-np.array(W3_min),np.array(W3_max)-np.array(np.abs(W3))], fmt='.', color=cmap(2), capsize=10, elinewidth=1.5)
  #if d==1:
  #  plt.ylim(1e-3,0.35)
  #if d==5:
  #  plt.ylim(0.05,0.35)
  plt.yscale('log', basey=10)
  plt.xlabel('d')
  plt.ylabel(r'$|W_c(P^n,P^m)|$')
  plt.title('Wasserstein distance against the number of samples used - '+r'$c(x,y)=||x-y||$')
  plt.legend()
  plt.savefig('W_against_d_n=8192.pdf')
  plt.show()

### Convergence of Sinkhorn loss

Plot Sinkhorn loss against n for MC, QMC and RQMC:

In [None]:
def plot_sink_conv(d,fig_size,sink1,sink2,sink3,label1,label2,label3,sink1_min,sink1_max,sink3_min,sink3_max):

  plt.figure(figsize=fig_size)
  cmap = plt.get_cmap("tab10")
  plt.plot(n, np.abs(sink1), linewidth=2, label=label1)
  plt.plot(n, np.abs(sink2), linewidth=2, label=label2)
  plt.plot(n, np.abs(sink3), linewidth=2, label=label3)
  plt.errorbar(n, np.abs(sink1), yerr=[np.squeeze(np.array(np.abs(sink1)))-np.squeeze(np.array(sink1_min)),np.squeeze(np.array(sink1_max))-np.squeeze(np.array(np.abs(sink1)))], fmt='.', color=cmap(0), capsize=10, elinewidth=1.5)
  plt.errorbar(n, np.abs(sink3), yerr=[np.squeeze(np.array(np.abs(sink3)))-np.squeeze(np.array(sink3_min)),np.squeeze(np.array(sink3_max))-np.squeeze(np.array(np.abs(sink3)))], fmt='.', color=cmap(2), capsize=10, elinewidth=1.5)
  #if d==1:
  #  plt.ylim(0.000005,0.25)
  #if d==5:
  #  plt.ylim(0.005,0.25)
  plt.xscale('log', basex=2)
  plt.yscale('log', basey=10)
  plt.xlabel('n')
  plt.ylabel(r'$|\overline{W}_{c,\epsilon}(P,P^n)|$')
  plt.title('Sinkhorn loss against the number of samples used - '+r'$c(x,y)=||x-y||^2$')
  plt.legend()
  plt.savefig('sink_against_n_d='+str(d)+'.pdf')
  plt.show() 

In [None]:
def plot_sink_conv_d(d,fig_size,sink1,sink2,sink3,label1,label2,label3,sink1_min,sink1_max,sink3_min,sink3_max):

  plt.figure(figsize=fig_size)
  cmap = plt.get_cmap("tab10")
  plt.plot(d, np.abs(sink1), linewidth=2, label=label1)
  plt.plot(d, np.abs(sink2), linewidth=2, label=label2)
  plt.plot(d, np.abs(sink3), linewidth=2, label=label3)
  plt.errorbar(d, np.abs(sink1), yerr=[np.squeeze(np.array(np.abs(sink1)))-np.squeeze(np.array(sink1_min)),np.squeeze(np.array(sink1_max))-np.squeeze(np.array(np.abs(sink1)))], fmt='.', color=cmap(0), capsize=10, elinewidth=1.5)
  plt.errorbar(d, np.abs(sink3), yerr=[np.squeeze(np.array(np.abs(sink3)))-np.squeeze(np.array(sink3_min)),np.squeeze(np.array(sink3_max))-np.squeeze(np.array(np.abs(sink3)))], fmt='.', color=cmap(2), capsize=10, elinewidth=1.5)
  #if d==1:
  #  plt.ylim(0.000005,0.25)
  #if d==5:
  #  plt.ylim(0.005,0.25)
  plt.yscale('log', basey=10)
  plt.xlabel('d')
  plt.ylabel(r'$|\overline{W}_{c,\epsilon}(P,P^n)|$')
  plt.title('Sinkhorn loss against the number of samples used - '+r'$c(x,y)=||x-y||^2$')
  plt.legend()
  plt.savefig('sink_against_n_d='+str(d)+'.pdf')
  plt.show() 

### Convergence of sliced Wasserstein distance

Plot sliced Wasserstein distance against n for MC, QMC and RQMC:

In [None]:
def plot_slicedW_conv(d,fig_size,W1,W2,W3,label1,label2,label3,W1_min,W1_max,W3_min,W3_max):

  plt.figure(figsize=fig_size)
  cmap = plt.get_cmap("tab10")
  plt.plot(n, np.abs(W1), linewidth=2, label=label1)
  plt.plot(n, np.abs(W2), linewidth=2, label=label2)
  plt.plot(n, np.abs(W3), linewidth=2, label=label3)
  plt.errorbar(n, np.abs(W1), yerr=[np.array(np.abs(W1))-np.array(W1_min),np.array(W1_max)-np.array(np.abs(W1))], fmt='.', color=cmap(0), capsize=10, elinewidth=1.5)
  plt.errorbar(n, np.abs(W3), yerr=[np.array(np.abs(W3))-np.array(W3_min),np.array(W3_max)-np.array(np.abs(W3))], fmt='.', color=cmap(2), capsize=10, elinewidth=1.5)
  #if d==1:
  #  plt.ylim(1e-3,0.35)
  #if d==5:
  #  plt.ylim(0.05,0.35)
  plt.xscale('log', basex=2)
  plt.yscale('log', basey=10)
  plt.xlabel('n')
  plt.ylabel(r'$|W_c^p(P,P^n)|$')
  plt.title('Sliced Wasserstein distance against n')
  plt.legend()
  plt.savefig('slicedW_against_n_d='+str(d)+'.pdf')
  plt.show() 

In [None]:
def plot_slicedW_conv_d(d,fig_size,W1,W2,W3,label1,label2,label3,W1_min,W1_max,W3_min,W3_max):

  plt.figure(figsize=fig_size)
  cmap = plt.get_cmap("tab10")
  plt.plot(d, np.abs(W1), linewidth=2, label=label1)
  plt.plot(d, np.abs(W2), linewidth=2, label=label2)
  plt.plot(d, np.abs(W3), linewidth=2, label=label3)
  plt.errorbar(d, np.abs(W1), yerr=[np.array(np.abs(W1))-np.array(W1_min),np.array(W1_max)-np.array(np.abs(W1))], fmt='.', color=cmap(0), capsize=10, elinewidth=1.5)
  plt.errorbar(d, np.abs(W3), yerr=[np.array(np.abs(W3))-np.array(W3_min),np.array(W3_max)-np.array(np.abs(W3))], fmt='.', color=cmap(2), capsize=10, elinewidth=1.5)
  #if d==1:
  #  plt.ylim(1e-3,0.35)
  #if d==5:
  #  plt.ylim(0.05,0.35)
  plt.yscale('log', basey=10)
  plt.xlabel('d')
  plt.ylabel(r'$|W_c^p(P^n,P^m)|$')
  plt.title('Sliced Wasserstein distance against d')
  plt.legend()
  plt.savefig('slicedW_against_d_n=8192.pdf')
  plt.show()