Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] loss results are different even though random seed is set #1770

Closed
ljm565 opened this issue Apr 14, 2023 · 15 comments
Closed

[BUG] loss results are different even though random seed is set #1770

ljm565 opened this issue Apr 14, 2023 · 15 comments
Assignees
Labels
bug Something isn't working

Comments

@ljm565
Copy link

ljm565 commented Apr 14, 2023

class EfficientFormerV2(nn.Module):
    def __init__(self, config):
        super(EfficientFormerV2, self).__init__()
        self.model = timm.create_model("efficientformerv2_s0.snap_dist_in1k", pretrained=True)
        self.fc = nn.Linear(176*7*7, 2, bias=config.bias)

    def forward(self, x):
        batch_size = x.size(0)
        output = self.model.forward_features(x) 
        output = output.view(batch_size, -1)
        output = self.fc(output)
        return output

Loss results of every training are different despite using efficientformer V2 model with fixed random seed.
More specifically, training steps of each training epoch are 75 and after few steps (around 30 steps), loss results are different.
I think, the issue is come from timm randomness because when I use our customed model, the loss results are the same.
Is there any solution for this issue?
Do I have to use only train.py that they provided?

Thanks

@ljm565 ljm565 added the bug Something isn't working label Apr 14, 2023
@hacktmz
Copy link

hacktmz commented Apr 14, 2023

Can you change the input image size when training a classification network with EfficientFormerV2 ?If I change the input to something other than 224, an error will be reported

@rwightman
Copy link
Collaborator

@ljm565 there is non determinism inherint in default pytorch models, https://pytorch.org/docs/stable/notes/randomness.html

The is also likely to be some in most train scripts.

timm isn't doing anything unusual, I typically don't find it worth going to extremes in this area but you are welcome to try.

#853

@ljm565
Copy link
Author

ljm565 commented Apr 16, 2023

@hacktmz Actually, our data size is 3 * 224 * 224 as same as pre-trained EfficientFormer-V2. We have to use 224*224 size if we want to use pre-trained model.

@ljm565
Copy link
Author

ljm565 commented Apr 16, 2023

@rwightman If I don't inherit the timm Efficient-V2 (just using nn.Linear and nn.Conv), all trials make same results. This is because I actually set the torch and random module's seed. Actually, Hugging Face transformer-based models do not show this problem...

@rwightman
Copy link
Collaborator

@ljm565 as per the randomness info on PyTorch, you don't get true determinism in PyTorch unless you change default flags, not sure if transformers changes anything by default.

Also, these models have batchnorm so you do have to flip between .train() and .eval()

@rwightman
Copy link
Collaborator

@ljm565 if you can compare your model as with two others, a resnet50 and a vit_small_patch16_224 that would be helpful to see if you observe the non-determinism in either or both of those.

@ljm565
Copy link
Author

ljm565 commented Apr 16, 2023

@rwightman Yes of course, I applied model.train() and model.eval() at different phases. Also, now I tried to test resnet in torchvision.model, this model show the same results at every training.

The below is that seeds applied.

torch.cuda.manual_seed(999)
torch.manual_seed(999)
np.random.seed(999)
random.seed(999)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUDNN_DETERMINISTIC'] = '1'
os.environ['PYTHONHASHSEED'] = str(999)

@ljm565
Copy link
Author

ljm565 commented Apr 16, 2023

EfficientFormer-v2

trial 1

1 / 100
----------
Phase: train
Epoch 1: 0/75 step loss: 0.8326853513717651, acc: 0.5625
Epoch 1: 20/75 step loss: 0.34142473340034485, acc: 0.859375
Epoch 1: 40/75 step loss: 0.19201204180717468, acc: 0.90625
Epoch 1: 60/75 step loss: 0.11295119673013687, acc: 0.96875
train loss: 0.297733, acc: 0.868818
trial2
1 / 100
----------
Phase: train
Epoch 1: 0/75 step loss: 0.8326853513717651, acc: 0.5625
Epoch 1: 20/75 step loss: 0.37356138229370117, acc: 0.90625
Epoch 1: 40/75 step loss: 0.177852600812912, acc: 0.90625
Epoch 1: 60/75 step loss: 0.05695604532957077, acc: 0.984375
train loss: 0.260036, acc: 0.882230

resnet50

trial 1
1 / 100
----------
Phase: train
Epoch 1: 0/75 step loss: 0.7814404368400574, acc: 0.53125
Epoch 1: 20/75 step loss: 0.02220681682229042, acc: 1.0
Epoch 1: 40/75 step loss: 0.0027756004128605127, acc: 1.0
Epoch 1: 60/75 step loss: 0.0006665001856163144, acc: 1.0
train loss: 0.087285, acc: 0.961442
trial 2
1 / 100
----------
Phase: train
Epoch 1: 0/75 step loss: 0.7814404368400574, acc: 0.53125
Epoch 1: 20/75 step loss: 0.02220681682229042, acc: 1.0
Epoch 1: 40/75 step loss: 0.0027756004128605127, acc: 1.0
Epoch 1: 60/75 step loss: 0.0006665001856163144, acc: 1.0
train loss: 0.087285, acc: 0.961442

@rwightman
Copy link
Collaborator

rwightman commented Apr 16, 2023

@ljm565 hmmm, does it happen if you disable the attention bias cache for both attention modules

ie, in two locations, change

    def get_attention_biases(self, device: torch.device) -> torch.Tensor:
        if torch.jit.is_tracing() or self.training:
            return self.attention_biases[:, self.attention_bias_idxs]
        else:
            device_key = str(device)
            if device_key not in self.attention_bias_cache:
                self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
            return self.attention_bias_cache[device_key]

to

    def get_attention_biases(self, device: torch.device) -> torch.Tensor:
            return self.attention_biases[:, self.attention_bias_idxs]

@ljm565
Copy link
Author

ljm565 commented Apr 17, 2023

@rwightman I changed function in two location, the loss results are different.
Also, I added torch.manual_seed(999) at the top of efficientformer_v2.py but it doesn't work too.

Since all loss results in step 0 are the same, the model initialization always seems to be the same due to the random seed.

@rwightman
Copy link
Collaborator

@ljm565 hmm, this is definitely odd, even without all those flags but the same seeds, the typical level of non-determinism inherient in benchmarking and cudnn, etc quite small and results don't diverge as much as you see there.

For good measure, have you forced drop_path_rate=0. ... some model sizes have it enabled by default, that should be the only source of randomness by default as drop_rate is 0. Though this should be deterministic with those flags and the same seed even if it is active...

Other thought is that maybe there is some numeric stability issue making it sensitive to very small changes. Does enabling gradientg clipping, lowering adam/adamw beta from .999/.99 -> .95 (if using adam) change anything?

@ljm565
Copy link
Author

ljm565 commented Apr 17, 2023

@rwightman According to my code in the question, I used efficientformerv2_s0, so its drop_path_rate may 0.
Also, I found that if I set argument pretrained as False, they loss difference is very tiny (around 1e-7).
But still torchvision.model resnet shows exact same loss results regardless of the pretrained argument.

@rwightman
Copy link
Collaborator

rwightman commented Apr 17, 2023

@ljm565 try changing self.upsample = nn.Upsample(scale_factor=stride, mode='bilinear') to self.upsample = nn.Upsample(scale_factor=stride, mode='nearest') ... we don't really want to run in nearest, but it might be the source of non-deterministic behaviour

@ljm565
Copy link
Author

ljm565 commented Apr 17, 2023

Unfortunately, It still show the different result at every trial...
I thought it might be a problem with my environment, so I reinstalled the cuda, cuDNN, and torch, but the result was the same.
Until now I tested on the Ubuntu, and I will check the problem on the Mac.
I will notice the results of the Mac.

tested os: Ubuntu 22.04
cuda: 11.6
torch: 1.13.1+cu116

-------------added-------------
On the Mac, the loss results are reproduced. My ubuntu environment may effect to the loss results...

@rwightman
Copy link
Collaborator

@ljm565 to be clear, on mac the losses are the same? it may not be your environment but cuda/cudnn instead of CPU? did you try forcing CPU on ubuntu? I guess a fresh environment, pytorch 2.0 might be worth a check too

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants