# Tutorial of simple GAN builded in pytorch

In [None]:
import torch 
from torch import nn
import numpy as np
import glob 
import imageio
import math 
import matplotlib.pyplot as plt

It’s a good practice to set up a <b>random generator seed</b> so that the experiment can be replicated identically on any machine. To do that in PyTorch, run the following code:

In [None]:
torch.manual_seed(111)
plt.style.use('dark_background')

Now we preparing the training data composed as follow:
<br>
<img src="https://latex.codecogs.com/svg.image?\left&space;({x_1}{,}{x_2}\right&space;):" />
<br>
<img src='https://latex.codecogs.com/svg.image?{x_2}{=}{sin}{x_1}'/>
<img src='https://latex.codecogs.com/svg.image?0\leq&space;{{x_1}\leq2\pi'/>


Now we create a data loader called train_loader, which will shuffle the data from train_set and return batches of 32 samples that you’ll use to train the neural networks.

After setting up the training data, you need to create the neural networks for the discriminator and generator that will compose the GAN. In the following section, you’ll implement the discriminator.

In [None]:
train_data_length = 1024*2
train_data = torch.zeros((train_data_length, 2))
train_data[:,0] = 2 * math.pi * torch.rand(train_data_length)
train_data[:,1] = torch.sin(train_data[:,0])**2
train_labels = torch.zeros((train_data_length))

train_set = [
    (train_data[i], train_labels[i]) for i in range(train_data_length)
]

plt.plot(train_data[:, 0], train_data[:, 1], ".")


#plt.plot(train_set)
#plt.show()


In [None]:
batch_size = 32
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)

# Discriminator implementation
The discriminator is a model with a two-dimensional input and a one-dimensional output. It’ll receive a sample from the real data or from the generator and will provide the probability that the sample belongs to the real training data. The code below shows how to create a discriminator:

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(2, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        output = self.model(x)
        return output
discriminator = Discriminator()


# Generator Implementation
In generative adversarial networks, the generator is the model that takes samples from a latent space as its input and generates data resembling the data in the training set. In this case, it’s a model with a two-dimensional input, which will receive random points (z₁, z₂), and a two-dimensional output that must provide (x̃₁, x̃₂) points resembling those from the training data.

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(2, 16),
            nn.ReLU(),
            nn.Linear(16, 32),
            nn.ReLU(),
            nn.Linear(32, 2),
        )

    def forward(self, x):
        output = self.model(x)
        return output

generator = Generator()


In [None]:
print(plt.style.available)

# Training
Now we can pass to thr taining phase chosing hyperparameter.
<br>
The binary cross-entropy function is a suitable loss function for training the discriminator because it considers a binary classification task. It’s also suitable for training the generator since it feeds its output to the discriminator, which provides a binary observable output.

In [None]:
lr = 0.001
epochs = 400
loss_function = nn.BCELoss()
generator_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
loss_d = []
loss_g = []
epochs_list = []

for epoch in range(epochs):
    
    #print(epochs_list)
    #print(loss_d)
    for n, (real_samples, _) in enumerate(train_loader):
        # Data for training the discriminator
        real_samples_labels = torch.ones((batch_size, 1))
        latent_space_samples = torch.randn((batch_size, 2))
        generated_samples = generator(latent_space_samples)
        generated_samples_labels = torch.zeros((batch_size, 1))
        all_samples = torch.cat((real_samples, generated_samples))
        all_samples_labels = torch.cat(
            (real_samples_labels, generated_samples_labels)
        )

        # Training the discriminator
        discriminator.zero_grad()
        output_discriminator = discriminator(all_samples)
        loss_discriminator = loss_function(
            output_discriminator, all_samples_labels)
        loss_discriminator.backward()
        discriminator_optimizer.step()

        # Data for training the generator
        latent_space_samples = torch.randn((batch_size, 2))

        # Training the generator
        generator.zero_grad()
        generated_samples = generator(latent_space_samples)
        output_discriminator_generated = discriminator(generated_samples)
        loss_generator = loss_function(
            output_discriminator_generated, real_samples_labels
        )
        loss_generator.backward()
        generator_optimizer.step()


        # Metrica di visualizzazione
        if epoch % 10 == 9 and n == batch_size - 1:
            print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
            print(f"Epoch: {epoch} Loss G.: {loss_generator}")
            print(type(loss_discriminator))
            
            fig, axs = plt.subplots(2)
            axs[0].plot(np.array(epochs_list), loss_d)
            axs[0].set_title('Discriminator loss')
            axs[1].plot(np.array(epochs_list), loss_g)
            axs[1].set_title('Generator loss')
            #axs[2].plot(generated_samples.detach()[:,0], generated_samples.detach()[:,1], '.')
            #axs[2].set_title('Generated Sample')
            plt.subplots_adjust(left=0.1,
                    bottom=0.1, 
                    right=0.9, 
                    top=0.99, 
                    wspace=0.4, 
                    hspace=0.6)
            plt.show()
            plt.plot(generated_samples.detach()[:,0], generated_samples.detach()[:,1], '.')

    latent_space_samples = torch.randn(100, 2)
    generated_samples = generator(latent_space_samples)
    generated_samples = generated_samples.detach()
    plt.plot(generated_samples[:, 0], generated_samples[:, 1], ".")
    plt.savefig("{}.png".format(epoch))
    plt.show()
    with imageio.get_writer('mygif.gif', mode='I') as writer:
        filenames = glob.glob("*.png")
        filenames = sorted(filenames)
        for filename in filenames:
            image = imageio.imread(filename)
            writer.append_data(image)
        image = imageio.imread(filename)
        writer.append_data(image)

    
    epochs_list.append(epoch)
    loss_d.append(np.array(loss_discriminator.item()))
    loss_g.append(np.array(loss_generator.item()))
    

In [None]:
latent_space_samples = torch.randn(100, 2)
generated_samples = generator(latent_space_samples)

generated_samples = generated_samples.detach()
plt.plot(generated_samples[:, 0], generated_samples[:, 1], ".")
plt.plot(train_data[:, 0], train_data[:, 1], ".")
