Skip to content

Commit

Permalink
docs update
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Nov 21, 2018
1 parent 989b162 commit f2254da
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 81 deletions.
8 changes: 4 additions & 4 deletions docs/source/Build your own trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ First, you need to build a pytorch ``module`` like this:
... out = self.layer1(input)
... out = self.layer2(out)
... return out
... network = LinearModel()
>>> network = LinearModel()
.. note::

Expand Down Expand Up @@ -126,7 +126,7 @@ However, ``do_lr_decay()`` will be called every epoch or on certain epoch
at the end automatically.
Actually, you don' need to do anything to apply learning rate decay.
If you don't want to decay. Just set ``lr_decay = 1.`` or set a decay epoch larger than training epoch.
I will show you how it works. If you want to implement something special strategies.
I will show you how it works and you can implement something special strategies.

.. code-block:: python
Expand All @@ -150,12 +150,12 @@ I will show you how it works. If you want to implement something special strateg
>>> opt.lr
1
It contains two main optimizer RMSprop and Adam. You can pass a certain name to use it with its own parameters.
It contains two main optimizer ``RMSprop`` and ``Adam``. You can pass a certain name to use it with its own parameters.

.. note::

As for spectrum normalization, the optimizer will filter out the differentiable weights.
So, you don't need write something like this.
So, you don't need write something like this
``filter(lambda p: p.requires_grad, params)``
Merely pass the ``model.parameters()``
is enough.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jdit.dataset
Dataloaders_factory
-------------------

.. autoclass:: Dataloaders_factory
.. autoclass:: DataLoadersFactory
:members:

HandMNIST
Expand Down
12 changes: 6 additions & 6 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,23 @@ Then you will see something like this as following.
.. code-block:: python
===> Build dataset
use 8 thread!
use 8 thread
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Processing...
Done!
Done
===> Building model
ResNet Total number of parameters: 2776522
ResNet model use CPU!
apply kaiming weight init!
ResNet model use CPU
apply kaiming weight init
===> Building optimizer
===> Training
using `tensorboard --logdir=log` to see learning curves and net structure.
training and valid_epoch data, configures info and checkpoint were save in `log` directory.
0%| | 0/10 [00:00<?, ?epoch/s]
0step [00:00, ?step/s]
0%| | 0/10 [00:00<.., ..epoch/s]
0step [00:00, step/s]
* It will search a fashing mnist dataset.
* Then build a resnet18 for classification.
Expand Down
25 changes: 21 additions & 4 deletions docs/source/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,32 @@ ClassificationTrainer
.. autoclass:: ClassificationTrainer
:members:

GanTrainer
----------
Generative Adversarial Networks Trainer
---------------------------------------

.. automodule:: jdit.trainer
.. currentmodule:: jdit.trainer

.. autoclass:: GanTrainer
SupGanTrainer
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

.. autoclass:: SupGanTrainer
:members:

Pix2pixGanTrainer
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

.. autoclass:: Pix2pixGanTrainer
:members:

GenerateGanTrainer
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

.. autoclass:: GenerateGanTrainer
:members:

instances
----------------------
---------

.. automodule:: jdit.trainer.instances
.. currentmodule:: jdit.trainer.instances
Expand Down
2 changes: 1 addition & 1 deletion jdit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def load_model(self, model_or_path: Union[Module, DataParallel, str, None] = Non
:param weights_or_path: Pytorch weights or weights file path.
:param gpu_ids: If using gpus. default:``()``
:param strict: The same function in pytorch ``model.load_state_dict(weights,strict = strict)`` .
default:``True``
default:``True``
:return: ``module``
Example::
Expand Down
60 changes: 22 additions & 38 deletions jdit/trainer/gan/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,28 @@
from abc import abstractmethod
from torch.autograd import Variable
import torch


from jdit.optimizer import Optimizer
from jdit.model import Model
from jdit.dataset import DataLoadersFactory

class GenerateGanTrainer(SupGanTrainer):
d_turn = 1

def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets, latent_shape):
""" a gan super class
:param logdir:
:param nepochs:
:param gpu_ids_abs:
:param netG:
:param netD:
:param optG:
:param optD:
:param datasets:
:param latent_shape:
:param logdir:Path of log
:param nepochs:Amount of epochs.
:param gpu_ids_abs: he id of gpus which t obe used. If use CPU, set ``[]``.
:param netG:Generator model.
:param netD:Discrimiator model
:param optG:Optimizer of Generator.
:param optD:Optimizer of Discrimiator.
:param datasets:Datasets.
:param latent_shape:The shape of input noise.
"""
super(GenerateGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets)
self.latent_shape = latent_shape
self.loger.regist_config(self)
# self.metric = FID(self.gpu_ids)

def get_data_from_loader(self, batch_data):
Expand All @@ -36,9 +36,7 @@ def valid_epoch(self):
self.netG.eval()
self.netD.eval()
for iteration, batch in enumerate(self.datasets.loader_valid, 1):
input_cpu, ground_truth_cpu = self.get_data_from_loader(batch)
self.mv_inplace(input_cpu, self.input) # input data
self.mv_inplace(ground_truth_cpu, self.ground_truth) # real data
self.input_cpu, self.ground_truth_cpu = self.get_data_from_loader(batch)
self.fake = self.netG(self.input)
dic = self.compute_valid()
if avg_dic == {}:
Expand Down Expand Up @@ -68,11 +66,8 @@ def compute_d_loss(self):
d_fake = self.netD(self.fake.detach())
d_real = self.netD(self.ground_truth)
var_dic = {}
var_dic["GP"] = gp = gradPenalty(self.netD, self.ground_truth, self.fake, input=self.input,
use_gpu=self.use_gpu)
var_dic["WD"] = w_distance = (d_real.mean() - d_fake.mean()).detach()
var_dic["LOSS_D"] = loss_d = d_fake.mean() - d_real.mean() + gp + sgp
return: loss_d, var_dic
var_dic["LS_LOSSD"] = loss_d = 0.5 * (torch.mean((d_real - 1) ** 2) + torch.mean(d_fake ** 2))
return loss_d, var_dic
"""
loss_d = None
Expand All @@ -89,11 +84,10 @@ def compute_g_loss(self):
Example::
d_fake = self.netD(self.fake)
d_fake = self.netD(self.fake, self.input)
var_dic = {}
var_dic["JC"] = jc = jcbClamp(self.netG, self.input, use_gpu=self.use_gpu)
var_dic["LOSS_D"] = loss_g = -d_fake.mean() + jc
return: loss_g, var_dic
var_dic["LS_LOSSG"] = loss_g = 0.5 * torch.mean((d_fake - 1) ** 2)
return loss_g, var_dic
"""
loss_g = None
Expand All @@ -102,23 +96,13 @@ def compute_g_loss(self):

@abstractmethod
def compute_valid(self):
g_loss, _ = self.compute_g_loss()
d_loss, _ = self.compute_d_loss()
var_dic = {"LOSS_D": d_loss, "LOSS_G": g_loss}
# var_dic = {}
# fake = self.netG(self.input).detach()
# d_fake = self.netD(self.fake, self.input).detach()
# d_real = self.netD(self.ground_truth, self.input).detach()
#
# var_dic["G"] = loss_g = (-d_fake.mean()).detach()
# var_dic["GP"] = gp = (
# gradPenalty(self.netD, self.ground_truth, self.fake, input=self.input, use_gpu=self.use_gpu)).detach()
# var_dic["D"] = loss_d = (d_fake.mean() - d_real.mean() + gp).detach()
# var_dic["WD"] = w_distance = (d_real.mean() - d_fake.mean()).detach()
_, d_var_dic = self.compute_g_loss()
_, g_var_dic = self.compute_d_loss()
var_dic = dict(d_var_dic, **g_var_dic)
return var_dic

def test(self):
self.mv_inplace(Variable(torch.randn((16, *self.latent_shape))), self.input)
self.input = Variable(torch.randn((16, *self.latent_shape))).to(self.device)
self.netG.eval()
with torch.no_grad():
fake = self.netG(self.input).detach()
Expand Down
42 changes: 15 additions & 27 deletions jdit/trainer/gan/pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,18 @@ class Pix2pixGanTrainer(SupGanTrainer):
def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets):
""" A pixel to pixel gan trainer
:param logdir:
:param nepochs:
:param gpu_ids_abs:
:param netG:
:param netD:
:param optG:
:param optD:
:param datasets:
:param latent_shape:
:param logdir:Path of log
:param nepochs:Amount of epochs.
:param gpu_ids_abs: he id of gpus which t obe used. If use CPU, set ``[]``.
:param netG:Generator model.
:param netD:Discrimiator model
:param optG:Optimizer of Generator.
:param optD:Optimizer of Discrimiator.
:param datasets:Datasets.
"""
super(Pix2pixGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets)
# self.loger.regist_config(self)

def get_data_from_loader(self, batch_data):

input_cpu, ground_truth_cpu = batch_data[0], batch_data[1]
return input_cpu.to(self.device), ground_truth_cpu.to(self.device)

Expand Down Expand Up @@ -61,11 +58,8 @@ def compute_d_loss(self):
d_fake = self.netD(self.fake.detach())
d_real = self.netD(self.ground_truth)
var_dic = {}
var_dic["GP"] = gp = gradPenalty(self.netD, self.ground_truth, self.fake, input=self.input,
use_gpu=self.use_gpu)
var_dic["WD"] = w_distance = (d_real.mean() - d_fake.mean()).detach()
var_dic["LOSS_D"] = loss_d = d_fake.mean() - d_real.mean() + gp
return: loss_d, var_dic
var_dic["LS_LOSSD"] = loss_d = 0.5 * (torch.mean((d_real - 1) ** 2) + torch.mean(d_fake ** 2))
return loss_d, var_dic
"""
loss_d = None
Expand All @@ -82,10 +76,9 @@ def compute_g_loss(self):
Example::
d_fake = self.netD(self.fake)
d_fake = self.netD(self.fake, self.input)
var_dic = {}
var_dic["JC"] = jc = jcbClamp(self.netG, self.input, use_gpu=self.use_gpu)
var_dic["LOSS_D"] = loss_g = -d_fake.mean() + jc
var_dic["LS_LOSSG"] = loss_g = 0.5 * torch.mean((d_fake - 1) ** 2)
return loss_g, var_dic
"""
Expand Down Expand Up @@ -113,8 +106,9 @@ def compute_valid(self):
return var_dic
"""

var_dic = {}
_, d_var_dic = self.compute_g_loss()
_, g_var_dic = self.compute_d_loss()
var_dic = dict(d_var_dic, **g_var_dic)
return var_dic

def valid_epoch(self):
Expand Down Expand Up @@ -169,9 +163,3 @@ def test(self):
fake = self.netG(self.input).detach()
self.watcher.image(fake, self.current_epoch, tag="Test/fake", grid_size=(7, 7), shuffle=False)
self.netG.train()

# @property
# def configure(self):
# dict = super(Pix2pixGanTrainer, self).configure
# dict["d_turn"] = self.d_turn
# return dict

0 comments on commit f2254da

Please sign in to comment.