Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
- can now specify pickle module, since dill causes a recursion depth …
Browse files Browse the repository at this point in the history
…exceeded error in pytorch 0.4.1
  • Loading branch information
nasimrahaman committed Aug 7, 2018
1 parent 1dd60ec commit 9ce802c
Showing 1 changed file with 30 additions and 6 deletions.
36 changes: 30 additions & 6 deletions inferno/trainers/basic.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import dill
from datetime import datetime
from inspect import signature
import os
import shutil
import contextlib
import warnings

# These are fetched from globals, they're not unused
# noinspection PyUnresolvedReferences
import dill
# noinspection PyUnresolvedReferences
import pickle


import torch
from numpy import inf
from torch.autograd import Variable
Expand Down Expand Up @@ -102,6 +108,7 @@ def __init__(self, model=None):
# Checkpointing
self._save_every = None
self._save_to_directory = None
self._pickle_module = 'pickle'
# Defaults for file names
self._checkpoint_filename = 'checkpoint.pytorch'
self._best_checkpoint_filename = 'best_checkpoint.pytorch'
Expand Down Expand Up @@ -599,6 +606,22 @@ def log_directory(self, value):
if value is not None:
self.set_log_directory(value)

@property
def pickle_module(self):
module_ = globals().get(self._pickle_module, None)
assert_(module_ is not None, "Pickle module not found!", ModuleNotFoundError)
return module_

_ALLOWED_PICKLE_MODULES = {'pickle', 'dill'}

@pickle_module.setter
def pickle_module(self, value):
assert_(isinstance(value, str), "`pickle_module` must be set to a string.", TypeError)
assert_(value in self._ALLOWED_PICKLE_MODULES,
f"Pickle module must be one of {self._ALLOWED_PICKLE_MODULES}, "
f"got {value} instead.", ValueError)
self._pickle_module = value

@property
def saving_every(self):
"""Gets the frequency at which checkpoints are made."""
Expand Down Expand Up @@ -1605,8 +1628,8 @@ def save(self, exclude_loader=True, stash_best_checkpoint=True):

# Save the state dictionary
torch.save(self.get_config(exclude_loader=exclude_loader),
checkpoint_path)
# pickle_module=dill)
checkpoint_path,
pickle_module=self.pickle_module)

self.callbacks.call(self.callbacks.END_OF_SAVE,
save_to_directory=self._save_to_directory,
Expand All @@ -1631,7 +1654,7 @@ def save_model(self, to_directory=None):
# Save the state dictionary
torch.save(self.model,
os.path.join(to_directory, 'model.pytorch'),
pickle_module=dill)
pickle_module=self.pickle_module)
return self

def load(self, from_directory=None, best=False, filename=None):
Expand Down Expand Up @@ -1661,7 +1684,7 @@ def load(self, from_directory=None, best=False, filename=None):
filename = self._best_checkpoint_filename if best else self._checkpoint_filename
# Load the dictionary
config_dict = torch.load(os.path.join(from_directory, filename),
pickle_module=dill)
pickle_module=self.pickle_module)
# This is required to prevent an infinite save loop?
self._is_iteration_with_best_validation_score = False
# Set config
Expand All @@ -1672,7 +1695,8 @@ def load_model(self, from_directory=None, filename=None):
from_directory = self._save_to_directory if from_directory is None else from_directory
filename = 'model.pytorch' if filename is None else filename
# Load the model
model = torch.load(os.path.join(from_directory, filename), pickle_module=dill)
model = torch.load(os.path.join(from_directory, filename),
pickle_module=self.pickle_module)
# Set model
self.model = model
return self
Expand Down

0 comments on commit 9ce802c

Please sign in to comment.