Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deprecate LayerNormFp32 #850

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-training.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torch>=1.9.0
torch>=1.10.0
torchvision
webdataset>=0.2.5
regex
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torch>=1.9.0
torch>=1.10.0
torchvision
regex
ftfy
Expand Down
5 changes: 1 addition & 4 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from dataclasses import dataclass

from .transformer import (
LayerNormFp32,
LayerNorm,
QuickGELU,
MultimodalTransformer,
Expand Down Expand Up @@ -58,9 +57,7 @@ def _build_text_decoder_tower(
):
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = (
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
)
norm_layer = LayerNorm

decoder = MultimodalTransformer(
context_length=multimodal_cfg.context_length,
Expand Down
6 changes: 3 additions & 3 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .hf_model import HFTextEncoder
from .modified_resnet import ModifiedResNet
from .timm_model import TimmModel
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\
from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\
text_global_pool
from .utils import to_2tuple

Expand Down Expand Up @@ -139,7 +139,7 @@ def _build_vision_tower(
)
else:
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
norm_layer = LayerNorm
if vision_cfg.norm_kwargs:
norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
if vision_cfg.act_kwargs is not None:
Expand Down Expand Up @@ -190,7 +190,7 @@ def _build_text_tower(
)
else:
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
norm_layer = LayerNorm
if text_cfg.norm_kwargs:
norm_layer = partial(norm_layer, **text_cfg.norm_kwargs)
if text_cfg.act_kwargs is not None:
Expand Down
5 changes: 4 additions & 1 deletion src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@


class LayerNormFp32(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).

Deprecated: pytorch 1.10+ always performs LayerNorm in fp32. Retained for checkpoint compatibility.
"""

def forward(self, x: torch.Tensor):
orig_type = x.dtype
Expand Down