# Error in Conneau's reported loss function (Eqn. 4)

Conneau et al. (2017) report that the generator in their GAN would also benefit from being given the loss term proportional to the misclassification-rate of the discriminator on all-real minibatches, where the generator does not contributes to generating these minibatches and is (the generator...) hence not part of the back-prop graph of the loss incurred from the discriminator misclassifying the all-real minibatches.

This script demonstrates that the computation of gradients of a model does not get affected when adding an independent loss term to the model's loss to be backpropagated.

In [1]:
# Get list of all gradients in a model

def grads(model):
    grads = []
    for p in model.parameters():
        grads += p.grad
    return grads

In [2]:
import torch

generator = torch.nn.Linear(5,1)
discriminator = torch.nn.Linear(5,1)

train_data = torch.rand(7,5)
y_true = torch.rand(7,1)

print('Gen:', generator)
print('Dis:', discriminator, '\n')

generator.train()
discriminator.eval()  # Doesn't matter whether .train() or .eval() at this point

criterion = torch.nn.MSELoss()

# First, compute generator's gradients given loss for generator only.
generator.zero_grad()

y_pred = generator(train_data)
loss = criterion(y_pred, y_true)
loss.backward()
print('Grads for generator for loss_gen only:\n', grads(generator), '\n')

# Second, compute generator's gradients given loss for generator+discriminator.
generator.zero_grad()
discriminator.zero_grad()

y_pred = generator(train_data)
loss_gen = criterion(y_pred, y_true)

y_pred = discriminator(train_data)
loss_dis = criterion(y_pred, y_true)

loss_gen = loss_gen + loss_dis

loss_gen.backward()
print('Grads for generator for loss_gen + loss_dis:\n', grads(generator))

Gen: Linear(in_features=5, out_features=1, bias=True)
Dis: Linear(in_features=5, out_features=1, bias=True) 

Grads for generator for loss_gen only:
 [tensor([-0.0993, -0.0389, -0.1039,  0.0662, -0.0270]), tensor(-0.2005)] 

Grads for generator for loss_gen + loss_dis:
 [tensor([-0.0993, -0.0389, -0.1039,  0.0662, -0.0270]), tensor(-0.2005)]
