# Even number GAN
Below we collected into a single notebook this [python code](https://github.com/nbertagnolli/pytorch-simple-gan) of an author who wrote a blog post on this [simple GAN](https://towardsdatascience.com/build-a-super-simple-gan-in-pytorch-54ba349920e4). 

In [1]:
from typing import Tuple
import math

import torch
import torch.nn as nn

from models import Discriminator, Generator
from utils import generate_even_data, convert_float_matrix_to_int_list

Trains the even GAN. 

    Args:
        max_int: The maximum integer our dataset goes to.  It is used to set the size of the binary
            lists
        batch_size: The number of examples in a training batch
        training_steps: The number of steps to train on.
        learning_rate: The learning rate for the generator and discriminator
        print_output_every_n_steps: The number of training steps before we print generated output

    Returns:
        generator: The trained generator model
        discriminator: The trained discriminator model

In [2]:
max_int = 128
batch_size = 16
training_steps = 500
learning_rate = 0.001
print_output_every_n_steps = 10

    
input_length = int(math.log(max_int, 2))

    # Models
generator = Generator(input_length)
discriminator = Discriminator(input_length)

    # Optimizers
generator_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)

    # loss
loss = nn.BCELoss()

for i in range(training_steps):
        # zero the gradients on each iteration
    generator_optimizer.zero_grad()

    # Create noisy input for generator
    # Need float type instead of int
    noise = torch.randint(0, 2, size=(batch_size, input_length)).float()
    
    # Generate examples of even real data
    true_labels, true_data = generate_even_data(max_int, batch_size=batch_size)
    true_labels = torch.tensor(true_labels).float()
    true_data = torch.tensor(true_data).float()
    
    # Train the generator
    # We invert the labels here and don't train the discriminator because we want the generator
    # to make things the discriminator classifies as true.
    G_of_noise = generator(noise)
    D_of_G_of_noise = discriminator(G_of_noise)
    generator_loss = loss(D_of_G_of_noise, true_labels)
    generator_loss.backward()
    generator_optimizer.step()

    # Train the discriminator on the true/generated data
    discriminator_optimizer.zero_grad()
    true_discriminator_out = discriminator(true_data)
    true_discriminator_loss = loss(true_discriminator_out, true_labels)

    # add .detach() here think about this
    generator_discriminator_out = discriminator(G_of_noise.detach())
    generator_discriminator_loss = loss(generator_discriminator_out, torch.zeros(batch_size))
    discriminator_loss = (true_discriminator_loss + generator_discriminator_loss) / 2
    discriminator_loss.backward()
    discriminator_optimizer.step()
    if i % print_output_every_n_steps == 0:
        print(convert_float_matrix_to_int_list(G_of_noise))
        #print(discriminator_loss.)

  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


[2, 100, 98, 37, 115, 84, 66, 99, 103, 109, 115, 33, 100, 115, 114, 96]
[99, 98, 115, 100, 69, 85, 119, 100, 33, 98, 103, 64, 96, 100, 102, 117]
[103, 101, 103, 99, 68, 103, 100, 64, 84, 118, 100, 111, 69, 99, 98, 99]
[103, 101, 97, 97, 103, 99, 99, 96, 101, 102, 100, 101, 98, 103, 117, 34]
[103, 33, 103, 101, 64, 68, 100, 99, 100, 102, 99, 101, 103, 100, 101, 96]
[101, 100, 101, 103, 98, 96, 99, 117, 101, 99, 97, 69, 96, 101, 101, 99]
[99, 34, 101, 99, 101, 101, 101, 97, 100, 101, 101, 97, 113, 97, 101, 97]
[100, 117, 117, 97, 100, 101, 101, 117, 99, 100, 99, 100, 100, 101, 101, 117]
[101, 101, 99, 101, 103, 101, 99, 101, 97, 68, 101, 101, 100, 101, 100, 101]
[101, 103, 101, 101, 33, 101, 101, 97, 101, 97, 101, 37, 117, 97, 100, 101]
[101, 101, 113, 101, 101, 101, 101, 101, 100, 101, 99, 99, 101, 101, 101, 101]
[99, 101, 103, 101, 97, 101, 101, 101, 101, 100, 101, 103, 101, 103, 100, 101]
[97, 101, 101, 101, 117, 103, 101, 101, 101, 103, 101, 101, 101, 101, 99, 101]
[101, 100, 100, 97