-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error running detection baseline #2
Comments
GELU approximation was added in PyTorch 1.12 but the pickled models appear to be from PyTorch 1.11. To fix this, you can process the models like so immediately after they are loaded: from packaging import version
def fix_gelus(model):
if version.parse(torch.__version__) >= version.parse("1.12.0"):
for mod in model.modules():
if isinstance(mod, torch.nn.modules.activation.GELU):
mod.approximate = "none" Use like so: class NetworkDatasetDetection(torch.utils.data.Dataset):
# ...
def __getitem__(self, index):
- return torch.load(os.path.join(self.model_paths[index], 'model.pt')), \
- self.labels[index], self.data_sources[index]
+ model = torch.load(os.path.join(self.model_paths[index], "model.pt"))
+ fix_gelus(model)
+ return model, self.labels[index], self.data_sources[index]
class NetworkDatasetDetectionTest(torch.utils.data.Dataset):
# ...
def __getitem__(self, index):
- return torch.load(os.path.join(self.model_paths[index], 'model.pt')), self.data_sources[index]
+ model = torch.load(os.path.join(self.model_paths[index], "model.pt"))
+ fix_gelus(model)
+ return model, self.data_sources[index]
|
Fixed it! Based on your comment I changed my PyTorch version from 1.12 to 1.11 and it worked just fine;. Maybe putting this info on the README might help other teams :) |
Sorry for the late reply! Yes, it looks like you might run into errors if you don't use PyTorch 1.11.0. I just updated the README with this information. |
Hello everyone, thank for releasing starter-kit!
I have been trying to run the baseline code o n a jupyter-notebook, however, I have run into the following error:
error_detection_baseline.txt
What were the versions of the packages that you used for the tests?
I am using Pytorch 1.12, currently.
Thanks for your attention!
The text was updated successfully, but these errors were encountered: