-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changed save_results to use .state_dict and created corresponding load_model function #70
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Kaare!
Well done on your first PR. Next time, add a comment in the PR!
Your PR is now approved - you may click 'Merge'!
Rasmus
Thanks so much for the contribution, @kaareendrup! One comment on this topic (not affecting this PR in any way) is that using @RasmusOrsoe, is the reason for not wanting to save the entire model basically what it says in the guide mentioned in #63? |
@asogaard yes. In fact, you cant save the models using torch.save. It throws a pickle error (So the example code for training has an error in it!). And secondly its sensitive to changes in |
Hi @RasmusOrsoe, I saw a pickle error when saving and loading, but it seemed to be fixed by using To the second point, wrt. # Import(s)
import dill
import torch
from torch.utils.data import DataLoader
from torch_geometric.data import Batch
from gnn_reco.components.loss_functions import VonMisesFisher2DLoss
from gnn_reco.data.sqlite_dataset import SQLiteDataset
from gnn_reco.data.constants import FEATURES, TRUTH
from gnn_reco.models import Model
from gnn_reco.models.detector import IceCubeDeepCore
from gnn_reco.models.gnn import DynEdge
from gnn_reco.models.graph_builders import KNNGraphBuilder
from gnn_reco.models.task.reconstruction import ZenithReconstructionWithKappa
# Load data
db = "/groups/icecube/leonbozi/datafromrasmus/GNNReco/data/databases/dev_level7_noise_muon_nu_classification_pass2_fixedRetro_v3/data/dev_level7_noise_muon_nu_classification_pass2_fixedRetro_v3.db"
dataset = SQLiteDataset(db, "SRTTWOfflinePulsesDC", FEATURES.ICECUBE86, TRUTH.ICECUBE86)
dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=False,
num_workers=1,
collate_fn=Batch.from_data_list,
persistent_workers=True,
prefetch_factor=2,
)
batch = next(iter(dataloader))
# Wrap code in functions to make it clear that these two operations are wholly independent
model_path = "test_model.pth"
def build_save_model():
detector = IceCubeDeepCore(
graph_builder=KNNGraphBuilder(nb_nearest_neighbours=8),
)
gnn = DynEdge(
nb_inputs=detector.nb_outputs,
)
task = ZenithReconstructionWithKappa(
hidden_size=gnn.nb_outputs,
target_label='zenith',
loss_function=VonMisesFisher2DLoss(),
)
model = Model(
detector=detector,
gnn=gnn,
tasks=[task],
device='cpu'
)
model.eval()
print(model(batch))
torch.save(model, model_path, pickle_module=dill)
del model
def load_model():
model = torch.load(model_path, pickle_module=dill)
model.eval()
print(model(batch))
# Test model saving and loading
build_save_model()
load_model() Both methods yield the same output:
|
@asogaard Aha! The changes @kaareendrup made is in components.utils.py and your dill-fix is in models/training/.. . So I think the pickle error we saw is the same as the one you fixed with dill, and the reason we saw it is because the function lives both places as was only fixed in one of the locations. We need to spring clean! Regarding the Rasmus |
Brilliant! In that case, I will take the liberty of merging this PR (hope that's okay, @kaareendrup!), do some spring cleaning to remove duplicate code, and add the two methods discussed above. 💪 |
Wow, good job guys! I will need to read through your discussion carefully to fully get what's going on. But everything sounds good and it sounds like a proper cleaning is not a bad idea. |
@asogaard @kaareendrup @RasmusOrsoe Lets get this saving/loading snippet as an example in the repo |
No description provided.