In [None]:
### Imports
%load_ext autoreload
%autoreload 2

import os

os.chdir("..")
import torchquad
import numpy as np
import torch
import matplotlib.pyplot as plt

%matplotlib notebook

In [None]:
torchquad.set_log_level("INFO")
torch.set_printoptions(precision=8)

In [None]:
def tf1(x):
    return torch.sin(x[:,0]**4 * 4 * np.pi) * x[:,1]

gt_tf1 = 0.08224537642165943298735264808181737465596841727271552769029466418570857556005241752395781124068434672 

# https://www.wolframalpha.com/input?i=integrate+-%28x-0.5%29%5E4++*+sin%28%28y-0.5%29%5E2*16%29+from+x%3D0+to+x%3D1+from+y%3D0+to+y%3D1

def tf2(x):
    return -(x[:,0] - 0.5)**4 * torch.sin((x[:,1] - 0.5)**2*24)

gt_tf2 = -0.002237582933156115933795974818428222657210309163702382340277410905010095369201654703359706808120094476

def print_error(gt,val):
    print(f"Correct val: {gt:.8e}, TQ={val:.8e}, AbsError={abs(val-gt):.4e},  RelError={abs(val-gt) / abs(gt):.4e}")

In [None]:
def plot_adaptive_grid(grid, dpi=100):
    """Plots the adaptive grid and corresponding function value.

    Args:
        grid (AdaptiveGrid): AdaptiveGrid of evaluated function
        dpi (int, optional): Plot dpi. Defaults to 100.
    """

    fig = plt.figure(figsize = (8,6),dpi=dpi)
    points = None
    fvals = None
    for subdomain in grid.subdomains:
        if points is None:
            points = subdomain.points.cpu().numpy()
            fvals = subdomain.fval.cpu().numpy()
        else:    
            points = np.concatenate([points,subdomain.points.cpu().numpy()])
            fvals = np.concatenate([fvals,subdomain.fval.cpu().numpy()])
    ax = fig.add_subplot(111, projection='3d')
#     ax.plot_trisurf(points[:,0], points[:,1], fvals, color='white', edgecolors='grey', alpha=0.5)
    ax.scatter(points[:,0], points[:,1], fvals, s = 0.1, c='red')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Function Value')
    plt.show()


In [None]:
f = tf2
gt = gt_tf2

In [None]:
trap = torchquad.Boole()
val = trap.integrate(f,dim=2,N=10000,integration_domain=[[0,1],[0,1]])
print_error(gt,val)

In [None]:
at = torchquad.AdaptiveBoole()
val = at.integrate(f,dim=2,N=10000,integration_domain=[[0,1],[0,1]],subdomains_per_dim=8,max_refinement_level=6)
print_error(gt,val)

In [None]:
plot_adaptive_grid(at._grid,dpi=150)

In [None]:
for subdomain in at._grid.subdomains:
    subdomain.print_subdomain_properties()