Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex numbers in DeepXDE? #284

Closed
AneleNcube opened this issue May 9, 2021 · 24 comments
Closed

Complex numbers in DeepXDE? #284

AneleNcube opened this issue May 9, 2021 · 24 comments

Comments

@AneleNcube
Copy link

Is there a way to change the data type in deepxde from float32 to complex32 or complex64? I read the answer given for issue #28 about changing from the default single precision, float32, to double precision, float64, for real floating point numbers. The command used is:

dde.config.real.set_float64()

Is there a similar command for converting to complex numbers, or only real numbers are permitted by deepxde?

@AneleNcube AneleNcube changed the title Complex numbers in DeepXDE Complex numbers in DeepXDE? May 9, 2021
@lululxvi
Copy link
Owner

The network only uses real numbers. You need to split the read and imag parts of the complex numbers explicitly.

@YouCantRedo
Copy link

@AneleNcube Hi, I want to ask if you solved the problem successfully, I have the same question, and I don't know how to change my code to fit the split real and imag number. I really need your help.
@lululxvi I want to know why this network doesn't support complex numbers because the complex number often occurs in a real-life problem, if it is possible for me to change the deepXDE API so it can support complex numbers? If so, where should I change it? Thanks very much!

@AneleNcube
Copy link
Author

AneleNcube commented Aug 27, 2021

@YouCantRedo Hi, I managed to solve my problem by splitting the real and imag parts as @lululxvi had suggested. I had to include two (instead of one) neural network outputs, one for the real and the other for the imag part of u(x), the surrogate of the solution to the differential equation.

