In [1]:
import torch
import theseus as th
import numpy as np

arr_size = int(1E3)

torch.manual_seed(0)
data_x = torch.tensor(np.stack([np.linspace(0, 1, 100)]*arr_size), dtype=torch.float32)
data_coeffs = torch.tensor(np.random.random(data_x.shape[0]), dtype=torch.float32)
data_y = data_coeffs.unsqueeze(-1)*torch.exp(data_x)

In [2]:
x = th.Variable(data_x, name='x')
y = th.Variable(data_y, name='y')

a = th.Vector(1, name='a')

In [3]:
def error_fn(optim_vars, aux_vars):
    a, = optim_vars
    x, y = aux_vars
    est = a.tensor*torch.exp(x.tensor)
    err = y.tensor - est
    return err

optim_vars = a,
aux_vars = x, y
cost_function = th.AutoDiffCostFunction(
    optim_vars, error_fn, 100,
    aux_vars=aux_vars, name='cost_fn'
)
objective = th.Objective()
objective.add(cost_function)
optimizer = th.GaussNewton(
    objective,
    max_iterations=15,
    step_size=0.5
)
theseus_optim = th.TheseusLayer(optimizer)

In [4]:
theseus_inputs = {
    'x': data_x,
    'y': data_y,
    'a': torch.ones((arr_size, 1))
}
with torch.no_grad():
    updated_inputs, info = theseus_optim.forward(
        theseus_inputs,
        optimizer_kwargs={'track_best_solution': True, 'verbose': True}
    )
print('Best solution:', info.best_solution)
coeffs_est = np.array(info.best_solution['a'])
coeffs_exact = np.array(data_coeffs)

Nonlinear optimizer. Iteration: 0. Error: 54.42876434326172
Nonlinear optimizer. Iteration: 1. Error: 13.607196807861328
Nonlinear optimizer. Iteration: 2. Error: 3.4017996788024902
Nonlinear optimizer. Iteration: 3. Error: 0.8504500389099121
Nonlinear optimizer. Iteration: 4. Error: 0.212612584233284
Nonlinear optimizer. Iteration: 5. Error: 0.05315316095948219
Nonlinear optimizer. Iteration: 6. Error: 0.013288293965160847
Nonlinear optimizer. Iteration: 7. Error: 0.0033220762852579355
Nonlinear optimizer. Iteration: 8. Error: 0.0008305194787681103
Nonlinear optimizer. Iteration: 9. Error: 0.00020763030624948442
Nonlinear optimizer. Iteration: 10. Error: 5.1907663873862475e-05
Nonlinear optimizer. Iteration: 11. Error: 1.2976941434317268e-05
Nonlinear optimizer. Iteration: 12. Error: 3.244261961299344e-06
Nonlinear optimizer. Iteration: 13. Error: 8.110659450721869e-07
Nonlinear optimizer. Iteration: 14. Error: 2.027695416018105e-07
Nonlinear optimizer. Iteration: 15. Error: 5.0693220