Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failure to fit complex data. #618

Open
zhangrentu opened this issue Oct 20, 2023 · 29 comments
Open

Failure to fit complex data. #618

zhangrentu opened this issue Oct 20, 2023 · 29 comments

Comments

@zhangrentu
Copy link

❓ Questions and Help

When attempting to fit some simple regression problems, such as y = ax + b, where both x and y are complex numbers, I encountered errors. Could you please advise on methods or modifications to resolve this issue?
image

@luisenp
Copy link
Contributor

luisenp commented Oct 20, 2023

Hi @zhangrentu. We have never tested Theseus with complex numbers, so I'm not really sure what parts of it would need to be modified. Certainly our custom linear solvers will not work with this type of data, so in the best possible case you are limited to use DenseLinearization and either CholeskyDenseSolver or LUDenseSolver, but even with this there is no guarantee that it will work.

If you send me a code snippet of code I can try to run and perhaps suggest what modifications would be necessary.

@luisenp
Copy link
Contributor

luisenp commented Oct 23, 2023

Hi @zhangrentu. I made a number of tweaks and removed some of our dtype constraints (add torch.complex32 here), and eventually failed with a torch error when trying to run your error_fn(). You can see my changes below

import torch
import theseus as th


def y_model(x, a, b, c):
   return a * torch.exp((-1j * b + c) * x)  # y = a * exp((-1ja + b) * x)


def generate_data(num_points=512, a=1, b=2, c=3):
   data_x = torch.linspace(0, 1, num_points).view(1, -1)
   data_y = y_model(data_x, a, b, c)
   return data_x, data_y


def read_data():
   data_x, data_y_clean = generate_data()
   return (
       data_x,
       data_y_clean,
       1 * torch.ones(1, 1),
       2 * torch.ones(1, 1),
       3 * torch.ones(1, 1),
   )


x_true, y_true, a_true, b_true, c_true = read_data()
x = th.Variable(torch.randn_like(x_true), name="x")
y = th.Variable(y_true, name="y")
a = th.Vector(1, name="a")  # a manifold subclass of Variable for optim_vars
b = th.Vector(1, name="b")  # a manifold subclass of Variable for optim_vars
c = th.Vector(1, name="c")  # a manifold subclass of Variable for optim_vars
for v in [x, y, a, b, c]:
   v.to(dtype=torch.complex32)


def error_fn(optim_vars, aux_vars):  # returns y - a * exp((-1j*a + b) * x)
   x, y = aux_vars
   return y.tensor - optim_vars[0].tensor * torch.exp(
       (-1j * optim_vars[1].tensor + optim_vars[2].tensor) * x.tensor
   )


objective = th.Objective(dtype=torch.complex32)
w = th.ScaleCostWeight(1.0)
w.to(dtype=torch.complex32)
cost_function = th.AutoDiffCostFunction(
   [a, b, c], error_fn, y_true.shape[1], aux_vars=[x, y], cost_weight=w
)
objective.add(cost_function)
layer = th.TheseusLayer(th.GaussNewton(objective, max_iterations=10))

phi = torch.nn.Parameter(x_true + 0.1 * torch.ones_like(x_true))
outer_optimizer = torch.optim.Adam([phi], lr=0.001)
input_tensors = {
   "x": phi.clone(),
   "a": 0.5 * torch.ones(1, 1),
   "b": torch.ones(1, 1),
   "c": torch.ones(1, 1),
}
input_tensors = {k: t.to(dtype=torch.complex32) for k, t in input_tensors.items()}
for epoch in range(20):
   solution, info = layer.forward(
       input_tensors=input_tensors,
       optimizer_kwargs={"backward_mode": "implicit"},
   )
   outer_loss1 = torch.nn.functional.mse_loss(solution["a"], a_true)
   outer_loss2 = torch.nn.functional.mse_loss(solution["b"], b_true)
   outer_loss3 = torch.nn.functional.mse_loss(solution["c"], c_true)
   outer_loss = outer_loss1 + outer_loss2 + outer_loss3
   outer_loss.backward()
   outer_optimizer.step()
   print("Outer loss: ", outer_loss.item())

@zhangrentu
Copy link
Author

Thank you very much for your response. I added the data attribute (torch.complex64) and performed calculations with PyTorch, but I found that the results are not converging. Do you have any suggestions?
image
image

@luisenp
Copy link
Contributor

luisenp commented Oct 24, 2023

Can you share the new version of your script?

@zhangrentu
Copy link
Author

Thanks. The following is the new version, with the added dtype constraints (torch.complex64)

import torch
import theseus as th


def y_model(x, a, b, c):
   return a * torch.exp((-1j * b + c) * x)  # y = a * exp((-1ja + b) * x)


def generate_data(num_points=512, a=1, b=2, c=3):
   data_x = torch.linspace(0, 1, num_points).view(1, -1)
   data_y = y_model(data_x, a, b, c)
   return data_x, data_y


