Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

None gradients for 'y' layers #27

Closed
najwalb opened this issue Apr 1, 2023 · 4 comments
Closed

None gradients for 'y' layers #27

najwalb opened this issue Apr 1, 2023 · 4 comments

Comments

@najwalb
Copy link

najwalb commented Apr 1, 2023

It looks like the gradients of the y_mlp_out and all components involving y in the last transformer neural network layer are None. Therefore, this part of the model is not training. The components of other inputs (X and E) seems to be working normally.

To reproduce the behavior, replace the 'trainer' line with this code:

    model = model.train() 
    #print(f'model {count_parameters(model)}')
    print('==== Done loading the model...')

    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.train.lr, amsgrad=True,
                                  weight_decay=cfg.train.weight_decay)
    
    train_loader = datamodule.train_dataloader()
    losses = []       
    for epoch in range(cfg.train.n_epochs):
        total_loss = 0
        for i, train_samples in enumerate(train_loader):
            # train_samples = train_samples.to(device)
            loss = model.training_step(train_samples, i) # loss for one batch
            loss = loss['loss']
            loss.backward()
            for name, p in model.named_parameters():
                if '_y' in name:
                    print(f'name: {name}, requires {p.requires_grad}, p {p}, grad, {p.grad}\n')
            optimizer.step()
            optimizer.zero_grad()
            total_loss += loss.cpu().detach().numpy()
            exit() # just to show None in one iteration

The problem appears to be that the 'y' output is not used when computing the loss. I am not sure how to use a cross-entropy loss on X and E alone but still back-propagating to the layers of 'y'.

It's also not clear why the None gradients only appear in the last layers.

@haoming-codes
Copy link

y is indeed not used for computing the loss. The input y to the transformer is the graph-level feature of the noisy_data, computed by compute_extra_data. The output y from the transformer is not used as input to the next denoising step.

@najwalb
Copy link
Author

najwalb commented Apr 11, 2023

@haoming-codes yes and this leads to the network layers using y to not be updated during training.

@cvignac
Copy link
Owner

cvignac commented Apr 11, 2023

Hello,
all transformer layers take as input X, E and y. Even if the output dimension of y is eventually 0, y is still useful. The only thing that is not trained is mlp_out_y, that you can disable if you want.

For the regressor model in the conditional generation experiments on the contrary, the output dimensions of X and E are 0, but the output dimension of yis 1 or 2.

Clement

@najwalb
Copy link
Author

najwalb commented Apr 13, 2023

The part of the network transforming y in the last transformer layer (y_y, e_y, x_y) is also not training. But I get what you mean by 'y' is still useful, since it's at least incorporating time to the other variables in the network. Thanks for clarifying!

Best,

@najwalb najwalb closed this as completed Apr 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants