Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

add model.save test #147

Merged
merged 5 commits into from Dec 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 1 addition & 3 deletions gradsflow/callbacks/raytune.py
Expand Up @@ -14,7 +14,6 @@
import os
from typing import Optional

import torch
from ray import tune

from gradsflow.core.callbacks import Callback
Expand All @@ -40,11 +39,10 @@ class TorchTuneCheckpointCallback(Callback):

def on_epoch_end(self):
epoch = self.model.tracker.current_epoch
model = self.model.learner

with tune.checkpoint_dir(epoch) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "filename")
torch.save((model.state_dict()), path)
self.model.save(path)


class TorchTuneReport(Callback):
Expand Down
17 changes: 12 additions & 5 deletions gradsflow/models/base.py
Expand Up @@ -15,6 +15,7 @@
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Union

import smart_open
import torch
from accelerate import Accelerator
from torch import nn
Expand Down Expand Up @@ -75,9 +76,6 @@ def assert_compiled(self):
if not self._compiled:
raise UserWarning("Model not compiled yet! Please call `model.compile(...)` first.")

def load_from_checkpoint(self, checkpoint):
self.learner = torch.load(checkpoint)

@torch.no_grad()
def predict(self, x):
return self.learner(x)
Expand Down Expand Up @@ -172,9 +170,18 @@ def train(self):
self.learner.requires_grad_(True)
self.learner.train()

def save(self, path: str, save_extra: bool = True):
def load_from_checkpoint(self, checkpoint):
data = torch.load(checkpoint)
if isinstance(data, dict):
self.learner = data["model"]
self.tracker = data["tracker"]
else:
self.learner = data

def save(self, path: str, save_extra: bool = False):
"""save model"""
model = self.learner
if save_extra:
model = {"model": self.learner, "tracker": self.tracker}
torch.save(model, path)
with smart_open.open(path, "wb") as f:
torch.save(model, f)
33 changes: 33 additions & 0 deletions tests/conftest.py
@@ -0,0 +1,33 @@
# Copyright (c) 2021 GradsFlow. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Arrange
import pytest
import timm

from gradsflow import Model


@pytest.fixture
def resnet18():
cnn = timm.create_model("ssl_resnet18", pretrained=False, num_classes=10).eval()

return cnn


@pytest.fixture
def cnn_model(resnet18):
model = Model(resnet18)
model.TEST = True

return model
41 changes: 30 additions & 11 deletions tests/models/test_model.py
Expand Up @@ -34,26 +34,26 @@
model.TEST = True


def test_predict():
def test_predict(cnn_model):
x = torch.randn(1, 3, 64, 64)
r1 = model.forward(x)
r2 = model(x)
r3 = model.predict(x)
r1 = cnn_model.forward(x)
r2 = cnn_model(x)
r3 = cnn_model.predict(x)
assert torch.all(torch.isclose(r1, r2))
assert torch.all(torch.isclose(r2, r3))
assert isinstance(model.predict(torch.randn(1, 3, 64, 64)), torch.Tensor)


def test_fit():
model.TEST = True
def test_fit(cnn_model):
cnn_model.compile()
assert autodataset
tracker = model.fit(autodataset, max_epochs=1, steps_per_epoch=1, show_progress=True)
tracker = cnn_model.fit(autodataset, max_epochs=1, steps_per_epoch=1, show_progress=True)
assert isinstance(tracker, Tracker)

autodataset2 = AutoDataset(train_data.dataloader, num_classes=num_classes)
model.TEST = False
cnn_model.TEST = False
ckpt_cb = ModelCheckpoint(save_extra=False)
tracker2 = model.fit(
tracker2 = cnn_model.fit(
autodataset2,
max_epochs=1,
steps_per_epoch=1,
Expand Down Expand Up @@ -84,7 +84,26 @@ def compute_accuracy(*_, **__):
assert model2.optimizer.param_groups[0]["lr"] == 0.01


def test_set_accelerator():
model2 = Model(cnn, accelerator_config={"fp16": True})
def test_set_accelerator(resnet18):
model2 = Model(resnet18, accelerator_config={"fp16": True})
model2.compile()
assert model2.accelerator


def test_save_model(tmp_path, resnet18, cnn_model):
path = f"{tmp_path}/dummy_model.pth"
cnn_model.save(path, save_extra=True)
assert isinstance(torch.load(path), dict)

cnn_model.save(path, save_extra=False)
assert isinstance(torch.load(path), type(resnet18))


def test_load_from_checkpoint(tmp_path, cnn_model):
path = f"{tmp_path}/dummy_model.pth"
cnn_model.save(path, save_extra=True)
assert isinstance(torch.load(path), dict)

cnn_model.tracker.train.metrics["CHECK"] = True
cnn_model.load_from_checkpoint(path)
assert cnn_model.tracker.train.metrics["CHECK"]