def read_data():
   data_x, data_y_clean = generate_data()
   return (
       data_x,
       data_y_clean,
       1 * torch.ones(1, 1),
       2 * torch.ones(1, 1),
       3 * torch.ones(1, 1),
   )


x_true, y_true, a_true, b_true, c_true = read_data()
x = th.Variable(torch.randn_like(x_true), name="x")
y = th.Variable(y_true, name="y")
a = th.Vector(1, name="a")  # a manifold subclass of Variable for optim_vars
b = th.Vector(1, name="b")  # a manifold subclass of Variable for optim_vars
c = th.Vector(1, name="c")  # a manifold subclass of Variable for optim_vars
for v in [x, y, a, b, c]:
   v.to(dtype=torch.complex64)


def error_fn(optim_vars, aux_vars):  # returns y - a * exp((-1j*a + b) * x)
   x, y = aux_vars
   return y.tensor - optim_vars[0].tensor * torch.exp(
       (-1j * optim_vars[1].tensor + optim_vars[2].tensor) * x.tensor
   )


objective = th.Objective(dtype=torch.complex64)
w = th.ScaleCostWeight(1.0)
w.to(dtype=torch.complex64)
cost_function = th.AutoDiffCostFunction(
   [a, b, c], error_fn, y_true.shape[1], aux_vars=[x, y], cost_weight=w
)
objective.add(cost_function)
optimizer = th.LevenbergMarquardt(
    objective,
    th.CholeskyDenseSolver,
    max_iterations=10,
    step_size=0.001,
)
layer = th.TheseusLayer(optimizer)

phi = torch.nn.Parameter(x_true + 0.1 * torch.ones_like(x_true))
outer_optimizer = torch.optim.Adam([phi], lr=0.001)
input_tensors = {
   "x": phi.clone(),
   "a": 0.5 * torch.ones(1, 1),
   "b": torch.ones(1, 1),
   "c": torch.ones(1, 1),
}
input_tensors = {k: t.to(dtype=torch.complex64) for k, t in input_tensors.items()}
for epoch in range(20):
   solution, info = layer.forward(
       input_tensors=input_tensors,
       optimizer_kwargs={"backward_mode": "implicit"},
   )
   outer_loss1 = torch.nn.functional.mse_loss(solution["a"], a_true)
   outer_loss2 = torch.nn.functional.mse_loss(solution["b"], b_true)
   outer_loss3 = torch.nn.functional.mse_loss(solution["c"], c_true)
   outer_loss = outer_loss1 + outer_loss2 + outer_loss3
   outer_loss.backward()
   outer_optimizer.step()
   print("Outer loss: ", outer_loss.item())

@zhangrentu
Copy link
Author

In the process of attempting parallel network(torch.nn.DataParallel) computations, the following error was encountered. A model composed of th.TheseusLayer and the rest of the network structures is placed on the same device. Do you know what problem is causing this?
image
image
image

@luisenp
Copy link
Contributor

luisenp commented Nov 3, 2023

Hi @zhangrentu. The error in your last comment should now be fixed after #623 is merged.

@luisenp
Copy link
Contributor

luisenp commented Nov 3, 2023

I took a quick look at your script. One change I had to make was to set autograd_mode="dense" in AutoDiffCostFunction (looks like vmap doesn't support complex data). I did see problems in the linear system used to compute gradients for our optimizers, but I don't have a good intuition on how to solve this because I'm not too familiar with complex data. Perhaps you need to change your error function to return some real valued error vector that depends on your complex inputs, but I'm not sure.

@luisenp
Copy link
Contributor

luisenp commented Nov 3, 2023

I'm able to run the optimizer if I use a very high damping, set the error to return .abs() of the previous error, and some slight changes to the loss (although I don't see the other loss changing at all).

You can see my changes here

import torch
import theseus as th


def y_model(x, a, b, c):
    return a * torch.exp((-1j * b + c) * x)  # y = a * exp((-1ja + b) * x)


def generate_data(num_points=4, a=1, b=2, c=3):
    data_x = torch.linspace(0, 1, num_points).view(1, -1)
    data_y = y_model(data_x, a, b, c)
    return data_x, data_y


def read_data():
    data_x, data_y_clean = generate_data()
    return (
        data_x,
        data_y_clean,
        1 * torch.ones(1, 1),
        2 * torch.ones(1, 1),
        3 * torch.ones(1, 1),
    )


x_true, y_true, a_true, b_true, c_true = read_data()
x = th.Variable(torch.randn_like(x_true), name="x")
y = th.Variable(y_true, name="y")
a = th.Vector(1, name="a")  # a manifold subclass of Variable for optim_vars
b = th.Vector(1, name="b")  # a manifold subclass of Variable for optim_vars
c = th.Vector(1, name="c")  # a manifold subclass of Variable for optim_vars
for v in [x, y, a, b, c]:
    v.to(dtype=torch.complex64)


def error_fn(optim_vars, aux_vars):  # returns y - a * exp((-1j*a + b) * x)
    x, y = aux_vars
    return (
        y.tensor
        - optim_vars[0].tensor
        * torch.exp((-1j * optim_vars[1].tensor + optim_vars[2].tensor) * x.tensor)
    ).abs()


objective = th.Objective(dtype=torch.complex64)
w = th.ScaleCostWeight(1.0)
w.to(dtype=torch.complex64)
cost_function = th.AutoDiffCostFunction(
    [a, b, c],
    error_fn,
    y_true.shape[1],
    aux_vars=[x, y],
    cost_weight=w,
    autograd_mode="dense",
)
objective.add(cost_function)
optimizer = th.LevenbergMarquardt(
    objective,
    th.CholeskyDenseSolver,
    max_iterations=5,
    step_size=0.1,
)
layer = th.TheseusLayer(optimizer)

phi = torch.nn.Parameter(x_true + 0.1 * torch.ones_like(x_true))
outer_optimizer = torch.optim.Adam([phi], lr=1.0)
input_tensors = {
    "x": phi.clone(),
    "a": 0.5 * torch.ones(1, 1),
    "b": torch.ones(1, 1),
    "c": torch.ones(1, 1),
}
input_tensors = {k: t.to(dtype=torch.complex64) for k, t in input_tensors.items()}
for epoch in range(20):
    solution, info = layer.forward(
        input_tensors=input_tensors,
        optimizer_kwargs={
            "backward_mode": "unroll",
            "verbose": True,
            "damping": 100.0,
        },
    )
    outer_loss1 = torch.nn.functional.mse_loss(solution["a"].real, a_true)
    outer_loss2 = torch.nn.functional.mse_loss(solution["b"].real, b_true)
    outer_loss3 = torch.nn.functional.mse_loss(solution["c"].real, c_true)
    outer_loss = outer_loss1 + outer_loss2 + outer_loss3
    outer_loss.backward()
    outer_optimizer.step()
    print("Outer loss: ", outer_loss.item())

@zhangrentu
Copy link
Author

Thank you very much for your response. However, when we use TheseusLayer in the model, it still returns NoneType. Additionally, optim_vars and aux_vars data are automatically placed on Cuda:0, but the objection is on cuda, which causes an error when calling objection.update, making it impossible to use data parallel training (nn.DataParallel).

image
image
image

@zhangrentu
Copy link
Author

We tried changing the error function to the real part, for example, real(exp(ia)) = cos a, but the error did not converge after the modification, or the matrix is non-positive definite, etc. We also attempted to reduce the step_size, which showed a slight improvement. Are there any other adjustment methods? Currently, we are using the following configuration:
optimizer = th.LevenbergMarquardt(
objective,
th.CholeskyDenseSolver,
th.DenseLinearization,
max_iterations=20,
step_size=0.02,
)

@luisenp
Copy link
Contributor

luisenp commented Nov 4, 2023

@zhangrentu This code has not been merged to main yet. Are you using the code directly from that branch? If you are, then please share a short snippet of code that results in the device error, because I'm not sure how that can happen in the code from that branch. Thanks!

@zhangrentu
Copy link
Author

Yes, I called it from a branch. The main process of parallel computing is as follows: data is placed on the primary GPU (cuda:0), and the model is distributed to the GPUs used for parallel computing (e.g., cuda:0, cuda:1, cuda:2). Since the TheseusLayer is treated as a layer within the network, it belongs to the model part. However, during parameter updates, the objective of the cost function was not distributed, resulting in the error. The specific error is as shown in the image below:

image
image
image

@luisenp
Copy link
Contributor

luisenp commented Nov 6, 2023

Ah, I see. We have never tested this inside a DataParallel model, so I don't have a lot of insight yet. Could you share a short repro script?

@luisenp
Copy link
Contributor

luisenp commented Nov 6, 2023

@zhangrentu Regarding the convergence, in the script I shared above one thing I did was to increase the damping to a really large value (I used 100.0), and the error does seem to decrease when doing this.

@zhangrentu
Copy link
Author

zhangrentu commented Nov 7, 2023

@zhangrentu Regarding the convergence, in the script I shared above one thing I did was to increase the damping to a really large value (I used 100.0), and the error does seem to decrease when doing this.

I really appreciate your response. It seems that the complex form is not currently supported. At this point, I conducted a simple real-number parameter estimation experiment; the code is as follows: estimating the exponential parameter in the presence of noise, where y is the noisy signal; my idea is to update y with an external network, and the internal NLLS estimates the parameters given y. However, the results do not converge either. I've tried adjusting the step size of NLLS in the optimizer and the learning rate of the external Adam, but it had no effect. I'm not sure what the reason is.

import torch
import theseus as th


def y_model(x, a, b, c):
   return a * torch.exp(b * x + c)  # y = a * exp(b*x + c)


def generate_data(num_points=512, a=1, b=2, c=3, noise_std=0.1):
   data_x = torch.linspace(0, 1, num_points).view(1, -1)
   data_y_noise = y_model(data_x, a, b, c) + noise_std * torch.randn((1, num_points))
   return data_x, data_y_noise


def read_data():
   data_x, data_y_noise = generate_data()
   return (
       data_x,
       data_y_noise,
       1 * torch.ones(1, 1),
       2 * torch.ones(1, 1),
       3 * torch.ones(1, 1),
   )


x_true, y_true, a_true, b_true, c_true = read_data()
x = th.Variable(torch.randn_like(x_true), name="x")
y = th.Variable(torch.randn_like(y_true), name="y")
a = th.Vector(1, name="a")  # a manifold subclass of Variable for optim_vars
b = th.Vector(1, name="b")  # a manifold subclass of Variable for optim_vars
c = th.Vector(1, name="c")  # a manifold subclass of Variable for optim_vars

def error_fn(optim_vars, aux_vars):  # returns y - a * exp(b * x + c)
   x, y = aux_vars
   return y.tensor - optim_vars[0].tensor * torch.exp((optim_vars[1].tensor * x.tensor + optim_vars[2].tensor))


objective = th.Objective(dtype=torch.float32)
w = th.ScaleCostWeight(1.0)
w.to(dtype=torch.float32)
cost_function = th.AutoDiffCostFunction(
   [a, b, c], error_fn, y_true.shape[1], aux_vars=[x, y], cost_weight=w
)
objective.add(cost_function)
optimizer = th.LevenbergMarquardt(
    objective,
    # th.CholeskyDenseSolver,
    max_iterations=20,
    step_size=0.01,
)
layer = th.TheseusLayer(optimizer)
phi = torch.nn.Parameter(y_true)

outer_optimizer = torch.optim.Adam([phi], lr=0.001)
for epoch in range(20):
   outer_optimizer.zero_grad()
   solution, info = layer.forward(
       input_tensors={
          "x": x_true,
          "y": phi.clone(),
          "a": 0.5 * torch.ones(1, 1),
          "b": torch.ones(1, 1),
          "c": torch.ones(1, 1),
       },
       optimizer_kwargs={"backward_mode": "implicit"},
   )
   outer_loss1 = torch.nn.functional.mse_loss(solution["a"], a_true)
   outer_loss2 = torch.nn.functional.mse_loss(solution["b"], b_true)
   outer_loss3 = torch.nn.functional.mse_loss(solution["c"], c_true)
   outer_loss = outer_loss1 + outer_loss2 + outer_loss3
   print("Outer loss:", outer_loss.item())
   print("a_value:", solution["a"])
   print("b_value:", solution["b"])
   print("c_value:", solution["c"])
   outer_loss.backward()
   outer_optimizer.step()

image

@zhangrentu
Copy link
Author

zhangrentu commented Nov 7, 2023

Ah, I see. We have never tested this inside a DataParallel model, so I don't have a lot of insight yet. Could you share a short repro script?
I believe that parallel training can accelerate the operation of the network and significantly improve efficiency for code debugging. Here is a simple example code:

import os

import torch.nn.functional
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6"

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import theseus as th
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

torch.set_default_dtype(torch.float32)
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

# Custom dataset - each data contains x, y, and coefficient to be estimated
class ExponentialData(Dataset):
    def __init__(self, num_data, num_points_per_data, noise_stddev):
        self.num_data = num_data
        self.num_points_per_data = num_points_per_data
        self.noise_stddev = noise_stddev
        self.data = []

        for _ in range(num_data):
            a = torch.rand(1)   # Randomly generate the coefficient for each data point
	    b = torch.rand(1)   # Randomly generate the coefficient for each data point
            c = torch.rand(1)   # Randomly generate the coefficient for each data point
            x = torch.linspace(0, 1, self.num_points_per_data)
            y = a * torch.exp(b * x + c) + torch.randn(self.num_points_per_data) * self.noise_stddev
            self.data.append((x, y, a, b, c))

    def __len__(self):
        return self.num_data

    def __getitem__(self, idx):
        return self.data[idx]

# Create a dataset and data loader
num_data = 120  # Generate 120 data 
num_points_per_data = 512  # Each data contains 512 data points
noise_stddev = 0.1
batch_size = 20
dataset = ExponentialData(num_data=num_data_points, num_points_per_data=num_points_per_data, noise_stddev=noise_stddev)
data_loader = DataLoader(dataset, batch_size=batch_size *3 , shuffle=True) # The number of GPU is 3

# Define a neural network model, including the coefficient to be estimated
class ModelWithCoefficient(nn.Module):
    def __init__(self, Optimizer, device):
        super(ModelWithCoefficient, self).__init__()
        self.fc1 = nn.Linear(1, 10)
        self.fc2 = nn.Linear(10, 1)
        self.Para = th.TheseusLayer(Optimizer).to(device=device, dtype=torch.float32)

    def forward(self, x,  x_true, batch_size, device):
        x = torch.relu(self.fc1(x))
        y_hat = self.fc2(x)
	input_tensors={
          "x": x_true,
          "y": y_hat,
          "a": 0.5 * torch.ones( batch_size, 1).to(device),
          "b": torch.ones(batch_size, 1).to(device),
          "c": torch.ones(batch_size, 1).to(device),
        }
        optimizer_kwargs = {"backward_mode": "implicit"}
        solution, _ = self.Quan(input_tensors, optimizer_kwargs)

        return solution

# defined optimizer for theseus
x = th.Variable(tensor=torch.rand(batch_size, 512, dtype=torch.float32), name="x")
y = th.Variable(tensor=torch.rand(batch_size, 512, dtype=torch.float32), name="y")
a = th.Vector(tensor=torch.rand(batch_size, 1, dtype=torch.float32), name="a")  # a manifold subclass of Variable for optim_vars
b = th.Vector(tensor=torch.rand(batch_size, 1, dtype=torch.float32), name="b")  # a manifold subclass of Variable for optim_vars
c = th.Vector(tensor=torch.rand(batch_size, 1, dtype=torch.float32), name="c")  # a manifold subclass of Variable for optim_vars

def error_fn(optim_vars, aux_vars):  # returns y - a * exp(bx + c)
   x, y = aux_vars
   return y.tensor - optim_vars[0].tensor * torch.exp((optim_vars[1].tensor * x.tensor + optim_vars[2].tensor))

objective = th.Objective(dtype=torch.float32)
w = th.ScaleCostWeight(1.0)
w.to(dtype=torch.float32)
cost_function = th.AutoDiffCostFunction(
   [a, b, c], error_fn, y_true.shape[1], aux_vars=[x, y], cost_weight=w
)
objective.add(cost_function)
optimizer = th.LevenbergMarquardt(
    objective,
    # th.CholeskyDenseSolver,
    max_iterations=20,
    step_size=0.01,
)

# Create the model, loss function, and optimizer
model = ModelWithCoefficient(optimizer, device).to(torch.float32)
model = nn.DataParallel(model.to(device), device_ids=[0, 1, 2])

optimizer = optim.Adam(model.parameters(), lr=0.01)

# Train the model
num_epochs = 100

for epoch in range(num_epochs):
    for x_true, y_noise, a_true, b_true, c_true in data_loader:
        optimizer.zero_grad()
        solution = model( y_noise.view(-1, 1), x_true.view(-1, 1), batch_size, device)
	outer_loss1 = torch.nn.functional.mse_loss(solution["a"], a_true)
  	outer_loss2 = torch.nn.functional.mse_loss(solution["b"], b_true)
  	outer_loss3 = torch.nn.functional.mse_loss(solution["c"], c_true)
  	loss = outer_loss1 + outer_loss2 + outer_loss3
        loss.backward()
        optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}}')

@luisenp
Copy link
Contributor

luisenp commented Nov 7, 2023

@zhangrentu Regarding the convergence, in the script I shared above one thing I did was to increase the damping to a really large value (I used 100.0), and the error does seem to decrease when doing this.

I really appreciate your response. It seems that the complex form is not currently supported. At this point, I conducted a simple real-number parameter estimation experiment; the code is as follows: estimating the exponential parameter in the presence of noise, where y is the noisy signal; my idea is to update y with an external network, and the internal NLLS estimates the parameters given y. However, the results do not converge either. I've tried adjusting the step size of NLLS in the optimizer and the learning rate of the external Adam, but it had no effect. I'm not sure what the reason is.

Hi @zhangrentu. The problem in this example is that your system is underconstrained, resulting in a Jacobian of rank 2 when you have 3 optimization variables. Note that $Ae^{Bx + C} = Ae^{Bx} \cdot e^{C} = De^{Bx}$, where $D= Ae^{C}$, so there are actually infinite solutions for $A$ and $C$.

The system converges if you add one more constraint, for example, make $C=1$, by including the following.

t = th.Vector(tensor=torch.ones(1, 1), name="t")
objective.add(th.Difference(c, t, w, name="c_constraint"))

@zhangrentu
Copy link
Author

I'm sorry for providing an inappropriate example. Thank you again for your prompt responses every time. I currently have two main questions:

  1. When my solving system encounters an indefinite situation, what could be the issues, and are there any parameters that can be adjusted besides the step size or learning rate?
  2. Applying constraints to optimize parameters can reduce the solution space. Most of the applications I encounter involve non-negative constraints or parameters within a specific range. Does the platform currently have relevant loss functions? If not, do you have any good recommendations?
    Thank you for your answers

@zhangrentu
Copy link
Author

@zhangrentu Regarding the convergence, in the script I shared above one thing I did was to increase the damping to a really large value (I used 100.0), and the error does seem to decrease when doing this.

I really appreciate your response. It seems that the complex form is not currently supported. At this point, I conducted a simple real-number parameter estimation experiment; the code is as follows: estimating the exponential parameter in the presence of noise, where y is the noisy signal; my idea is to update y with an external network, and the internal NLLS estimates the parameters given y. However, the results do not converge either. I've tried adjusting the step size of NLLS in the optimizer and the learning rate of the external Adam, but it had no effect. I'm not sure what the reason is.

Hi @zhangrentu. The problem in this example is that your system is underconstrained, resulting in a Jacobian of rank 2 when you have 3 optimization variables. Note that AeBx+C=AeBx⋅eC=DeBx, where D=AeC, so there are actually infinite solutions for A and C.

The system converges if you add one more constraint, for example, make C=1, by including the following.

t = th.Vector(tensor=torch.ones(1, 1), name="t")
objective.add(th.Difference(c, t, w, name="c_constraint"))

For this problem, if you have only two optimization parameters, A and B, and you know that A equals B, how would you incorporate this constraint into the equation?

@luisenp
Copy link
Contributor

luisenp commented Dec 19, 2023

We don't have yet a principled solver for constrained problems. However, you can use soft penalties with a high cost weight to approximate this constraint, which can work well in many cases. If A is an optimization variable but B is not (e.g., a constant target), then you can use th.Difference() cost function (or perhaps th.eb.HingeCost). If both A and B are optimization variables, then you could use a th.Between constraint and set the measurement auxiliary var to a zero value (or identity if you are using rotation groups).

Something like (may have some syntax errors)

A = th.Vector(...)
B = th.Vector(...)
Z = th.Vector(tensor=torch.zeros(B, d), name=zeros)
cf = th.Between(A, B, Z, th.ScaleCostWeight(100.0), name="a_eq_b_constraint")
obj.add(cf)

This adds the constraint [(A - B) == 0] ** 2 to the objective.

@zhangrentu
Copy link
Author

Thank you for your suggestion. Regarding non-equality constraints, such as imposing non-negativity constraints on optimization variables, do you have any recommendations for effective penalty functions?

@luisenp
Copy link
Contributor

luisenp commented Dec 22, 2023

In this case you can use th.eb.HingeCost, which is defined here. Here is a visual example of what the error looks like when down_limit=-5, up_limit=3 and threshold=0. If threshold is non-zero, its effect is to push both limits towards zero.

image

For a non-negativity constraint, you could use something like down_limit=0, up_limit=torch.inf, and some threshold to discourage getting too close to zero.

Hope this helps. Do note that we haven't tested this cost functon as extensively as others, so please let us know if you have any feedback.

@zhangrentu
Copy link
Author

Thank you very much for your advice, and I wish you a Merry Christmas in advance. While this method does constrain the parameters to be non-negative, the iterative process is not stable. Nevertheless, I appreciate your input.
image

@wangdomg
Copy link

wangdomg commented Feb 1, 2024

Ah, I see. We have never tested this inside a DataParallel model, so I don't have a lot of insight yet. Could you share a short repro script?
I believe that parallel training can accelerate the operation of the network and significantly improve efficiency for code debugging. Here is a simple example code:

import os

import torch.nn.functional
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6"

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import theseus as th
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

torch.set_default_dtype(torch.float32)
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

# Custom dataset - each data contains x, y, and coefficient to be estimated
class ExponentialData(Dataset):
    def __init__(self, num_data, num_points_per_data, noise_stddev):
        self.num_data = num_data
        self.num_points_per_data = num_points_per_data
        self.noise_stddev = noise_stddev
        self.data = []

        for _ in range(num_data):
            a = torch.rand(1)   # Randomly generate the coefficient for each data point
	    b = torch.rand(1)   # Randomly generate the coefficient for each data point
            c = torch.rand(1)   # Randomly generate the coefficient for each data point
            x = torch.linspace(0, 1, self.num_points_per_data)
            y = a * torch.exp(b * x + c) + torch.randn(self.num_points_per_data) * self.noise_stddev
            self.data.append((x, y, a, b, c))

    def __len__(self):
        return self.num_data

    def __getitem__(self, idx):
        return self.data[idx]

# Create a dataset and data loader
num_data = 120  # Generate 120 data 
num_points_per_data = 512  # Each data contains 512 data points
noise_stddev = 0.1
batch_size = 20
dataset = ExponentialData(num_data=num_data_points, num_points_per_data=num_points_per_data, noise_stddev=noise_stddev)
data_loader = DataLoader(dataset, batch_size=batch_size *3 , shuffle=True) # The number of GPU is 3

# Define a neural network model, including the coefficient to be estimated
class ModelWithCoefficient(nn.Module):
    def __init__(self, Optimizer, device):
        super(ModelWithCoefficient, self).__init__()
        self.fc1 = nn.Linear(1, 10)
        self.fc2 = nn.Linear(10, 1)
        self.Para = th.TheseusLayer(Optimizer).to(device=device, dtype=torch.float32)

    def forward(self, x,  x_true, batch_size, device):
        x = torch.relu(self.fc1(x))
        y_hat = self.fc2(x)
	input_tensors={
          "x": x_true,
          "y": y_hat,
          "a": 0.5 * torch.ones( batch_size, 1).to(device),
          "b": torch.ones(batch_size, 1).to(device),
          "c": torch.ones(batch_size, 1).to(device),
        }
        optimizer_kwargs = {"backward_mode": "implicit"}
        solution, _ = self.Quan(input_tensors, optimizer_kwargs)

        return solution

# defined optimizer for theseus
x = th.Variable(tensor=torch.rand(batch_size, 512, dtype=torch.float32), name="x")
y = th.Variable(tensor=torch.rand(batch_size, 512, dtype=torch.float32), name="y")
a = th.Vector(tensor=torch.rand(batch_size, 1, dtype=torch.float32), name="a")  # a manifold subclass of Variable for optim_vars
b = th.Vector(tensor=torch.rand(batch_size, 1, dtype=torch.float32), name="b")  # a manifold subclass of Variable for optim_vars
c = th.Vector(tensor=torch.rand(batch_size, 1, dtype=torch.float32), name="c")  # a manifold subclass of Variable for optim_vars

def error_fn(optim_vars, aux_vars):  # returns y - a * exp(bx + c)
   x, y = aux_vars
   return y.tensor - optim_vars[0].tensor * torch.exp((optim_vars[1].tensor * x.tensor + optim_vars[2].tensor))

objective = th.Objective(dtype=torch.float32)
w = th.ScaleCostWeight(1.0)
w.to(dtype=torch.float32)
cost_function = th.AutoDiffCostFunction(
   [a, b, c], error_fn, y_true.shape[1], aux_vars=[x, y], cost_weight=w
)
objective.add(cost_function)
optimizer = th.LevenbergMarquardt(
    objective,
    # th.CholeskyDenseSolver,
    max_iterations=20,
    step_size=0.01,
)

# Create the model, loss function, and optimizer
model = ModelWithCoefficient(optimizer, device).to(torch.float32)
model = nn.DataParallel(model.to(device), device_ids=[0, 1, 2])

optimizer = optim.Adam(model.parameters(), lr=0.01)

# Train the model
num_epochs = 100

for epoch in range(num_epochs):
    for x_true, y_noise, a_true, b_true, c_true in data_loader:
        optimizer.zero_grad()
        solution = model( y_noise.view(-1, 1), x_true.view(-1, 1), batch_size, device)
	outer_loss1 = torch.nn.functional.mse_loss(solution["a"], a_true)
  	outer_loss2 = torch.nn.functional.mse_loss(solution["b"], b_true)
  	outer_loss3 = torch.nn.functional.mse_loss(solution["c"], c_true)
  	loss = outer_loss1 + outer_loss2 + outer_loss3
        loss.backward()
        optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}}')

Hi, do you use theseus successfully in parallel network (torch.nn.DataPatallel)?

@zhangrentu
Copy link
Author

I'm very sorry for the late reply. It seems that the current library does not support parallel training. One possible reason could be the inconsistency of the data distribution across devices during parallel training tasks on the primary GPU.

Ah, I see. We have never tested this inside a DataParallel model, so I don't have a lot of insight yet. Could you share a short repro script?
I believe that parallel training can accelerate the operation of the network and significantly improve efficiency for code debugging. Here is a simple example code:

import os

import torch.nn.functional
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6"

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import theseus as th
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

torch.set_default_dtype(torch.float32)
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

# Custom dataset - each data contains x, y, and coefficient to be estimated
class ExponentialData(Dataset):
    def __init__(self, num_data, num_points_per_data, noise_stddev):
        self.num_data = num_data
        self.num_points_per_data = num_points_per_data
        self.noise_stddev = noise_stddev
        self.data = []

        for _ in range(num_data):
            a = torch.rand(1)   # Randomly generate the coefficient for each data point
	    b = torch.rand(1)   # Randomly generate the coefficient for each data point
            c = torch.rand(1)   # Randomly generate the coefficient for each data point
            x = torch.linspace(0, 1, self.num_points_per_data)
            y = a * torch.exp(b * x + c) + torch.randn(self.num_points_per_data) * self.noise_stddev
            self.data.append((x, y, a, b, c))

    def __len__(self):
        return self.num_data

    def __getitem__(self, idx):
        return self.data[idx]

# Create a dataset and data loader
num_data = 120  # Generate 120 data 
num_points_per_data = 512  # Each data contains 512 data points
noise_stddev = 0.1
batch_size = 20
dataset = ExponentialData(num_data=num_data_points, num_points_per_data=num_points_per_data, noise_stddev=noise_stddev)
data_loader = DataLoader(dataset, batch_size=batch_size *3 , shuffle=True) # The number of GPU is 3

# Define a neural network model, including the coefficient to be estimated
class ModelWithCoefficient(nn.Module):
    def __init__(self, Optimizer, device):
        super(ModelWithCoefficient, self).__init__()
        self.fc1 = nn.Linear(1, 10)
        self.fc2 = nn.Linear(10, 1)
        self.Para = th.TheseusLayer(Optimizer).to(device=device, dtype=torch.float32)

    def forward(self, x,  x_true, batch_size, device):
        x = torch.relu(self.fc1(x))
        y_hat = self.fc2(x)
	input_tensors={
          "x": x_true,
          "y": y_hat,
          "a": 0.5 * torch.ones( batch_size, 1).to(device),
          "b": torch.ones(batch_size, 1).to(device),
          "c": torch.ones(batch_size, 1).to(device),
        }
        optimizer_kwargs = {"backward_mode": "implicit"}
        solution, _ = self.Quan(input_tensors, optimizer_kwargs)

        return solution

# defined optimizer for theseus
x = th.Variable(tensor=torch.rand(batch_size, 512, dtype=torch.float32), name="x")
y = th.Variable(tensor=torch.rand(batch_size, 512, dtype=torch.float32), name="y")
a = th.Vector(tensor=torch.rand(batch_size, 1, dtype=torch.float32), name="a")  # a manifold subclass of Variable for optim_vars
b = th.Vector(tensor=torch.rand(batch_size, 1, dtype=torch.float32), name="b")  # a manifold subclass of Variable for optim_vars
c = th.Vector(tensor=torch.rand(batch_size, 1, dtype=torch.float32), name="c")  # a manifold subclass of Variable for optim_vars

def error_fn(optim_vars, aux_vars):  # returns y - a * exp(bx + c)
   x, y = aux_vars
   return y.tensor - optim_vars[0].tensor * torch.exp((optim_vars[1].tensor * x.tensor + optim_vars[2].tensor))

objective = th.Objective(dtype=torch.float32)
w = th.ScaleCostWeight(1.0)
w.to(dtype=torch.float32)
cost_function = th.AutoDiffCostFunction(
   [a, b, c], error_fn, y_true.shape[1], aux_vars=[x, y], cost_weight=w
)
objective.add(cost_function)
optimizer = th.LevenbergMarquardt(
    objective,
    # th.CholeskyDenseSolver,
    max_iterations=20,
    step_size=0.01,
)

# Create the model, loss function, and optimizer
model = ModelWithCoefficient(optimizer, device).to(torch.float32)
model = nn.DataParallel(model.to(device), device_ids=[0, 1, 2])

optimizer = optim.Adam(model.parameters(), lr=0.01)

# Train the model
num_epochs = 100

for epoch in range(num_epochs):
    for x_true, y_noise, a_true, b_true, c_true in data_loader:
        optimizer.zero_grad()
        solution = model( y_noise.view(-1, 1), x_true.view(-1, 1), batch_size, device)
	outer_loss1 = torch.nn.functional.mse_loss(solution["a"], a_true)
  	outer_loss2 = torch.nn.functional.mse_loss(solution["b"], b_true)
  	outer_loss3 = torch.nn.functional.mse_loss(solution["c"], c_true)
  	loss = outer_loss1 + outer_loss2 + outer_loss3
        loss.backward()
        optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}}')

Hi, do you use theseus successfully in parallel network (torch.nn.DataPatallel)?

@zhangrentu
Copy link
Author

zhangrentu commented Mar 11, 2024 via email

@luisenp
Copy link
Contributor

luisenp commented Mar 12, 2024

Hi @zhangrentu. One option I can think of would be to add a new variable C that represents the difference between A and B. Then you can add the following costs (pseudocode):

def c_eq_a_diff_b(optim_vars, aux_vars):
  A, B, C = optim_vars
  return C.tensor - (A.tensor - B.tensor)

cost_1 = AutodiffCostFunction([A, B, C], c_eq_a_diff_b, C.dof)   # C = A - B
cost_2 = HingeCost(C, 0, np.inf, threshold)  # C > 0
cost_3 = HingeCost(B, 0, np.inf, threshold)  # B > 0

Not sure if this will work well, it might be tricky to optimize properly.

@zhangrentu
Copy link
Author

Hi @zhangrentu. One option I can think of would be to add a new variable C that represents the difference between A and B. Then you can add the following costs (pseudocode):

def c_eq_a_diff_b(optim_vars, aux_vars):
  A, B, C = optim_vars
  return C.tensor - (A.tensor - B.tensor)

cost_1 = AutodiffCostFunction([A, B, C], c_eq_a_diff_b, C.dof)   # C = A - B
cost_2 = HingeCost(C, 0, np.inf, threshold)  # C > 0
cost_3 = HingeCost(B, 0, np.inf, threshold)  # B > 0

Not sure if this will work well, it might be tricky to optimize properly.

Thanks, as before, the convergence is not stable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants