Skip to content

Commit

Permalink
SupGanTrain refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Nov 9, 2018
1 parent 6159088 commit a704c53
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#mypackage/metric/Cifar10_M.csv
#mypackage/metric/Cifar10_S.csv

log
jdit/data
.idea
datasets
Expand Down
2 changes: 1 addition & 1 deletion jdit/trainer/instances/cifarPix2pixGan.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class CifarPix2pixGanTrainer(Pix2pixGanTrainer):
def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets):
super(CifarPix2pixGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD,
datasets)
self.datasets.dataset_train = self.datasets.dataset_valid

self.watcher.graph(netG, (4, 1, 32, 32), self.use_gpu)

def get_data_from_loader(self, batch_data):
Expand Down
1 change: 0 additions & 1 deletion jdit/trainer/instances/fashingClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class FashingClassTrainer(ClassificationTrainer):

def __init__(self, logdir, nepochs, gpu_ids, net, opt, datasets):
super(FashingClassTrainer, self).__init__(logdir, nepochs, gpu_ids, net, opt, datasets)

self.watcher.graph(net, self.datasets.batch_shape, self.use_gpu)
data, label = self.datasets.samples_train
self.watcher.embedding(data, data, label)
Expand Down

0 comments on commit a704c53

Please sign in to comment.