Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed save_results to use .state_dict and created corresponding load_model function #70

Merged
merged 3 commits into from
Nov 16, 2021

Conversation

kaareendrup
Copy link
Collaborator

No description provided.

@RasmusOrsoe RasmusOrsoe self-requested a review November 15, 2021 15:15
Copy link
Collaborator

@RasmusOrsoe RasmusOrsoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Kaare!

Well done on your first PR. Next time, add a comment in the PR!

Your PR is now approved - you may click 'Merge'!

Rasmus

@asogaard
Copy link
Collaborator

Thanks so much for the contribution, @kaareendrup! One comment on this topic (not affecting this PR in any way) is that using .state_dict means that the user needs to keep track of the various model classes. For a single, monolithic model class, that's not that big of a deal, but if we are going for a more modular approach (detector, gnn, task) then it's quickly a handful of classes that the user needs to configure manually when loading the model, as evidenced by the large number of arguments to the load_model function implemented in this PR.

@RasmusOrsoe, is the reason for not wanting to save the entire model basically what it says in the guide mentioned in #63?

@RasmusOrsoe
Copy link
Collaborator

RasmusOrsoe commented Nov 16, 2021

@asogaard yes. In fact, you cant save the models using torch.save. It throws a pickle error (So the example code for training has an error in it!). And secondly its sensitive to changes in $PATH$.

@asogaard
Copy link
Collaborator

Hi @RasmusOrsoe,

I saw a pickle error when saving and loading, but it seemed to be fixed by using dill as the pickle module (see https://github.com/icecube/gnn-reco/blob/main/src/gnn_reco/models/training/utils.py#L35). Perhaps you encountered another error? I have attached some example code below that seems to save and load the entire model fine.

To the second point, wrt. $PATH, I don't have a good feeling for whether this is a conceptual problem or a very practical one. But perhaps it makes sense to provide both functionalities (e.g. .save and .save_state_dict) to allow users to choose, since the maintenance overhead for each is minimal? Should I try to write such methods for the Model class?


# Import(s)
import dill
import torch
from torch.utils.data import DataLoader
from torch_geometric.data import Batch

from gnn_reco.components.loss_functions import  VonMisesFisher2DLoss
from gnn_reco.data.sqlite_dataset import SQLiteDataset
from gnn_reco.data.constants import FEATURES, TRUTH
from gnn_reco.models import Model
from gnn_reco.models.detector import IceCubeDeepCore
from gnn_reco.models.gnn import DynEdge
from gnn_reco.models.graph_builders import KNNGraphBuilder
from gnn_reco.models.task.reconstruction import ZenithReconstructionWithKappa

# Load data
db = "/groups/icecube/leonbozi/datafromrasmus/GNNReco/data/databases/dev_level7_noise_muon_nu_classification_pass2_fixedRetro_v3/data/dev_level7_noise_muon_nu_classification_pass2_fixedRetro_v3.db"
dataset = SQLiteDataset(db, "SRTTWOfflinePulsesDC", FEATURES.ICECUBE86, TRUTH.ICECUBE86)
dataloader = DataLoader(
    dataset,
    batch_size=4, 
    shuffle=False,
    num_workers=1, 
    collate_fn=Batch.from_data_list,
    persistent_workers=True,
    prefetch_factor=2,
)
batch = next(iter(dataloader))

# Wrap code in functions to make it clear that these two operations are wholly independent
model_path = "test_model.pth"

def build_save_model():
    detector = IceCubeDeepCore(
        graph_builder=KNNGraphBuilder(nb_nearest_neighbours=8),
    )
    gnn = DynEdge(
        nb_inputs=detector.nb_outputs,
    )
    task = ZenithReconstructionWithKappa(
        hidden_size=gnn.nb_outputs, 
        target_label='zenith', 
        loss_function=VonMisesFisher2DLoss(),
    )
    model = Model(
        detector=detector,
        gnn=gnn,
        tasks=[task],
        device='cpu'
    )

    model.eval()
    print(model(batch))

    torch.save(model, model_path, pickle_module=dill)
    del model

def load_model():
    model = torch.load(model_path, pickle_module=dill)
    model.eval()
    print(model(batch))

# Test model saving and loading
build_save_model()
load_model()

Both methods yield the same output:

[tensor([[1.6040, 0.7016],
        [1.7503, 0.7677],
        [1.7475, 0.8868],
        [1.6987, 0.4920]], grad_fn=<StackBackward>)]

@RasmusOrsoe
Copy link
Collaborator

@asogaard Aha! The changes @kaareendrup made is in components.utils.py and your dill-fix is in models/training/.. . So I think the pickle error we saw is the same as the one you fixed with dill, and the reason we saw it is because the function lives both places as was only fixed in one of the locations. We need to spring clean!

Regarding the $PATH$ dependency - it is really only a problem if we move the code around again. It's just important to remember that if we save the entire thing using pickle, and later move the model code, its likely that old trained models wont be future compatible. Maybe, as you propose, doing both types of saving is the good thing to do - one could write a load function that would first try to read the full pickle, if it fails it could default to the state_dict.

Rasmus

@asogaard
Copy link
Collaborator

Brilliant! In that case, I will take the liberty of merging this PR (hope that's okay, @kaareendrup!), do some spring cleaning to remove duplicate code, and add the two methods discussed above. 💪

@asogaard asogaard merged commit 905a15f into graphnet-team:main Nov 16, 2021
@kaareendrup
Copy link
Collaborator Author

Wow, good job guys! I will need to read through your discussion carefully to fully get what's going on. But everything sounds good and it sounds like a proper cleaning is not a bad idea.

@RasmusOrsoe
Copy link
Collaborator

@asogaard @kaareendrup @RasmusOrsoe Lets get this saving/loading snippet as an example in the repo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants