Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xformers ViT-B ImageNet MAE + Deepnorm training instability #219

Open
jramapuram opened this issue Mar 1, 2022 · 62 comments · Fixed by #221, #220 or #229
Open

xformers ViT-B ImageNet MAE + Deepnorm training instability #219

jramapuram opened this issue Mar 1, 2022 · 62 comments · Fixed by #221, #220 or #229
Assignees
Labels

Comments

@jramapuram
Copy link

jramapuram commented Mar 1, 2022

🐛 Bug

I'm trying to create a 1:1 config that can train a stable ViT-B with the MAE config (from appendix A.2).

Maybe I'm missing something (highly plausible), but when I use xformers instead of timm it creates an unstable training scenario [over numerous trials] with exactly the same hyper-parameters (batch_size=4096 + cutmix + mixup + label smoothing + AdamW[0.9, 0.95], lr=1e-4 [with scaling rule ofc], lr warmup + cosine decay, skip bias/CLS/pos_embed weight decay, etc, etc).

image

xformers ViT-B Config

reversible: False
block_type: "encoder"
num_layers: 12
dim_model: 768
layer_norm_style: "pre"

multi_head_config:
  num_heads: 12
  residual_dropout: 0.1  # (1) tried without this, (2) swapping this for DropPath, (3) with regular dropout
  use_rotary_embeddings: False

  attention:
    name: "scaled_dot_product"
    dropout: 0.0
    causal: False

feedforward_config:
  name: "MLP"
  dropout: 0.0
  activation: "gelu"
  hidden_layer_multiplier: 4

xformers ViT-B

"""A simple ViT-B in xformers."""

import typing as t
from pathlib import Path
import yaml

import torch
from torch import nn
from timm.models.vision_transformer import DropPath
from timm.models.layers.patch_embed import PatchEmbed
from timm.models.layers.weight_init import trunc_normal_
from xformers.factory import xFormer, xFormerConfig


def _init_vit_weights(module: nn.Module):
    """Transformer weight initialization from TIMM."""
    if isinstance(module, nn.Linear):
        trunc_normal_(module.weight, std=0.02)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
        nn.init.zeros_(module.bias)
        nn.init.ones_(module.weight)


class ViT(nn.Module):
    """Vision transformer + head module."""

    def __init__(
        self,
        img_size: int = 224,
        in_chans: int = 3,
        patch_size: int = 16,
        num_classes: int = 1000,
        drop_path_rate: float = 0,
        transfomer_config_file: t.Union[Path, str] = "configs/vit_b.yaml",
    ):
        """A standard ViT module."""
        super().__init__()

        # read the model config
        with open(transfomer_config_file, "rb") as fileptr:
            self.model_config = yaml.load(fileptr, Loader=yaml.FullLoader)

        # embed_dim = self.model_config["block_config"]["dim_model"]
        embed_dim = self.model_config["dim_model"]
        print(self.model_config)

        # build the patch embedding model
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
            flatten=True,
        )
        self.num_patches = self.patch_embed.num_patches

        # Build the tokens / position embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(  # +1 for CLS token
            torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)
        )

        # build the backbone
        self.backbone = xFormer.from_config(xFormerConfig([self.model_config]))

        # Swap dropout with drop-path
        # Also tried (1) without this, (2) without dropout.
        if drop_path_rate > 0:
            dpr_idx = 0
            dpr = [
                x.item()
                for x in torch.linspace(0, drop_path_rate, len(self.backbone.encoders))
            ]
            for layer in self.backbone.encoders:
                if hasattr(layer.mha, "resid_drop"):
                    setattr(layer.mha, "resid_drop", DropPath(dpr[dpr_idx]))
                    dpr_idx += 1

        # build the head network
        self.head = nn.Sequential(
            nn.LayerNorm(embed_dim, eps=1e-6),
            nn.Linear(embed_dim, num_classes)
        )

        # do ViT initializations
        self.init_weights()

    def init_weights(self):
        """Initialize layers, pos_embed and CLS for ViTs."""
        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(_init_vit_weights)

    def forward(self, inputs: torch.Tensor) -> t.Dict[str, torch.Tensor]:
        """Infer variates and return a dict with repr and logits.

        Example sizing:
        patches = [2, 196, 768] --> [2, 197, 768] (after CLS)
        representation = [2, 197, 768]
        logits = [2, 197, 1000]
        CLS = [2, 1000]  # select first of 197

        """
        patches = self.patch_embed(inputs)
        cls_token = self.cls_token.expand(inputs.shape[0], -1, -1)
        out = torch.cat((cls_token, patches), dim=1)
        out = out + self.pos_embed

        representation = self.backbone(out)
        logits = self.head(representation)
        return {
            "representation": representation.detach(),
            "logits": logits,
            "CLS": logits[:, 0],
        }

Command

vit_b = ViT()
vit_b_timm = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=1000)

number_of_parameters(vit_b)
Out[18]: 86567656

number_of_parameters(vit_b_timm)
Out[19]: 86567656

To Reproduce

Steps to reproduce the behavior:

  1. Train a ViT-B with xformers --> unstable
  2. Using same training setup with timm --> stable
  3. 😢
@blefaudeux
Copy link
Contributor

thanks for the issue, just saw that, looking into it !

@blefaudeux
Copy link
Contributor

@jramapuram can you elaborate on your config, do you have Triton for instance ? Could you share a print(model) here, to be sure of what parts are actually instantiated ? After a quick look it seems that there could be a part where the gradients are not handled up to the same precision as torch, at least that's my #1 hypothesis

@blefaudeux
Copy link
Contributor

many thanks for the detailed issue and code snippets, this is perfect

@blefaudeux blefaudeux self-assigned this Mar 1, 2022
@blefaudeux
Copy link
Contributor

if you're using Triton, could you test out installing a recent dev package ? pip install triton==1.1.2.dev20220106

@blefaudeux
Copy link
Contributor

also @jramapuram could you confirm that this is with torch AMP ? (fp16)

@blefaudeux
Copy link
Contributor

blefaudeux commented Mar 1, 2022

cc @dianaml0, @fmassa, is that something that you've seen ? I remember @xwhan saw that at some point, but I thought that this was fixed. I just did a quick check in the triton code, and we're keeping the data type as fp32 in the softmax and layernorm case when AMP is activated, which should lead to a similar precision as pytorch (layernorm is a bit below). It looks like a vanishing gradient problem, and the parts here are very standard (MLP and scaled_dot_product attention), I'm wondering whether it could be somewhere else in the code, or if the timm ViT adds some parameter-less normalization for instance. I'm not seeing this on the Cifar example that we host

edit: adding some more context and info

@blefaudeux
Copy link
Contributor

@jramapuram the eps parameter for LayerNorm is not the same in between timm and xformers (1e-5 vs. 1e-6), it's a long shot but since your issue could be related to vanishing gradient, could explain. Fixing that

@jramapuram
Copy link
Author

Filling in details:

  • AMP FP16 ✅
  • triton==1.1.1 ( can test 1.1.2.dev20220106 👍 )
  • Will try the layernorm eps; good find! Might be relevant for AMP

Instantiated model print to STDOUT: https://gist.github.com/jramapuram/d284e0f261d3fdb15c213dd929d272b9

@blefaudeux
Copy link
Contributor

Filling in details:

* AMP FP16 white_check_mark

* triton==1.1.1 ( can test 1.1.2.dev20220106 +1  )

* Will try the layernorm eps; good find! Might be relevant for AMP

Instantiated model print to STDOUT: https://gist.github.com/jramapuram/d284e0f261d3fdb15c213dd929d272b9

I can repro the problem with the minimal microViT example actually (prior to the linked PRs), just need to wait long enough. Testing right now with the changes from the linked PRs

@blefaudeux
Copy link
Contributor

seems fine with the updated eps @jramapuram, let me know if it fixes your issue ?

@jramapuram
Copy link
Author

jramapuram commented Mar 1, 2022

Training now; will update here :)

def update_ln_eps(module: nn.Module, new_eps: float):
    """Recurse and update LN eps with this value."""
    from xformers.triton.layer_norm import FusedLayerNorm

    if isinstance(module, torch.nn.modules.LayerNorm):
        module.eps = new_eps

    if isinstance(module, FusedLayerNorm):
        module.epsilon = new_eps

    for _, child in module.named_children():
        update_ln_eps(child, new_eps)

@jramapuram
Copy link
Author

@blefaudeux : Unfortunately this has not seemed to fix it for me 😬 . Not sure if the scaling from microViT --> ViT-B ImageNet might be causing some issues that are not easily evident.

With LN fix using function above:
image

With Triton 1.1.2.dev20220106 (tested with pip freeze to validate)
image

Commit d4c28fb (tried with and without triton 1.1.2.dev20220106):
image

For sanity I also tried again swapping back to TIMM and it is still working 😬
image

@blefaudeux blefaudeux reopened this Mar 2, 2022
@blefaudeux
Copy link
Contributor

ouch, this is not good.. the issue auto-closed it seems, but keeping it open, I'll try to dig a bit more

@blefaudeux
Copy link
Contributor

@jramapuram to try to pinpoint this a little better (and if you have time), could you try in an environment which does not have Triton ? a few parts will default switch to PyTorch, if you don't see an issue there then I would know where to look (well, softmax and layernorm)

@blefaudeux
Copy link
Contributor

blefaudeux commented Mar 2, 2022

Else I can think of

  • different init strategies for the weights (probable but kind of unlikely that it explains I think)
  • shared weights in the projection,
  • different pre/post normalization
    looking into Timm's implementation.

I can confirm that it does not happen on cifar and a smaller ViT unfortunately, would have been nice to have an easy repro

edit: adding more context

@blefaudeux
Copy link
Contributor

@jramapuram to try to pinpoint this a little better (and if you have time), could you try in an environment which does not have Triton ? a few parts will default switch to PyTorch, if you don't see an issue there then I would know where to look (well, softmax and layernorm)

testing with pure pytorch layers right now, and I'm not seeing any difference so far, so might not be a good explanation

@blefaudeux
Copy link
Contributor

blefaudeux commented Mar 2, 2022

Else I can think of

* different init strategies for the weights (probable but kind of unlikely that it explains I think)

init is different indeed, see for instance, while xformers mostly follows default pytorch

* shared weights in the projection,

the projection seems to follow the same structure, n x 3n matrix + bias, nothing different here

* different pre/post normalization

nope, Pre-norm in both cases

in short I don't see much difference (provided my home test with pytorch vs. triton parts is confirmed on your end @jramapuram) except for weights init, since AMP training is notoriously a little finicky maybe that could explain ? Not super intuitive to me but having a deeper look

@jramapuram
Copy link
Author

  • Will try vanilla pytorch (without triton) on ImgNet for my own sanity as well 😅
  • PreNorm ✅
  • I do the custom TIMM init already (see code above which distills this ; will also try a lower std (std=0.01) as well.):
def _init_vit_weights(module: nn.Module):
    """Transformer weight initialization from TIMM."""
    if isinstance(module, nn.Linear):
        trunc_normal_(module.weight, std=0.02)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
        nn.init.zeros_(module.bias)
        nn.init.ones_(module.weight)

@blefaudeux : are there any xformers linear layers that don't inherit nn.Linear that might be missed by this init function?

Thanks for the great suggestions btw!

@blefaudeux
Copy link
Contributor

  • Will try vanilla pytorch (without triton) on ImgNet for my own sanity as well sweat_smile

    • PreNorm white_check_mark

    • I do the custom TIMM init already (see code above which distills this ; will also try a lower std (std=0.01) as well.):

def _init_vit_weights(module: nn.Module):
    """Transformer weight initialization from TIMM."""
    if isinstance(module, nn.Linear):
        trunc_normal_(module.weight, std=0.02)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
        nn.init.zeros_(module.bias)
        nn.init.ones_(module.weight)

ahh, I didn't know for the init on your side, so this rules it out also !

@blefaudeux : are there any xformers linear layers that don't inherit nn.Linear that might be missed by this init function?

No I don't think so, although fused MLP uses a normal nn.Linear but fuses the dropout/bias/activation (so the bias init would be missed). It does not seem like you're using fusedMLP so it should not be the case

Thanks for the great suggestions btw!

No problem, this is a little perplexing to be honest, but we'll root it out !

@blefaudeux
Copy link
Contributor

blefaudeux commented Mar 2, 2022

seeing your curves, it does seem a little different from what I was seeing prior to the eps adjustment: validation accuracy was collapsing in the microViT example / CIFAR, but over many steps, while yours looks like a complete breakdown, one update completely breaks the model. Really looks like a raw fp16 representation problem, an underflow or overflow would look like that

this is what a faulty normalization floor looked like (eps = 1e-5, pre/post correction), not really what you're seeing, unless it's a logging issue (not logging often enough, but guess is no since I'm seeing your steps axis and you seem to log per step)
gnome-shell-screenshot-a985s4

@blefaudeux
Copy link
Contributor

hmm turns out I was testing with rotary embeddings turned on, and they make a huge difference

@jramapuram
Copy link
Author

Lower std on trunc normal init (0.01):
image

Without triton:
image

WARNING:root:Triton is not available, some optimizations will not be enabled.
Error No module named 'triton'

FusedMLP:
image

        (feedforward): FusedMLP(
          (mlp): Sequential(
            (0): Linear(in_features=768, out_features=3072, bias=True)
            (1): FusedDropoutBias(
              (pytorch_activation): GELU()
            )
            (2): Linear(in_features=3072, out_features=768, bias=True)
            (3): FusedDropoutBias(
              (pytorch_activation): Passthrough()
            )
          )
        )

@blefaudeux
Copy link
Contributor

thanks @jramapuram, it's very informative, so no issues with the triton layers whatsover, the problem is in a pure pytorch definition.. :/

@jramapuram
Copy link
Author

jramapuram commented May 16, 2022

Still no joy on ef6de0f 😬 .

Here I only init just the pos_embed and cls_token with trunc_normal_(std=0.02) and use DeepNorm:
image
image

Edit: updated curves to compare to prenorm.
image

@blefaudeux
Copy link
Contributor

Still no joy on ef6de0f grimacing . Will show pre-norm plot for comparison soon.

Here I only init just the pos_embed and cls_token with trunc_normal_(std=0.02) and use DeepNorm: image image

oh yes for current main branch, nothing landed addressing this yet. Could you try #303 by any chance ? I can try to start something later today, but a little bit underwater atm :/

@jramapuram
Copy link
Author

jramapuram commented May 16, 2022

oh yes for current main branch, nothing landed addressing this yet. Could you try #303 by any chance ? I can try to start something later today, but a little bit underwater atm :/

No worries! Will give that a shot :) [feel better!]

I added the reference pre-norm graphs above. Differences are basically:

  1. CLS + pos_embed init only : i.e. use default xformer init
  2. CLS + pos_embed + weight init
  3. CLS + pos_embed + weight_init + LN init

@blefaudeux blefaudeux changed the title xformers ViT-B ImageNet MAE training instability xformers ViT-B ImageNet MAE + Deepnorm training instability May 19, 2022
@blefaudeux
Copy link
Contributor

oh wow, it's pretty clear indeed, thanks @jramapuram. #303 is definitely fixing a small bug, but I doubt that it explains this really, I'll dive back into deepnorm. I may have a repro actually, with the recent metaformer+cifar10 deepnorm does not work either but I thought that was because of the decidely different model structure. I'll give it a second look, sorry for the delay

@blefaudeux
Copy link
Contributor

blefaudeux commented May 20, 2022

hmm, I did spend some time on that and found nothing obviously wrong, it's really perplexing. I'll give IN a shot. If you have the option, would it be possible to test this without AMP, just in case it's a matter of numerical accuracy (which would not be caught by the grad scaler if not NaN) ?

@blefaudeux
Copy link
Contributor

Just in case @jramapuram, could you check that you're using triton == 2.0.0.dev20220403 ? I'm bumping into numerical stability issues with newer versions of triton (being worked on) and the symptoms are very similar (accuracy crash but not NaNs)

@blefaudeux
Copy link
Contributor

(no, I've not forgotten that issue.. ). I would love to be able to repro on something a little smaller than a full blown IN + training over a couple of nodes, documenting that here. Attached is a minGPT training setup, with pre/post/deepnorm (8 layers transformer, 25M params). Deepnorm doesn't converge to a solution which is as good as the others, but no catastrophic failure for either of them
Screenshot from 2022-05-23 13-23-39

@jramapuram
Copy link
Author

Just in case @jramapuram, could you check that you're using triton == 2.0.0.dev20220403 ? I'm bumping into numerical stability issues with newer versions of triton (being worked on) and the symptoms are very similar (accuracy crash but not NaNs)

Thanks for keeping this in mind @blefaudeux. Just checked, using triton==2.0.0.dev20220430 -- I can drop down and test!

Re the minGPT: I'm surprised there is a perf drop -- does the test loss / negative-log-likelihood to follow the same trend?

@blefaudeux
Copy link
Contributor

Just in case @jramapuram, could you check that you're using triton == 2.0.0.dev20220403 ? I'm bumping into numerical stability issues with newer versions of triton (being worked on) and the symptoms are very similar (accuracy crash but not NaNs)

Thanks for keeping this in mind @blefaudeux. Just checked, using triton==2.0.0.dev20220430 -- I can drop down and test!

Re the minGPT: I'm surprised there is a perf drop -- does the test loss / negative-log-likelihood to follow the same trend?

20220430 was fine, the ones after that were broken, but fixed by triton-lang/triton@205a493 so it's back to being good at the moment ! re-minGPT I can check the other metrics, as mentioned in another thread I think that it may be due to the distribution being hardcoded right now for deepnorm, I think it's not very readable, hackable, and not a great design overall, I'd like to come up with something better and more explicit (for instance with a couple of possible inits as part of the xformers config, and deepnorm respecting that). It's always possible to init from the outside, but it's tied to parameter naming conventions (not super clear right now), and it kind of negates the point of supporting deepnorm to begin with I think

@jramapuram
Copy link
Author

Unfortunately no joy @blefaudeux. I tried:

  1. Triton downgrade + pos + cls init
  2. Triton downgrade + pos + cls + weight init
  3. triton==2.0.0.dev20220430 + pos + cls init
  4. triton==2.0.0.dev20220430 + pos + cls + weight init

image

@blefaudeux
Copy link
Contributor

thanks a bunch @jramapuram ! I've a draft PR getting ready which rewrites a lot of the input projections (something we discussed earlier) + explicit handling of a couple of init methods (optional, users are still free to do as they please), I'm hoping that it solves this. To give an insight, I think that this setting is not well handled and could be the culprit (deepnorm assumes a different projection per Q/K/V, and the default here should probably be "true" I believe)

@blefaudeux
Copy link
Contributor

I think that #312 is getting there @jramapuram, it's a lot cleaner to my eyes. Something I've seen, related to your curves above, is that it's not just deepnorm, the post- normalization path does not play well with ViT. GPT is fine with this nornalization path, I don't know if it's a known fact, I would need to check the literature. Since deepnorm is a subset of the post- normalization code path, it makes a little more sense, or at least it's not alone

@blefaudeux
Copy link
Contributor

ok, beyond #312 which cleans things up, it looks like (given Timm, here) layernorm requires a specific treatment for ViT+Post, the weight is initialized to a very small value (vs. 1 typically). Since in our case Post & Deepnorm (same residual codepath) both fail with ViT but work well with GPT, it could explain why. I'll give that a shot

@blefaudeux
Copy link
Contributor

I've not forgotten that @jramapuram, turns out that for vision / post norm Swin v2 already solved this (related to the message above), see their paper. The initial weights need to be scaled way down, I'll try to implement this in xformers when I get the time

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment