Skip to content

Commit

Permalink
reqst
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Oct 26, 2018
1 parent ce7702e commit f17d75e
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions jdit/trainer/gan/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,26 @@
from abc import abstractmethod
from tqdm import tqdm
from torch.autograd import Variable
from jdit.metric.inception import FID
# from jdit.metric.inception import FID
# from ...metric import FID
import torch


class GanTrainer(SupTrainer):
d_turn = 1
def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets, latent_shape):
""" a gan super class
:param logdir:
:param nepochs:
:param gpu_ids_abs:
:param netG:
:param netD:
:param optG:
:param optD:
:param datasets:
:param latent_shape:
"""
super(GanTrainer, self).__init__(nepochs, logdir, gpu_ids_abs=gpu_ids_abs)
self.netG = netG
self.netD = netD
Expand All @@ -19,7 +31,7 @@ def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, dataset
self.fake = None
self.fixed_input = None
self.latent_shape = latent_shape
self.metric = FID(self.gpu_ids)
# self.metric = FID(self.gpu_ids)
self.loger.regist_config(self.netG, config_filename="Generator")
self.loger.regist_config(self.netD, config_filename="Discriminator")
self.loger.regist_config(datasets)
Expand Down

0 comments on commit f17d75e

Please sign in to comment.