Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An issue about the default initialize methods. #4555

Closed
LiChenda opened this issue Aug 3, 2022 · 10 comments
Closed

An issue about the default initialize methods. #4555

LiChenda opened this issue Aug 3, 2022 · 10 comments

Comments

@LiChenda
Copy link
Contributor

LiChenda commented Aug 3, 2022

Hello, I'm training some custom models with ESPNet2. And I find that the default initialize method checking the bias parameters of NN by the p.dim() function. This check is not precise (some parameters my also make p.dim() == 1 be True ) and set them to zero may lead to abnormal training. (I found one of my module always output zeros and the training may dead with ReLU. After removing the following lines, the training becomes OK. )

# bias init
for p in model.parameters():
if p.dim() == 1:
p.data.zero_()

@sw005320
Copy link
Contributor

sw005320 commented Aug 8, 2022

@LiChenda, thanks for the report.
Can you tell me more examples of when some parameters also make p.dim() == 1 be True?
How about making your treatment an option?

@LiChenda
Copy link
Contributor Author

LiChenda commented Aug 10, 2022

Some parameters like the weight $\gamma$ in BatchNorm, the weight $\alpha$ in PReLU, and some custom parameters defined with torch.nn.Parameter() may make p.dim() == 1 be True.
One possible solution might be updating these lines https://github.com/espnet/espnet/blob/96bd74641ceb463096067223d0734f70bddd8def/espnet2/torch_utils/initialize.py#L77-L80
with:

for name, p in model.named_parameters(): 
    if 'bias' in name:
        p.data.zero_() 

@LiChenda
Copy link
Contributor Author

BTW, this function is called by amlost all the tasks in ESPnet2, I'm not sure if these updating keeps the previously trained model reproducible.

@sw005320
Copy link
Contributor

for name, p in model.named_parameters(): 
    if 'bias' in name:
        p.data.zero_() 

This sounds good to me.
Yes, we can make this an option and make default false in some timing.
Meanwhile, we can conduct some tests and make default true.
Can you make a PR?

@LiChenda
Copy link
Contributor Author

Sure!

@popcornell
Copy link
Contributor

popcornell commented Aug 10, 2022

for name, p in model.named_parameters(): 
    if 'bias' in name:
        p.data.zero_() 

maybe make it case insensitive too just to be sure.
This is difficult to implement in a scalable way. Current solution thus makes sense.

Maybe raising a warning also could be useful so user knows that the bias is set to zero. But warning should be likely raised once (e.g. for each layer of that class) otherwise will be too verbose maybe.

@b-flo
Copy link
Member

b-flo commented Aug 10, 2022

BTW, this function is called by amlost all the tasks in ESPnet2, I'm not sure if these updating keeps the previously trained model reproducible.

It should be noted that args.init is set to None by default for all tasks and most configs don't use initialization.
We currently have 1026 configs (!!). 172 configs use either chainer or xavier_uniform initialization, from which 110 are for the ASR task and 59 for the ENH task.
From what I see, it doesn't seem to impact config we usually use, at least for ASR.

@LiChenda
Copy link
Contributor Author

BTW, this function is called by amlost all the tasks in ESPnet2, I'm not sure if these updating keeps the previously trained model reproducible.

It should be noted that args.init is set to None by default for all tasks and most configs don't use initialization. We currently have 1026 configs (!!). 172 configs use either chainer or xavier_uniform initialization, from which 110 are for the ASR task and 59 for the ENH task. From what I see, it doesn't seem to impact config we usually use, at least for ASR.

Thanks for pointing it out. I just made a PR for this issue, see #4574 .

@LiChenda
Copy link
Contributor Author

The PR is merged, so I close this issue.

@kan-bayashi
Copy link
Member

I found this issue has an impact on the most of TTS models since TTS modules is initialized using this method as a default.

# initialize parameters
self._reset_parameters(
init_type=init_type,
init_enc_alpha=init_enc_alpha,
init_dec_alpha=init_dec_alpha,
)

if init_type != "pytorch":
initialize(self, init_type)

The model including BatchNorm2d or BatchNorm1d modules should be improved with this changes.
(I'm not sure what is happened before fixing this issue...)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants