Skip to content

Commit

Permalink
Fix Instances name
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Dec 22, 2019
1 parent b1e52e1 commit 099eec1
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions jdit/trainer/instances/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .fashionClassification import FashionClassTrainer, start_fashionClassTrainer
from .fashionGenerateGan import FashionGenerateGenerateGanTrainer, start_fashionGenerateGanTrainer
from .fashionGenerateGan import FashionGenerateGanTrainer, start_fashionGenerateGanTrainer
from .cifarPix2pixGan import start_cifarPix2pixGanTrainer
from .fashionClassParallelTrainer import start_fashionClassPrarallelTrainer
from .fashionAutoencoder import FashionAutoEncoderTrainer, start_fashionAutoencoderTrainer
__all__ = ['FashionClassTrainer', 'start_fashionClassTrainer',
'FashionGenerateGenerateGanTrainer', 'start_fashionGenerateGanTrainer',
'FashionGenerateGanTrainer', 'start_fashionGenerateGanTrainer',
'cifarPix2pixGan', 'start_cifarPix2pixGanTrainer', 'start_fashionClassPrarallelTrainer',
'start_fashionAutoencoderTrainer', 'FashionAutoEncoderTrainer']
12 changes: 6 additions & 6 deletions jdit/trainer/instances/fashionGenerateGan.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ def forward(self, input_data):
return out


class FashionGenerateGenerateGanTrainer(GenerateGanTrainer):
class FashionGenerateGanTrainer(GenerateGanTrainer):
d_turn = 1

def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, dataset, latent_shape):
super(FashionGenerateGenerateGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD,
dataset,
latent_shape=latent_shape)
super(FashionGenerateGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD,
dataset,
latent_shape=latent_shape)

data, label = self.datasets.samples_train
self.watcher.embedding(data, data, label, global_step=1)
Expand Down Expand Up @@ -130,8 +130,8 @@ def start_fashionGenerateGanTrainer(gpus=(), nepochs=50, lr=1e-3, depth_G=32, de
print('===> Training')
print("using `tensorboard --logdir=log` to see learning curves and net structure."
"training and valid_epoch data, configures info and checkpoint were save in `log` directory.")
Trainer = FashionGenerateGenerateGanTrainer("log/fashion_generate", nepochs, gpus, G, D, opt_G, opt_D, mnist,
latent_shape)
Trainer = FashionGenerateGanTrainer("log/fashion_generate", nepochs, gpus, G, D, opt_G, opt_D, mnist,
latent_shape)
if run_type == "train":
Trainer.train()
elif run_type == "debug":
Expand Down

0 comments on commit 099eec1

Please sign in to comment.