Skip to content

Commit

Permalink
split G and D update step
Browse files Browse the repository at this point in the history
  • Loading branch information
carpedm20 committed Apr 6, 2017
1 parent 7b0777f commit 013c1c4
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 13 deletions.
25 changes: 25 additions & 0 deletions models.py
Expand Up @@ -102,3 +102,28 @@ def main(self, x):
fc2_out = self.fc2(fc1_out).view([-1] + self.conv2_input_dim)
conv2_out = self.conv2(fc2_out)
return conv2_out

class _Loss(nn.Module):

def __init__(self, size_average=True):
super(_Loss, self).__init__()
self.size_average = size_average

def forward(self, input, target):
# _assert_no_grad(target)
backend_fn = getattr(self._backend, type(self).__name__)
return backend_fn(self.size_average)(input, target)

class L1Loss(_Loss):
r"""Creates a criterion that measures the mean absolute value of the
element-wise difference between input `x` and target `y`:
:math:`{loss}(x, y) = 1/n \sum |x_i - y_i|`
`x` and `y` arbitrary shapes with a total of `n` elements each.
The sum operation still operates over all the elements, and divides by `n`.
The division by `n` can be avoided if one sets the constructor argument `sizeAverage=False`
"""
pass
28 changes: 15 additions & 13 deletions trainer.py
Expand Up @@ -102,7 +102,7 @@ def build_model(self):
self.D.apply(weights_init)

def train(self):
l1 = nn.L1Loss()
l1 = L1Loss()

z_D = Variable(torch.FloatTensor(self.batch_size, self.z_num))
z_G = Variable(torch.FloatTensor(self.batch_size, self.z_num))
Expand All @@ -121,11 +121,10 @@ def train(self):
raise Exception("[!] Caution! Paper didn't use {} opimizer other than Adam".format(config.optimizer))

def get_optimizer(lr):
return optimizer(
chain(self.G.parameters(), self.D.parameters()),
lr=lr, betas=(self.beta1, self.beta2))
return optimizer(self.G.parameters(), lr=lr, betas=(self.beta1, self.beta2)), \
optimizer(self.D.parameters(), lr=lr, betas=(self.beta1, self.beta2))

optim = get_optimizer(self.lr)
g_optim, d_optim = get_optimizer(self.lr)

data_loader = iter(self.data_loader)
x_fixed = self._get_variable(next(data_loader))
Expand All @@ -151,21 +150,24 @@ def get_optimizer(lr):
z_D.data.normal_(0, 1)
z_G.data.normal_(0, 1)

sample_z_D = self.G(z_D)
#sample_z_D = self.G(z_D)
sample_z_G = self.G(z_G)

AE_x = self.D(x)
AE_G_d = self.D(sample_z_G.detach())
AE_G_g = self.D(sample_z_G)

d_loss_real = l1(AE_x, x)
d_loss_fake = l1(self.D(sample_z_G.detach()), sample_z_G.detach())
d_loss_fake = l1(AE_G_d, sample_z_G.detach())

AE_G = self.D(sample_z_G).detach()
d_loss = d_loss_real - k_t * d_loss_fake
g_loss = l1(sample_z_G, AE_G)
g_loss = l1(sample_z_G, AE_G_g)

loss = d_loss + g_loss

loss.backward()
optim.step()

g_optim.step()
d_optim.step()

g_d_balance = (self.gamma * d_loss_real - d_loss_fake).data[0]
k_t += self.lambda_k * g_d_balance
Expand Down Expand Up @@ -196,7 +198,7 @@ def get_optimizer(lr):
self.inject_summary(self.summary_writer, tag, value, step)

self.inject_summary(
self.summary_writer, "AE_G", AE_G.data.cpu().numpy(), step)
self.summary_writer, "AE_G", AE_G_g.data.cpu().numpy(), step)
self.inject_summary(
self.summary_writer, "AE_x", AE_x.data.cpu().numpy(), step)
self.inject_summary(
Expand All @@ -211,7 +213,7 @@ def get_optimizer(lr):
cur_measure = np.mean(measure_history)
if cur_measure > prev_measure * 0.9999:
self.lr *= 0.5
optim = get_optimizer(self.lr)
g_optim, d_optim = get_optimizer(self.lr)
prev_measure = cur_measure

def generate(self, inputs, path, idx=None):
Expand Down

0 comments on commit 013c1c4

Please sign in to comment.