diff --git a/alphafold2_pytorch/alphafold2.py b/alphafold2_pytorch/alphafold2.py index ade0242..8a10aa3 100644 --- a/alphafold2_pytorch/alphafold2.py +++ b/alphafold2_pytorch/alphafold2.py @@ -49,6 +49,11 @@ def default(val, d): def cast_tuple(val, depth = 1): return val if isinstance(val, tuple) else (val,) * depth +def init_zero_(layer): + nn.init.constant_(layer.weight, 0.) + if exists(layer.bias): + nn.init.constant_(layer.bias, 0.) + # helper classes class Always(nn.Module): @@ -82,6 +87,7 @@ def __init__( nn.Dropout(dropout), nn.Linear(dim * mult, dim) ) + init_zero_(self.net[-1]) def forward(self, x, **kwargs): x = self.norm(x) @@ -114,6 +120,7 @@ def __init__( nn.init.constant_(self.gating.bias, 1.) self.dropout = nn.Dropout(dropout) + init_zero_(self.to_out) def forward(self, x, mask = None, attn_bias = None, context = None, context_mask = None, tie_dim = None): device, orig_shape, h, has_context = x.device, x.shape, self.heads, exists(context) @@ -606,6 +613,8 @@ def __init__( self.to_quaternion_update = nn.Linear(dim, 6) + init_zero_(self.ipa_block.attn.to_out) + self.to_points = nn.Linear(dim, 3) # aux confidence measure diff --git a/setup.py b/setup.py index 8a896df..f31e166 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'alphafold2-pytorch', packages = find_packages(), - version = '0.4.30', + version = '0.4.31', license='MIT', description = 'AlphaFold2 - Pytorch', author = 'Phil Wang, Eric Alcaide',