Skip to content

Commit

Permalink
Merge pull request #4574 from LiChenda/fix_init
Browse files Browse the repository at this point in the history
update checks for bias in initialization
  • Loading branch information
sw005320 committed Aug 24, 2022
2 parents f274ebe + ed7db10 commit f51b3be
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions espnet2/torch_utils/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

"""Initialize modules for espnet2 neural networks."""

import logging
import math

import torch
Expand All @@ -24,11 +25,17 @@ def initialize(model: torch.nn.Module, init: str):

if init == "chainer":
# 1. lecun_normal_init_parameters
for p in model.parameters():
for name, p in model.named_parameters():
data = p.data
if data.dim() == 1:
if ".bias" in name and data.dim() == 1:
# bias
data.zero_()
logging.info(f"Initialize {name} to zeros")
elif data.dim() == 1:
# linear weight
n = data.size(0)
stdv = 1.0 / math.sqrt(n)
data.normal_(0, stdv)
elif data.dim() == 2:
# linear weight
n = data.size(1)
Expand Down Expand Up @@ -75,9 +82,10 @@ def initialize(model: torch.nn.Module, init: str):
else:
raise ValueError("Unknown initialization: " + init)
# bias init
for p in model.parameters():
if p.dim() == 1:
for name, p in model.named_parameters():
if ".bias" in name and p.dim() == 1:
p.data.zero_()
logging.info(f"Initialize {name} to zeros")

# reset some modules with default init
for m in model.modules():
Expand Down

0 comments on commit f51b3be

Please sign in to comment.