Skip to content

Commit

Permalink
add distributed API for model, datasets, trainer.
Browse files Browse the repository at this point in the history
  • Loading branch information
308188605@qq.com committed Jun 1, 2019
1 parent fa71beb commit 4304b4d
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 3 deletions.
26 changes: 25 additions & 1 deletion jdit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import psutil
from typing import Union
from abc import ABCMeta, abstractmethod

from torch.utils.data.distributed import DistributedSampler

class DataLoadersFactory(metaclass=ABCMeta):
"""This is a super class of dataloader.
Expand Down Expand Up @@ -156,6 +156,30 @@ def build_loaders(self):
self.loader_test = DataLoader(self.dataset_test, batch_size=self.batch_size, shuffle=self.shuffle)
self.nsteps_test = len(self.loader_test)

def convert_to_distributed(self, which_dataset=None, num_replicas=None, rank=None):
samplers = {}
if which_dataset is None:
samplers["train"] = DistributedSampler(self.dataset_train, num_replicas=None, rank=None)
self.loader_train = DataLoader(self.dataset_train, self.batch_size, False, sampler=samplers["train"])

else:
if which_dataset == "train":
samplers["train"] = DistributedSampler(self.dataset_train, num_replicas=num_replicas, rank=rank)
self.loader_train = DataLoader(self.dataset_train, self.batch_size, False,
sampler=samplers["train"])
elif which_dataset == "valid":
samplers["valid"] = DistributedSampler(self.dataset_valid, num_replicas=num_replicas, rank=rank)
self.loader_valid = DataLoader(self.dataset_valid, self.batch_size, False,
sampler=samplers["valid"])
elif which_dataset == "test":
self.loader_test.sampler = samplers["test"]
self.loader_test = DataLoader(self.dataset_test, self.batch_size, False,
sampler=samplers["test"])
else:
ValueError(
"param `which_dataset` can only be set 'train, valid and test'. Got %s instead" % which_dataset)
return samplers

@property
def samples_train(self):
return self._get_samples(self.dataset_train, self.sample_dataset_size)
Expand Down
71 changes: 70 additions & 1 deletion jdit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import os
from torch.nn import init, Conv2d, Linear, ConvTranspose2d, InstanceNorm2d, BatchNorm2d, DataParallel, Module
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch import save, load
from typing import Union
from collections import OrderedDict
Expand Down Expand Up @@ -100,7 +101,7 @@ def __init__(self, proto_model: Module,

if not isinstance(proto_model, Module):
raise TypeError(
"The type of `proto_model` must be `torch.nn.Module`, but got %s instead" % type(proto_model))
"The type of `proto_model` must be `torch.nn.Module`, but got %s instead" % type(proto_model))
self.model: Union[DataParallel, Module] = None
self.model_name = proto_model.__class__.__name__
self.weights_init = None
Expand Down Expand Up @@ -258,6 +259,74 @@ def is_checkpoint(self, model_name: str, epoch: int, logdir="log"):
self.check_point(model_name, epoch, logdir)
return is_check_point

def convert_to_distributed(self, device_ids=None,
output_device=None, dim=0, broadcast_buffers=True,
process_group=None, bucket_cap_mb=25,
find_unused_parameters=False,
check_reduction=False):
"""
Args:
module (Module): module to be parallelized
device_ids (list of int or torch.device): CUDA devices. This should
only be provided when the input module resides on a single
CUDA device. For single-device modules, the ``i``th
:attr:`module` replica is placed on ``device_ids[i]``. For
multi-device modules and CPU modules, device_ids must be None
or an empty list, and input data for the forward pass must be
placed on the correct device. (default: all devices for
single-device modules)
output_device (int or torch.device): device location of output for
single-device CUDA modules. For multi-device modules and
CPU modules, it must be None, and the module itself
dictates the output location. (default: device_ids[0] for
single-device modules)
broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
the module at beginning of the forward function.
(default: ``True``)
process_group: the process group to be used for distributed data
all-reduction. If ``None``, the default process group, which
is created by ```torch.distributed.init_process_group```,
will be used. (default: ``None``)
bucket_cap_mb: DistributedDataParallel will bucket parameters into
multiple buckets so that gradient reduction of each
bucket can potentially overlap with backward computation.
:attr:`bucket_cap_mb` controls the bucket size in MegaBytes (MB)
(default: 25)
find_unused_parameters (bool): Traverse the autograd graph of all tensors
contained in the return value of the wrapped
module's ``forward`` function.
Parameters that don't receive gradients as
part of this graph are preemptively marked
as being ready to be reduced.
(default: ``False``)
check_reduction: when setting to ``True``, it enables DistributedDataParallel
to automatically check if the previous iteration's
backward reductions were successfully issued at the
beginning of every iteration's forward function.
You normally don't need this option enabled unless you
are observing weird behaviors such as different ranks
are getting different gradients, which should not
happen if DistributedDataParallel is correctly used.
(default: ``False``)
Attributes:
module (Module): the module to be parallelized
Example::
>>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
>>> net.convert_to_distributed(pg)
>>> # same thing
>>> net.model = torch.nn.DistributedDataParallel(net.model, pg)
"""
assert isinstance(self.model, DataParallel), "please only use one gpu for one task"
self.model = DistributedDataParallel(self.model, device_ids,
output_device, dim, broadcast_buffers,
process_group, bucket_cap_mb,
find_unused_parameters,
check_reduction)

@staticmethod
def count_params(proto_model: Module):
"""count the total parameters of model.
Expand Down
27 changes: 27 additions & 0 deletions jdit/trainer/super.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,33 @@ def train(self, process_bar_header: str = None, process_bar_position: int = None
self.test()
self.watcher.close()

def dist_train(self, process_bar_header: str = None, process_bar_position: int = None,
subbar_disable=False,
record_configs=True, show_network=False, **kwargs):
"""The main training loop of epochs.
:param process_bar_header: The tag name of process bar header,
which is used in ``tqdm(desc=process_bar_header)``
:param process_bar_position: The process bar's position. It is useful in multitask,
which is used in ``tqdm(position=process_bar_position)``
:param subbar_disable: If show the info of every training set,
:param record_configs: If record the training processing data.
:param show_network: If show the structure of network. It will cost extra memory,
:param kwargs: Any other parameters that passing to ``tqdm()`` to control the behavior of process bar.
"""
if record_configs:
self._record_configs()
if show_network:
self.plot_graphs_lazy()
for epoch in tqdm(range(self.start_epoch, self.nepochs + 1), total=self.nepochs,
unit="epoch", desc=process_bar_header, position=process_bar_position, **kwargs):
self._datasets["datasets"].loader_train.sampler.set_epoch(epoch)
self.current_epoch = epoch
self.train_epoch(subbar_disable)
self.valid_epoch()
self.test()
self.watcher.close()

def __setattr__(self, key, value):
super(SupTrainer, self).__setattr__(key, value)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="jdit", # pypi中的名称,pip或者easy_install安装时使用的名称,或生成egg文件的名称
version="0.0.13",
version="0.0.14",
author="Guanglei Ding",
author_email="dingguanglei.bupt@qq.com",
maintainer='Guanglei Ding',
Expand Down

0 comments on commit 4304b4d

Please sign in to comment.