Since the particular spatial domain I was dealing with was real-valued (i.e. -1.0 < x < 1.0), separating the real part of the differential equation from the imag was not complicated: I just treated the two parts like a system of differential equations (similar to the DeepXDE example https://github.com/lululxvi/deepxde/blob/master/examples/Lorenz_inverse.py which has more that two outputs and just one input for the real-valued input). That way, I did not need to use a complex data type to represent the complex-valued differential equation that I wanted to solve. I obtained the expected results after running the code

This is a snippet from my code where I applied the DeepXDE package to solve an eigenvalue problem with complex eigenpairs:

omega1 = 0.5 #real part of eigenvalue
omega2 = -0.5 #imag part of eigenvalue

#the differential eqution 
def pde(x, y):
        y1, y2 = y[:, 0:1], y[:, 1:2]   # y1= real[y]. y2 = imag[y]
        
        y1_x = dde.grad.jacobian(y, x, i=0)
        y2_x = dde.grad.jacobian(y, x, i=1)
        y1_xx = dde.grad.hessian(y, x, component=0, i=0, j=0)
        y2_xx = dde.grad.hessian(y, x, component=1, i=0, j=0)

        # real and imaginary parts of differential equation 
        return [((1 - x**2)**2)*y1_xx - 2 * x *(1 - x**2) * y1_x - ((1 - x**2)/2) * y1 - 2*omega1*omega2*y2 + (omega1**2 -omega2**2)*y1 ,
                ((1 - x**2)**2)*y2_xx - 2 * x *(1 - x**2) * y2_x - ((1 - x**2)/2) * y2 + 2*omega1*omega2*y1 + (omega1**2 - omega2**2)*y2 ]

The rest of the code was more or less the same. Just some extra care needs to be taken when writing out the real and imag parts of the differential equations. But, I think it is possible to avoid the need to expand the differential equation yourself (with pen and paper) when you also apply a package such as pytorch with DeepXDE (but I haven't checked this). In that case, you would only write your differential equation as usual, then let the computer split the differential equation by using np.real(...) and np.imag(...).

I hope this might help you overcome your current challenge with your code. I'm not sure exactly why the neural network code cannot handle complex datatypes, perhaps it is something in the source codes for the loss. I will not hazard a guess too much, so I will leave that for @lululxvi to answer.

@YouCantRedo
Copy link

Dear @AneleNcube
Thank you very much for your patient reply! I really learned a lot from your answer. However, I still have a few questions that I want to discuss with you:

  1. About the pde code you shared, I wonder why you did not change i =1 in y2_xx.
    Also, I am a little bit confused about the coefficients for the function you returned, Doesn't the first equation returned contain only omega1 and the second only contains omega2? In my case, I only have 1 pure imaginary coefficient for my function, can you please check if the code below is correct or not.
    `def pde(xyz, u):

pde: du_z - con * (du_xx + du_yy)

u1, u2 = u[:, 0:1], u[:, 1:2]      # u1= real[u]. u2 = imag[u]

du1_z = dde.grad.jacobian(u, xyz, i=0, j=2)
du1_xx = dde.grad.hessian(u, xyz, component=0, i=0, j=0)
du1_yy = dde.grad.hessian(u, xyz, component=0, i=0, j=1)

du2_z = dde.grad.jacobian(u, xyz, i=1, j=2)
du2_xx = dde.grad.hessian(u, xyz, component=1, i=1, j=0)
du2_yy = dde.grad.hessian(u, xyz, component=1, i=1, j=1)

con = -0.00795224193206157j

con1 = 0 # real part of coefficient
con2 = -0.00795224193206157 # imag part of coefficient

return [
    du1_z - con1 * (du1_xx + du1_yy),
    du2_z - con2 * (du2_xx + du2_yy),
]`
  1. As shown in Lorenz_Inverse example, I realize I should define two ic and two bc (for real and imaginary condition), the tricky part is about how to define the func , should I define 2 functions, func1 only return real part of the original func and func2 return imag part? In this case, I can substitute func1 to ic1 and func2 to ic2. However, how should I substitute func1 and func2 to data.TimePDE part? Because the Lorenz_Inverse example did not use the reference solution function, maybe you know sth about this issue.

  2. May I ask more about other places you changed in your code. because I am afraid that I will miss some parts unchanged.

@AneleNcube
Copy link
Author

Hi @YouCantRedo.
I'm glad to assist. I am relatively new to neural networks and I am also learning from discussions here. Let me see if I can answer you questions.

  1. Your pde seems to have 3 independent variables (x, y, and z) and 1 complex-valued dependent variable (u). I believe you have correctly represented the pde in the code, but when defining the second derivatives with the Hessian matrix note that:

Hessian matrix H: H[i][j] = d^2y / dx_i dx_j, where i,j=0,…,dim_x-1. Thus, in this case, unlike the Jacobian matrix, the indices "i" and "j" denote the input variables in the sense that "x_0" is x, "x_1" is y and "x_2" is z . So, to define the second derivative of u2 w.r.t input "x", for example, we must have i = j = 0 so that H[0][0] = d^2 u2 / dx_0 dx_0. The argument "component" in dde.grad.hessian( ... ) will specify the output variable. So I think you should have:

u1, u2 = u[:, 0:1], u[:, 1:2]      # u1= real[u]. u2 = imag[u]

du1_z = dde.grad.jacobian(u, xyz, i=0, j=2)
du1_xx = dde.grad.hessian(u, xyz, component=0, i=0, j=0)
du1_yy = dde.grad.hessian(u, xyz, component=0, i=1, j=1) # changed "i" to 1 so that we have (d^2 u1/dx_1dx_1)

du2_z = dde.grad.jacobian(u, xyz, i=1, j=2)
du2_xx = dde.grad.hessian(u, xyz, component=1, i=0, j=0) # changed "i" to 0 so that we have (d^2 u1/dx_0dx_0)
du2_yy = dde.grad.hessian(u, xyz, component=1, i=1, j=1)

Regarding the coefficients omega1 and omega2 in my code, they actually appear as (omega1 + omega*1j)**2 before splitting them into real and imaginary part. Thus, after expanding I get (omega1**2 - omega2**2) as the real part and (2*omega1*omega2) as the imaginary part. So, a lot of the terms mix up when I split my differential equation, like both y1 and y2 appear in the real and imaginary parts because originally I have:

CodeCogsEqn (2)

In the case of your equation, you have returned the correct expressions since your pde has 1 pure imaginary coefficent.

  1. You're absolutely right here. I omitted a lot of other important changes I made to the boundary conditions. So, I have included the other parts of my code here:
def func1(y): ### Real part of y (true/reference solution)
  Y =(1 - y**2)**(-1/4 - 1j/4)       ​
  ​sol1 = np.real(Y)
  return sol1

def func2(y):### Imaginary part of y (true/reference solution)
  Y =(1 - y**2)**(-1/4 - 1j/4)       
  sol2 = np.imag(Y)
  return sol2 

def boundary_l(x, on_boundary):
        return on_boundary and np.isclose(x[0], -0.9) 
def boundary_r(x, on_boundary):
        return on_boundary and np.isclose(x[0], 0.9) 

geom = dde.geometry.Interval(-0.9, 0.9)  ### defines a one-dimensional spatial domain

bc1 = dde.DirichletBC(geom, func1, lambda _, boundary_l : boundary_l, component=0) ### component 0 is the real part of y (boundary conditions of the real part)
bc2 = dde.DirichletBC(geom, func1, lambda _, boundary_r : boundary_r, component=0)

bc3 = dde.DirichletBC(geom, func2, lambda _, boundary_l : boundary_l, component=1)  ### component 1 is the imag part of y (boundary conditions of the real part)
bc4 = dde.DirichletBC(geom, func2, lambda _, boundary_r : boundary_r, component=1)

data = dde.data.PDE(
        geom,
        de,
        [bc1, bc2, bc3, bc4],
        num_domain=100,
        num_boundary=2,
    )

layer_size = [1] + [20] * 3  + [2]
net = dde.maps.FNN(layer_size, "tanh", "Glorot uniform")
model = dde.Model(data, net)

Like you suggested and as I have done in my code, I had boundary conditions for the real and imaginary parts of the pde. I used the boundary values of my reference solution as Dirichlet boundary conditions. In your code, is "z" the time variable and what are your initial conditions? If the initial conditions are in terms of a function, I think you can try something similar to the code from the example: reaction_inverse https://github.com/lululxvi/deepxde/blob/master/examples/reaction_inverse.py:

def fun_init(x):
        return np.exp(-20 * x[:, 0:1])

    geom = dde.geometry.Interval(0, 1)
    timedomain = dde.geometry.TimeDomain(0, 10)
    geomtime = dde.geometry.GeometryXTime(geom, timedomain)

...
    ic1 = dde.IC(geomtime, fun_init, lambda _, on_initial: on_initial, component=0)
    ic2 = dde.IC(geomtime, fun_init, lambda _, on_initial: on_initial, component=1)

Depending on how your initial condition looks like, you can customise this to your equations.

Since you have two other inputs (x and y), I suppose you are using dde.geometry.Polygon( ... ) to define the 2-D xy-plane (like in the example: Poisson_Lshape https://github.com/lululxvi/deepxde/blob/master/examples/Poisson_Lshape.py). Yes, I think you would need to define separate boundary conditions for the real and imaginary parts of your differential equation as was done in the Lorenz_Inverse example.

  1. I think the above snippet of my code is the rest of the important changes I made to my code so as to solve my complex-valued differential equation. I hope this extra information will clarify things and assist you much more.

@YouCantRedo
Copy link

Dear @AneleNcube
I have read your answer carefully, it is very useful, Big thanks!

  1. Your explanation of the dde.grad.hessian is very good, I realized that I made a big mistake before.

  2. Yes, I treated the z parameter as time, and my initial condition can be a function, but this function is also a reference solution for my problem. For this point, the example is similar to my problem (where bc and data share a common func), so I am a little bit confused that which func should go to the data argument because I already separated my func to func1 and func2 as you did. I tried the [func1, func2] but it seems that this argument can only be a single function. Below is the main part of my code, I will be very grateful if you can look at it.
    image

  3. I also have a further question that I want to discuss with you: Directly import array ranther than func to Class IC #362
    In my case, I can directly obtain the initial condition data (when z=0) instead of the function, and I think it is ok to directly substitute the data (numpy.array) to replace func in theory. When I try this approach, it runs ok, but the num_inital parameter must be set the same as the size of the data array, otherwise, it won't run. And the result is slightly different from when I use func, do you have any idea about this phenomenon?

@AneleNcube
Copy link
Author

AneleNcube commented Aug 31, 2021

Hi @YouCantRedo. I am glad to discuss and share ideas.

  1. It wasn't a major mistake, lol.
  2. I had the same issue as well regarding that part where you tried solution=[func1, func2]. It seems it only admits a single function to compare the neural network solution with the reference solution while it trains. I am not sure whether it would make sense even if one included either func1 or func2 (at a time) since the neural network in our case has two outputs. For now, I can think of one way of seeing the accuracy of the model as it trains (I assume that is the purpose of including the reference solution in data.TimePDE); however, it might be unnecessarily long.

You could save the model every 1000 or so training epochs, then after it has completed training you can use a for loop to load all the 1000-th epoch models and apply the test metric on them. After that, generate an array of accuracy values for both the real and imag parts of your solution. So, in your code, after setting up your data = dde.data.TimePDE(...) include the following code for training your model.

checkpoint_filepath = "/content/model/model.ckpt"
model_checkpoint_callback = dde.callbacks.ModelCheckpoint( filepath=checkpoint_filepath, save_better_only=False, period=1000)
...

losshistory, train_state = model.train(epochs=20000, callbacks=[model_checkpoint_callback]) 

This will save your model every n number of epochs as specified in dde.callbacks.ModelCheckpoint(... period= n). After training, the following code loads the models, calculates the L_2 relative error of their outputs and saves them in lists:

Err1 = [] ### L2 error (real part of solutions)
Err2 = [] ### L2 error (imag part of solutions)

for i in range(20):
  models="/content/model/model.ckpt-{}".format((i+1)*1000)
  model.restore(models, verbose=1)
  Y = np.linspace(-0.9, 0.9, num=2500)
  x = geom.uniform_points(2500, True)
  Y1 = model.predict(x)[:, 0:1] ### model output (real part)
  Y2 = model.predict(x)[:, 1:2] ### model output (imag part)
  ytrue1 = func1(Y) #reference solutions
  ytrue2 = func2(Y)
  Err1.append(dde.metrics.l2_relative_error(ytrue1, Y1))
  Err2.append(dde.metrics.l2_relative_error(ytrue2, Y2))

### Plotting the L_2 relative error during the training
plt.figure(figsize = (8,6))
plt.plot(np.arange(len(Err1)),Err1, '-b', linewidth=1, label = '$Re[\hat{y}(x)]$')
plt.plot(np.arange(len(Err2)),Err2, '-r', linewidth=1, label = '$Im[\hat{y}(x)]$')
plt.tight_layout()
plt.legend()
plt.ylabel('$L_2\ relative\ error$');plt.xlabel('Thousand Epochs')

I tried this with one of my problems and obtained the following plot for the L_2 relative errors.

download (1)

Perhaps there is another much more compact and quick alternative to doing this, but for now I can only think of this "Rube Goldberg-like" approach.

  1. This is an interesting question, but I'm not sure what could be the reason behind that outcome. I've been focusing mostly on time-independent differential equations so I haven't had an opportunity to deal with IC a lot.

@lululxvi
Copy link
Owner

lululxvi commented Sep 4, 2021

@AneleNcube @YouCantRedo The way to provide the solution is to define a function with multiple outputs:

def func(X):
    real = ... # The real part of the solution, column array
    imag = ... # The im part of the solution, column array
    return np.hstack((real, imag))

... TimePDE(..., solution=func)

@YouCantRedo
Copy link

YouCantRedo commented Sep 4, 2021

@AneleNcube Thank you very much for your answer! I have 2 other questions.

  1. I have some doubts about the configuration of my pde function, as I mentioned earlier, my pde function is:
    `con1 = np.real(con) # real part of coefficient
    con2 = np.imag(con) # imag part of coefficient

return [
du1_z - con1 * (du1_xx + du1_yy),
du2_z - con2 * (du2_xx + du2_yy),
]`
Even though the coefficient con is separated into real and imag parts, the variable u is also a complex number, so the derivative of u is therefore complex, so I think the 1st expression is not all about real parts but also contains imag parts and so does the 2nd expression. Is this ok?

  1. These days, I tried your code about plotting the L_2 relative errors, but it didn't run successfully, I assume it is because your example is a 1D problem and mine is 3D, so I should have to make some changes to your code. But your code is indeed very helpful! A few hours ago, lulu answered this question, and I change the solution=fun as suggested, it runs ok, below is my code:
    image

The result is as shown:
image

image

I think this approach can achieve the same effect as your code (as shown in Fig2), is this right?
I am very very new to neural networks (as you can see, haha) and I can't tell whether my model is good enough according to Fig2, is there any standard?
As you can see, my code is really basic, so I want to ask if I include further postprocessing methods, did you just plot the L_2 relative errors to observe the results?

@AneleNcube
Copy link
Author

Thank you @lululxvi and @YouCantRedo.

  1. That's right, I had overlooked that! The PDE was not correct. After a careful expansion, I think you should obtain the following for the real and imaginary parts of the PDE.
return[
du1_z - con1 * (du1_xx + du1_yy) + con2*(du2_xx + du2_yy),  #real part
du2_z - con2 * (du1_xx + du1_yy) - con1*(du2_xx + du2_yy)  #imag part
]
  1. I too just tried @lululxvi 's above suggestion...

@AneleNcube @YouCantRedo The way to provide the solution is to define a function with multiple outputs:

def func(X):
    real = ... # The real part of the solution, column array
    imag = ... # The im part of the solution, column array
    return np.hstack((real, imag))

... TimePDE(..., solution=func)

It works quite well for me too. For gauging the accuracy of your model, perhaps the best metric is L2 relative error which you can add with this piece:

model.compile("adam", lr=0.001, metrics=["l2 relative error"])

The L2 relative error ought to decrease as the training progresses and I think a good model will have, at most, a L2 relative error 1e-2 or 1e-3. But, it depends on the level of accuracy you need from your model. In the case where you need to get more accurate approximations you may have to train for more epochs.

In terms of postprocessing, I can see from your image that you didn't obtain the plots of your solution u(x,y,z) perhaps because it has 3 input variables and 1 response variable. I'm not sure yet how you can get the plots in your scenario, but you may need to write some code for plotting the predicted u(x,y) at different constant values of z.

Do you mind sharing your code so that I can try that approach?

@YouCantRedo
Copy link

@AneleNcube I would be happy to share the code with you! Can you please leave an email, so I can send it to you!

@AneleNcube
Copy link
Author

My email address: ncubeanele4@gmail.com

@lululxvi
Copy link
Owner

lululxvi commented Sep 9, 2021

@YouCantRedo @AneleNcube For 3d case, unfortunately, DeepXDE won't plot it for you. You have to predict the solution and plot it by yourself.

@AneleNcube
Copy link
Author

AneleNcube commented Sep 10, 2021

Hi @YouCantRedo and @lululxvi.

Thank you for sharing your code @YouCantRedo. I looked at it today and tried to figure out why there is a high L2 relative error (~1e+01) when comparing the model prediction and u_true (the reference solution provided by your npz files). I'm not sure of the reason yet, but I generated some plots comparing the real part of u(x,y,z) as given by u_pred, u_true, and func (the model prediction, reference solution in dataset form, and reference solution in functional form, respectively):

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

z_i = [0., 1., 25., 50.] # z-coordinate

for i in z_i:
  start = np.where(data_xyz[:,2:3]==i)[0][0] #start of xyz coordinates where z=z_i
  end = np.where(data_xyz[:,2:3]==i)[0][-1] #end of xyz coordinates where z=z_i

  data_xyz, u_true = gen_testdata()  #u_true = reference solution in dataset form
  u_func = func(data_xyz)  #u_func = reference solution in functional form
  u_pred = model.predict(data_xyz) # u_pred = model prediction

  X = data_xyz[:, 0:1][start:end]
  Y = data_xyz[:, 1:2][start:end]
  Z1 = u_pred[:, 1:2][start:end]
  Z2 = u_true[:, 1:2][start:end]
  Z3 = u_func[:, 1:2][start:end]

  fig = plt.figure(figsize=plt.figaspect(1/3))
  ax = fig.add_subplot(1, 3, 1, projection='3d')
  ax.scatter(X, Y, Z1, c=Z1, cmap='viridis', linewidth=0.5);
  ax.set_title('pred (real), z = {}'.format(i))
  ax.set_xlabel('x')
  ax.set_ylabel('y')
  ax.set_zlabel('u(x,y)');

  ax = fig.add_subplot(1, 3, 2, projection='3d')
  ax.scatter(X, Y, Z2, c=Z2, cmap='viridis', linewidth=0.5);
  ax.set_title('true (real), z = {}'.format(i))
  ax.set_xlabel('x')
  ax.set_ylabel('y')
  ax.set_zlabel('u(x,y)');

  ax = fig.add_subplot(1, 3, 3, projection='3d')
  ax.scatter(X, Y, Z3, c=Z3, cmap='viridis', linewidth=0.5);
  ax.set_title('func (real), z = {}'.format(i))
  ax.set_xlabel('x')
  ax.set_ylabel('y')
  ax.set_zlabel('u(x,y)');

print("L2 relative error u_true vs u_func:", dde.metrics.l2_relative_error(u_true, u_func))
print("L2 relative error u_pred vs u_func:", dde.metrics.l2_relative_error(u_pred, u_func))

Screenshot of output (u_pred vs. u_true vs. func)

u_pred vs u_true vs func

These are plots of u(x, y , z=constant) for some values of "z" I chose randomly: z= 0, 1, 25, 50. Notice there seems to be a mismatch between the two reference solutions u_true and func, as seen by the high L2 relative error comparing them and the plots for z= 25 and z=50. However, the two seem to agree for lower values of z, such as z=0 and z=1:

u_true vs. u_func
u_true vs func

Perhaps this mismatch between the reference solutions may explain the low accuracy of the predicted solution? Which reference solution is correct?

@louhz
Copy link

louhz commented Sep 12, 2021

I think i might know why deepxde is not support complex number. In tensorflow backend the tf.gradient can not distinguish the complex number and its conjugate in the last few version. But in the pytorch 1.9.0, they introduce the torch.complex() to solve complex numbers' by using the wirtinger deravitive to solve the autrograd, as you can see in: https://pytorch.org/docs/stable/complex_numbers.html
I am not sure whether the latest update of tensorflow could deal with the complex number.
The possible solution is add the torch.complex() in the pytorch backend of the dde.grad.jacobian and dde.grad.jacobian

@YouCantRedo
Copy link

YouCantRedo commented Sep 14, 2021

@AneleNcube Thank you very much for your work!
I also found the issue you mentioned a few days ago, it turned out I made a small mistake in func (func/fun1/func2), I have corrected my code and tested the results using your code to generate some plots, the plots of u_true vs. func are now consistent. Thank you very much for your code for generating plots!
image

However, the L2 relative error is still relatively high (~1e-01), and decreases very slow, do you have any idea for this reason.
image

I have sent the new version of my code with correct func to your email address, please feel free to check it!

@lululxvi
Copy link
Owner

@louhz You are welcome to submit a pull request.

@YouCantRedo
Copy link

@louhz Thank you very much for your answer! Now I am tring to use the approch provided by @AneleNcube.

@louhz
Copy link

louhz commented Sep 24, 2021

@louhz You are welcome to submit a pull request.

Yes. I am trying to write the code for the TensorFlow version of the hessian and jacobian matrix for the holomorphic functions. I will pull the request after I finish the code.

@AneleNcube
Copy link
Author

@louhz You are welcome to submit a pull request.

Yes. I am trying to write the code for the TensorFlow version of the hessian and jacobian matrix for the holomorphic functions. I will pull the request after I finish the code.

Thank you for considering this challenge @louhz.

@sandhu-githhub
Copy link

Dear contributors to this thread,

greetings

@sandhu-githhub
Copy link

sandhu-githhub commented Aug 20, 2022

Dear @lululxvi

I recently have started learning PINNs on the DeepXDE framework. I have read your communications and found you very helpful with the implementation of the diffusive PDE

#284

I would highly appreciate your help if you kindly guide me with corrections on a basic implementation of the scalar Helmholtz equation (time independent) in an unbounded homogeneous medium. It's the wave equation in 2D and its solution should be a plane wave in either of the 2 spatial directions.

I shall be grateful if you could have a look at the code below. I have commented on it well and made it very simple.

I have the following confusion too:

  1. I think that there is no need to implement any boundary conditions. Do you agree?

  2. it is not necessary to provide the true solution, and the PINNs should learn the right solution benefiting from reducing the loss function eventually.

  3. I observed that the training loss reduces significantly, but I am afraid I could not see expected results ?

  4. I arbitrarily chose 10(3) training(testing) points per wavelength respectively, however I could only obtain 960 points in the output, and could not reshape to visualize the image.

I shall once again highly acknowledge your help whenever I become able to produce something.

looking forward to hearing from you

helmholtz_problem

@sandhu-githhub
Copy link

@lululxvi
Copy link
Owner

1. I think that there is no need to implement any boundary conditions.  Do you agree?

I am not sure about this.

2. it is not necessary to provide the true solution, and the PINNs should learn the right solution benefiting from reducing the loss function eventually.

Yes.

3. I observed that the training loss reduces significantly, but I am afraid I could not see expected results ?

This may be related to (1).

4. I arbitrarily chose 10(3) training(testing) points per wavelength respectively, however I could only obtain 960 points in the output, and could not reshape to visualize the image.

You can predict the solution at any point after training.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants