Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error running detection baseline #2

Closed
danperazzo opened this issue Jul 20, 2022 · 3 comments
Closed

Error running detection baseline #2

danperazzo opened this issue Jul 20, 2022 · 3 comments

Comments

@danperazzo
Copy link

Hello everyone, thank for releasing starter-kit!

I have been trying to run the baseline code o n a jupyter-notebook, however, I have run into the following error:
error_detection_baseline.txt

What were the versions of the packages that you used for the tests?

I am using Pytorch 1.12, currently.

Thanks for your attention!

@PythonNut
Copy link

PythonNut commented Jul 20, 2022

GELU approximation was added in PyTorch 1.12 but the pickled models appear to be from PyTorch 1.11. To fix this, you can process the models like so immediately after they are loaded:

from packaging import version

def fix_gelus(model):
    if version.parse(torch.__version__) >= version.parse("1.12.0"):
        for mod in model.modules():
            if isinstance(mod, torch.nn.modules.activation.GELU):
                mod.approximate = "none"

Use like so:

class NetworkDatasetDetection(torch.utils.data.Dataset):
    # ...
    def __getitem__(self, index):
-       return torch.load(os.path.join(self.model_paths[index], 'model.pt')), \
-              self.labels[index], self.data_sources[index]
+       model = torch.load(os.path.join(self.model_paths[index], "model.pt"))
+       fix_gelus(model)
+       return model, self.labels[index], self.data_sources[index]


class NetworkDatasetDetectionTest(torch.utils.data.Dataset):
    # ...

    def __getitem__(self, index):
-       return torch.load(os.path.join(self.model_paths[index], 'model.pt')), self.data_sources[index]
+       model = torch.load(os.path.join(self.model_paths[index], "model.pt"))
+       fix_gelus(model)
+       return model, self.data_sources[index]

@danperazzo
Copy link
Author

danperazzo commented Jul 21, 2022

Fixed it! Based on your comment I changed my PyTorch version from 1.12 to 1.11 and it worked just fine;. Maybe putting this info on the README might help other teams :)

@mmazeika
Copy link
Owner

mmazeika commented Aug 2, 2022

Sorry for the late reply! Yes, it looks like you might run into errors if you don't use PyTorch 1.11.0. I just updated the README with this information.

@mmazeika mmazeika closed this as completed Aug 2, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants