Skip to content

Commit

Permalink
Extra cleanup
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #239

Differential Revision: D18628765

Pulled By: vreis

fbshipit-source-id: 26472bab473f50d7dbc153b5c4ee60641ea77f9f
  • Loading branch information
vreis authored and facebook-github-bot committed Nov 21, 2019
1 parent e8eddcb commit 9cf3c9f
Showing 1 changed file with 56 additions and 31 deletions.
87 changes: 56 additions & 31 deletions classy_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,13 @@
"""

import logging
import os
from datetime import datetime
from pathlib import Path

import torch
from torchvision import set_video_backend

from classy_vision.generic.args import parse_args
from classy_vision.generic.registry_utils import import_all_packages_from_directory
from classy_vision.generic.util import load_checkpoint
Expand All @@ -57,20 +61,22 @@
)
from classy_vision.tasks import FineTuningTask, build_task
from classy_vision.trainer import DistributedTrainer, LocalTrainer
from torchvision import set_video_backend


def main(args, config):
# Global settings
# Global flags
torch.manual_seed(0)
set_video_backend(args.video_backend)

task = build_task(config)

# Load checkpoint, if available
# Load checkpoint, if available. This automatically resumes from an
# existing checkpoint, in case training is being restarted.
checkpoint = load_checkpoint(args.checkpoint_folder, args.device)
task.set_checkpoint(checkpoint)

# Load a checkpoint contraining a pre-trained model. This is how we
# implement fine-tuning of existing models.
pretrained_checkpoint = load_checkpoint(
args.pretrained_checkpoint_folder, args.device
)
Expand All @@ -80,57 +86,76 @@ def main(args, config):
), "Can only use a pretrained checkpoint for fine tuning tasks"
task.set_pretrained_checkpoint(pretrained_checkpoint)

hooks = [
LossLrMeterLoggingHook(args.log_freq),
ModelComplexityHook(),
TimeMetricsHook(),
# Configure hooks to do tensorboard logging, checkpoints and so on
task.set_hooks(configure_hooks(args))

use_gpu = None
if args.device is not None:
use_gpu = args.device == "gpu"

# LocalTrainer is used for a single node. DistributedTrainer will setup
# training to use PyTorch's DistributedDataParallel.
trainer_class = {"none": LocalTrainer, "ddp": DistributedTrainer}[
args.distributed_backend
]

trainer = trainer_class(use_gpu=use_gpu, num_dataloader_workers=args.num_workers)

# That's it! When this call returns, training is done.
trainer.train(task)

output_folder = Path(args.checkpoint_folder).resolve()
logging.info("Training successful!")
logging.info(
f'Results of this training run are available at: "{output_folder}"'
)


def configure_hooks(args):
hooks = [LossLrMeterLoggingHook(args.log_freq), TimeMetricsHook()]

# Make a folder to store checkpoints and tensorboard logging outputs
suffix = datetime.now().isoformat()
base_folder = Path(__file__).parent / f"output_{suffix}"
if args.checkpoint_folder == "":
args.checkpoint_folder = base_folder / "checkpoints"
os.makedirs(args.checkpoint_folder)

logging.info(f"Logging outputs to {base_folder.resolve()}")

if not args.skip_tensorboard:
try:
from tensorboardX import SummaryWriter

tb_writer = SummaryWriter(log_dir="/tmp/tensorboard")
tb_writer = SummaryWriter(log_dir=base_folder / "tensorboard")
hooks.append(TensorboardPlotHook(tb_writer))
hooks.append(ModelTensorboardHook(tb_writer))
except ImportError:
logging.warning("tensorboardX not installed, skipping tensorboard hooks")
if args.checkpoint_folder != "":
args_dict = vars(args)
args_dict["config"] = config
hooks.append(
CheckpointHook(
args.checkpoint_folder,
args_dict,
checkpoint_period=args.checkpoint_period,
)

args_dict = vars(args)
args_dict["config"] = config
hooks.append(
CheckpointHook(
args.checkpoint_folder, args_dict, checkpoint_period=args.checkpoint_period
)
)

if args.profiler:
hooks.append(ProfilerHook())
if args.show_progress:
hooks.append(ProgressBarHook())
if args.visdom_server != "":
hooks.append(VisdomHook(args.visdom_server, args.visdom_port))

task.set_hooks(hooks)

use_gpu = None
if args.device is not None:
use_gpu = args.device == "gpu"

trainer_class = {"none": LocalTrainer, "ddp": DistributedTrainer}[
args.distributed_backend
]

trainer = trainer_class(use_gpu=use_gpu, num_dataloader_workers=args.num_workers)
trainer.train(task)
return hooks


# run all the things:
if __name__ == "__main__":
logger = logging.getLogger()
logger.setLevel(logging.INFO)

logging.info("Generic convolutional network trainer.")
logging.info("Classy Vision's default training script.")

# This imports all modules in the same directory as classy_train.py
# Because of the way Classy Vision's registration decorators work,
Expand Down

0 comments on commit 9cf3c9f

Please sign in to comment.