Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loss function problem #97

Closed
alan0324 opened this issue Jul 6, 2023 · 6 comments
Closed

loss function problem #97

alan0324 opened this issue Jul 6, 2023 · 6 comments

Comments

@alan0324
Copy link

alan0324 commented Jul 6, 2023

您好,我在執行train_model.py時遇到了loss張量不合的問題,以下是我遇到的問題:
Traceback (most recent call last):
File "F:\temp\demo\00\pythonProject\attentive-gan-derainnet\tools\train_model.py", line 324, in
train_model(args.dataset_dir, weights_path=args.weights_path)
File "F:\temp\demo\00\pythonProject\attentive-gan-derainnet\tools\train_model.py", line 116, in train_model
train_gan_loss, train_discriminative_loss, train_net_output = derain_net.compute_loss(
File "F:\temp\demo\00\pythonProject\attentive-gan-derainnet\attentive_gan_model\derain_drop_net.py", line 52, in compute_loss
auto_encoder_loss, auto_encoder_output = self._attentive_gan.compute_autoencoder_loss(
File "F:\temp\demo\00\pythonProject\attentive-gan-derainnet\attentive_gan_model\attentive_gan_net.py", line 333, in compute_autoencoder_loss
lm_loss = tf.add(lm_loss, mse_loss)
File "F:\temp\demo\00\pythonProject\venv\lib\site-packages\tensorflow\python\util\traceback_utils.py", line 153, in error_handler
raise e.with_traceback(filtered_tb) from None
File "F:\temp\demo\00\pythonProject\venv\lib\site-packages\tensorflow\python\framework\ops.py", line 7262, in raise_from_not_ok_status
raise core._status_to_exception(e) from None # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node _wrapped__AddV2_device/job:localhost/replica:0/task:0/device:CPU:0}} Incompatible shapes: [1,60,90] vs. [1,120,180]
請問要如何解決此問題
非常感謝您提供的代碼!

@MaybeShewill-CV
Copy link
Owner

@alan0324 检查src_image和label_image的尺寸是不是对应:)

@alan0324
Copy link
Author

alan0324 commented Jul 6, 2023

感謝您的回覆!! 請問是以下嗎if name == 'main':
input_image = tf.placeholder(dtype=tf.float32, shape=[1, 256, 256, 3])
auto_label_image = tf.placeholder(dtype=tf.float32, shape=[1, 256, 256, 3])
rnn_label_image = tf.placeholder(dtype=tf.float32, shape=[1, 256, 256, 1])
net = GenerativeNet(phase=tf.constant('train', tf.string))
rnn_loss = net.compute_attentive_rnn_loss(input_image, rnn_label_image, name='rnn_loss')
autoencoder_loss = net.compute_autoencoder_loss(input_image, auto_label_image, name='autoencoder_loss')
for vv in tf.trainable_variables():
print(vv.name)
不好意思 對這塊還不太熟QQ

@alan0324
Copy link
Author

alan0324 commented Jul 6, 2023

另外補充我在for迴圈print出來的結果lm_loss shape: ()
lm_loss size: tf.Tensor(1, shape=(), dtype=int32)
mse_loss shape: (1, 60, 90)
mse_loss size: tf.Tensor(5400, shape=(), dtype=int32)
lm_loss shape: (1, 60, 90)
lm_loss size: tf.Tensor(5400, shape=(), dtype=int32)
mse_loss shape: (1, 120, 180)
mse_loss size: tf.Tensor(21600, shape=(), dtype=int32)

@MaybeShewill-CV
Copy link
Owner

@alan0324 对 看报错是你输入的原始图像和label图像尺寸不一致导致的 可以检查下;)

@alan0324
Copy link
Author

alan0324 commented Jul 9, 2023

@MaybeShewill-CV 您好,經過了一番嘗試,即使更改了原始圖像與label圖像的尺寸也沒有改變我的報錯信息,我的mse_loss還是會在for迴圈的計算中增加一倍,請問除了尺寸不一致以外,還有什麼可能會導致錯誤的原因嗎?

@alan0324
Copy link
Author

alan0324 commented Jul 9, 2023

附上我的尺寸信息
__C.TRAIN.IMG_HEIGHT = 480

Set train image width

__C.TRAIN.IMG_WIDTH = 720

Set train image height

__C.TRAIN.CROP_IMG_HEIGHT = 240

Set train image width

__C.TRAIN.CROP_IMG_WIDTH = 360

if name == 'main':
input_image = tf.keras.Input(dtype=tf.float32, shape=[1, 240, 360, 3])
auto_label_image = tf.keras.Input(dtype=tf.float32, shape=[1, 240, 360, 3])
rnn_label_image = tf.keras.Input(dtype=tf.float32, shape=[1, 240, 360, 1])

if name == 'main':
"""
test
"""
input_tensor = tf.keras.Input(dtype=tf.float32, shape=[5, 480, 720, 3])
label_tensor = tf.keras.Input(dtype=tf.float32, shape=[5, 480, 720, 3])
mask_tensor = tf.keras.Input(dtype=tf.float32, shape=[5, 480, 720, 1])

@alan0324 alan0324 closed this as completed Aug 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants