Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Network not learning #22

Open
ValterFallenius opened this issue Mar 5, 2022 · 7 comments
Open

Network not learning #22

ValterFallenius opened this issue Mar 5, 2022 · 7 comments
Labels
bug Something isn't working

Comments

@ValterFallenius
Copy link

Training loss is not decreasing
I have implemented the network in the PR "lightning" branch with pytorch lightning and tried to find any bugs. The network compiles without issues and seems to generate gradients but still network fails to learn anything. I have tried to play around with the learning rate and plot the data at different stages but even with 4 training samples (it should be able to overfit these) it fails to decrease the loss even after 100 epochs...

Here is the training loss plotted:
W B Chart 2022-03-05 11_20_29

It seems like it's doing something but not nearly quick enough to overfit the small dataset. Something is wrong...

Hyperparameters:
n_samples = 4
hidden_dim=8,
forecast_steps=1,
input_channels=15,
output_channels=6, #512
input_size=112, # 112
n_samples = 100,
num_workers = 8,
batch_size = 1,
learning_rate = 1e-2

Below is a weights&biases grad report. As you can see most gradients are non-zero, I'm not sure why image_encoder has very small gradients for their biases...

wandb: epoch 83 wandb: grad_2.0_norm/head.bias_epoch 0.0746 wandb: grad_2.0_norm/head.bias_step 0.049 wandb: grad_2.0_norm/head.weight_epoch 0.0862 wandb: grad_2.0_norm/head.weight_step 0.081 wandb: grad_2.0_norm/image_encoder.module.module.0.bias_epoch 0.0 wandb: grad_2.0_norm/image_encoder.module.module.0.bias_step 0.0 wandb: grad_2.0_norm/image_encoder.module.module.0.weight_epoch 0.06653 wandb: grad_2.0_norm/image_encoder.module.module.0.weight_step 0.043 wandb: grad_2.0_norm/image_encoder.module.module.2.bias_epoch 0.00017 wandb: grad_2.0_norm/image_encoder.module.module.2.bias_step 0.0001 wandb: grad_2.0_norm/image_encoder.module.module.2.weight_epoch 0.003 wandb: grad_2.0_norm/image_encoder.module.module.2.weight_step 0.0019 wandb: grad_2.0_norm/image_encoder.module.module.3.bias_epoch 0.0 wandb: grad_2.0_norm/image_encoder.module.module.3.bias_step 0.0 wandb: grad_2.0_norm/image_encoder.module.module.3.weight_epoch 0.16387 wandb: grad_2.0_norm/image_encoder.module.module.3.weight_step 0.1125 wandb: grad_2.0_norm/image_encoder.module.module.4.bias_epoch 0.00013 wandb: grad_2.0_norm/image_encoder.module.module.4.bias_step 0.0001 wandb: grad_2.0_norm/image_encoder.module.module.4.weight_epoch 0.00203 wandb: grad_2.0_norm/image_encoder.module.module.4.weight_step 0.0012 wandb: grad_2.0_norm/image_encoder.module.module.5.bias_epoch 0.0 wandb: grad_2.0_norm/image_encoder.module.module.5.bias_step 0.0 wandb: grad_2.0_norm/image_encoder.module.module.5.weight_epoch 0.15237 wandb: grad_2.0_norm/image_encoder.module.module.5.weight_step 0.1151 wandb: grad_2.0_norm/image_encoder.module.module.6.bias_epoch 0.0032 wandb: grad_2.0_norm/image_encoder.module.module.6.bias_step 0.0018 wandb: grad_2.0_norm/image_encoder.module.module.6.weight_epoch 0.00157 wandb: grad_2.0_norm/image_encoder.module.module.6.weight_step 0.0012 wandb: grad_2.0_norm/image_encoder.module.module.7.bias_epoch 0.00497 wandb: grad_2.0_norm/image_encoder.module.module.7.bias_step 0.003 wandb: grad_2.0_norm/image_encoder.module.module.7.weight_epoch 0.11753 wandb: grad_2.0_norm/image_encoder.module.module.7.weight_step 0.0915 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_kv.weight_epoch 0.03763 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_kv.weight_step 0.0277 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_out.bias_epoch 0.0412 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_out.bias_step 0.0289 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_out.weight_epoch 0.05167 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_out.weight_step 0.0369 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_q.weight_epoch 0.0008 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_q.weight_step 0.0008 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_kv.weight_epoch 0.04393 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_kv.weight_step 0.0216 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_out.bias_epoch 0.0412 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_out.bias_step 0.0289 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_out.weight_epoch 0.04287 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_out.weight_step 0.027 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_q.weight_epoch 0.0014 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_q.weight_step 0.0009 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h1.bias_epoch 0.00197 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h1.bias_step 0.001 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h1.weight_epoch 0.03313 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h1.weight_step 0.0216 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h2.bias_epoch 0.00103 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h2.bias_step 0.0004 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h2.weight_epoch 0.00353 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h2.weight_step 0.002 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_zr.bias_epoch 0.00133 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_zr.bias_step 0.0009 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_zr.weight_epoch 0.02123 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_zr.weight_step 0.0147 wandb: grad_2.0_norm_total_epoch 0.31513 wandb: grad_2.0_norm_total_step 0.2254 wandb: train/loss_epoch 1.72826 wandb: train/loss_step 1.73303 wandb: trainer/global_step 251 wandb: validation/loss_epoch 1.76064

