Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Oct 26, 2018
1 parent 460f42f commit b35f6d5
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 3 deletions.
1 change: 0 additions & 1 deletion generate_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from mypackage.tricks import gradPenalty, spgradPenalty
from mypackage.model.Tnet import NLayer_D, TWnet_G, NThickLayer_D


# from mypackage.tricks import jcbClamp


Expand Down
48 changes: 47 additions & 1 deletion jdit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,53 @@
from torch import save, load

class Model(object):
"""a model
r"""A warapper of pytorch ``module`` .
In the simplest case, we use a raw pytorch ``module`` to assemble a ``Model`` of this class.
It can be more convenient to use some feather method, such ``checkPoint`` , ``loadModel`` and so on.
* :attr:`proto_model` is the core model in this class. It is no necessary to passing a ``module``
when you init a ``Model`` . you can build a model later by using ``Model.define(module)`` or load a model from a file.
* :attr:`gpu_ids_abs` controls the gpus which you want to use. you should use a absolute id of gpus.
* :attr:`init_method` controls the weights init method.
* At init_method="xavier", it will use ``init.xavier_normal_``, in ``pytorch.nn.init``, to init the Conv layers of model.
* At init_method="kaiming", it will use ``init.kaiming_normal_``, in ``pytorch.nn.init``, to init the Conv layers of model.
* At init_method=your_own_method, it will be used on weights, just like what ``pytorch.nn.init`` method does.
* :attr:`show_structure` controls whether to show your network structure.
.. note::
Don't try to pass a :attr:``DataParallel`` model. Only :attr:``module`` is accessable.
.. note::
:attr:`gpu_ids_abs` must be a tuple or list. If you want to use cpu, just passing an ampty list like ``[]``.
Args:
proto_model (module): A pytroch module. Default: ``None``.
gpu_ids_abs (tuple or list): The absolute id of gpus. if [] using cpu. Default: ``()``.
init_method (str or def): Weights init method. Default: ``"Kaiming"``
show_structure (bool): Is the structure shown. Default: ``False``
Attributes:
num_params (int): the totals amount of weights in this model.
gpu_ids (list or tuple): which device is this model on.
Examples::
>>> # using a square kernels and equal stride
>>> module = Sequential(Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)))
>>> # using cpu to init a Model by module.
>>> net = Model(module, [], show_structure=False)
>>> input = torch.randn(20, 16, 10, 50, 100)
>>> output = net(input)
"""
def __init__(self, proto_model=None, gpu_ids_abs=(), init_method="kaiming", show_structure=False):
Expand Down
1 change: 0 additions & 1 deletion jdit/trainer/gan/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# from ...metric import FID
import torch


class GanTrainer(SupTrainer):
d_turn = 1
def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets, latent_shape):
Expand Down

0 comments on commit b35f6d5

Please sign in to comment.