Skip to content

Commit

Permalink
Fix bug in loading the final train phase checkpoint (#387)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #387

Setting a checkpoint called `task._set_model_train_mode()` if the checkpoint was a train phase checkpoint. This function used to call `self.optimizer.update_schedule_on_epoch`, which called `self.where`. This would raise an exception for the final checkpoint since the value of where would be 1.0.

Moved this call to `advance_phase()`, that is anyway a better place for the update schedule to be called instead of inside `_set_model_train_mode()`.

Added a test case which failed before the fix and passes now.

Reviewed By: vreis

Differential Revision: D19815391

fbshipit-source-id: 05b51a0e76d20c7009e93fff4d01a6fdfb14e2cc
  • Loading branch information
mannatsingh authored and facebook-github-bot committed Feb 10, 2020
1 parent 8a690af commit bbfc3ae
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 8 deletions.
7 changes: 4 additions & 3 deletions classy_vision/tasks/classification_task.py
Expand Up @@ -719,6 +719,10 @@ def advance_phase(self):
# Set up pytorch module in train vs eval mode, update optimizer.
self._set_model_train_mode()

# Update the optimizer schedule
if self.train and self.train_phase_idx >= 0:
self.optimizer.update_schedule_on_epoch(self.where)

def done_training(self):
"""Stop condition for training
"""
Expand Down Expand Up @@ -778,9 +782,6 @@ def _set_model_train_mode(self):
):
self._broadcast_buffers()

if self.train and self.train_phase_idx >= 0:
self.optimizer.update_schedule_on_epoch(self.where)

def _broadcast_buffers(self):
"""Explicitly synchronize buffers across all devices."""
if self.distributed_model is None:
Expand Down
3 changes: 0 additions & 3 deletions classy_vision/tasks/fine_tuning_task.py
Expand Up @@ -66,9 +66,6 @@ def _set_model_train_mode(self):
else:
self.base_model.train(phase["train"])

if self.train and self.train_phase_idx >= 0:
self.optimizer.update_schedule_on_epoch(self.where)

def prepare(
self,
num_dataloader_workers: int = 0,
Expand Down
43 changes: 41 additions & 2 deletions test/tasks_classification_task_test.py
Expand Up @@ -4,14 +4,16 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import shutil
import tempfile
import unittest
from test.generic.config_utils import get_fast_test_task_config, get_test_task_config
from test.generic.utils import compare_model_state, compare_samples, compare_states

import torch
from classy_vision.dataset import build_dataset
from classy_vision.generic.util import get_checkpoint_dict
from classy_vision.hooks import LossLrMeterLoggingHook
from classy_vision.generic.util import get_checkpoint_dict, load_checkpoint
from classy_vision.hooks import CheckpointHook, LossLrMeterLoggingHook
from classy_vision.losses import build_loss
from classy_vision.models import build_model
from classy_vision.optim import build_optimizer
Expand All @@ -29,6 +31,14 @@ def _compare_samples(self, sample_1, sample_2):
def _compare_states(self, state_1, state_2, check_heads=True):
compare_states(self, state_1, state_2)

def setUp(self):
# create a base directory to write checkpoints to
self.base_dir = tempfile.mkdtemp()

def tearDown(self):
# delete all the temporary data created
shutil.rmtree(self.base_dir)

def test_build_task(self):
config = get_test_task_config()
task = build_task(config)
Expand Down Expand Up @@ -90,6 +100,35 @@ def test_checkpointing(self):
task_2.train_step(use_gpu, local_variables)
self._compare_states(task.get_classy_state(), task_2.get_classy_state())

def test_final_train_checkpoint(self):
"""Test that a train phase checkpoint with a where of 1.0 can be loaded"""

config = get_fast_test_task_config()
task = build_task(config).set_hooks(
[CheckpointHook(self.base_dir, {}, phase_types=["train"])]
)
task_2 = build_task(config)

use_gpu = torch.cuda.is_available()

trainer = LocalTrainer(use_gpu=use_gpu)
trainer.train(task)

# load the final train checkpoint
checkpoint = load_checkpoint(self.base_dir)

# make sure fetching the where raises an exception, which means that
# where is >= 1.0
with self.assertRaises(Exception):
task.where

# set task_2's state as task's final train checkpoint
task_2.set_checkpoint(checkpoint)
task_2.prepare(use_gpu=use_gpu)

# we should be able to train the task
trainer.train(task_2)

def test_test_only_checkpointing(self):
"""
Tests checkpointing by running train_steps to make sure the
Expand Down

0 comments on commit bbfc3ae

Please sign in to comment.