I have plotted the inputs as they flow through the layers, and none of them seems to do anything unexpected:
input layer
after image_encoder
after temp_encoder
after agg
after head
after softmax
output vs ground truth

I'm out of ideas and would appreciate any input.

To Reproduce
Steps to reproduce the behavior:

  1. Clone https://github.com/ValterFallenius/metnet
  2. Download data samples from README link
  3. Install requirements
  4. Run the code
@ValterFallenius ValterFallenius added the bug Something isn't working label Mar 5, 2022
@JackKelly
Copy link
Member

Great work plotting all this information!

I'm afraid I'm not familiar enough with the code or the model to really help. But, in the meantime, would it be at all possible to share a link to Weights and Biases for this training run, please? To help us debug.

@ValterFallenius
Copy link
Author

Here is a link to this particular run: w&b-run

The author Casper mentioned the loss seemed too big. Right now I'm using torch.nn.CrossEntropyLoss(y_hat-y). Is this the same they use in the paper?

See his comment below:

My question:

My network doesn't seem to be able to train properly but I am having a hard time finding the bug. I am trying to run it like you suggested on a small subset of the data, with only a single lead time and much fewer hidden layers but it doesn't train well. The pytorch model compiles and doesn't report any bugs but the network still won't reduce the error even when run repeatedly on the same training sample .

His answer:

If the plot is showing log loss the number seems too high e.g. your probability mass on the correct class is on average exp(-1.77) ≈ 1e-77, just prediting no-rain all the time should give you something much better. However one thing to note is log-loss changes after the first few hundred updates are usually super small. Most likely you have a sign error in the loss or not updating the params as the grads are not zero? I can't really help beyond that.

@bmkor
Copy link

bmkor commented Apr 24, 2022

Sorry for my chip in. @ValterFallenius
Wondering where did you get the sample data? both openclimatefix/goes and openclimatefix/mrms in huggingface or somewhere else? Mind sharing me a bit? Wanna give a try as well. Thanks in advance.

@ValterFallenius
Copy link
Author

ValterFallenius commented Apr 24, 2022

Hey @bmkor,

I am using neither actually, you can find my raw data in #27. However I have not published elevation, longitude/latitude data I have used, let me know if you need it. But unless you are writing a thesis for the Swedish government I think you might be better off using the original dataset available on huggingface ^^

Also I am not using any GOES data, since it's of bad quality in Sweden because of the lack of geostationary satellites.

/Valter

@bmkor
Copy link

bmkor commented Apr 24, 2022

Thanks a lot for your prompt reply and comment. Would try to use those available in the huggingface first. See if I can make the model run.

@jacobbieker
Copy link
Member

Thanks a lot for your prompt reply and comment. Would try to use those available in the huggingface first. See if I can make the model run.

Hi, just so you know, the goes dataset currently doesn't have data in it, I'm working through adding data for that. The MRMS dataset does, although I am still finishing up the dataset script. But if you want to get started with that radar data, you can just download the Zarr files themselves and open them locally quite easily.

@CUITCHENSIYU
Copy link

hello, can you tell me where can download MRMS dataset? thank you vary much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants