In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
torch.manual_seed(47)

# Generator network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = nn.Linear(1, 512)  
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(512, 32 * 32 * 3) 

    def forward(self, x):
        x = self.relu(self.fc(x))
        x = torch.sigmoid(self.fc2(x))
        return x


# Discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(32 * 32 * 3, 512)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(512, 1)  # Output size 1 for binary classification

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten input tensor
        x = self.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))  # Using sigmoid for binary classification
        return x

def APEGAN():
    for epoch in range(1, epochs + 1):
        for i, (adversarial_images, _) in enumerate(test_dataloader):
            valid = torch.ones(adversarial_images.size(0), 1)
            generator = Generator()
            discriminator = Discriminator()
            adversarial_loss = nn.BCELoss()
            mse_loss = nn.MSELoss()  # Mean Squared Error Loss
            optimizer_G = optim.SGD(generator.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
            optimizer_D = optim.SGD(discriminator.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
            optimizer_G.zero_grad()
            z = torch.randn(adversarial_images.size(0), 1)
            gen_images = generator(z) #Passing adversarial_images to the generator to generate new gen_images
            real_output = discriminator(image_input) #real_output is generated using the input images
            fake_output = discriminator(gen_images.detach())  # Detach to avoid backprop through Generator           
            # Calculate loss for Discriminator
            real_loss = adversarial_loss(real_output, torch.ones_like(real_output))
            fake_loss = adversarial_loss(fake_output, torch.zeros_like(fake_output))
           
            g_loss = mse_loss(discriminator(gen_images), valid)
            # g_loss += adversarial_loss(fake_output)
            g_loss.backward() 
            optimizer_G.step()
            optimizer_D.zero_grad()
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()
        
            if i % 100 == 0:
                print(
                     "[Epoch %d]  [D loss: %f] [G loss: %f]"
                     % (epoch, d_loss.item(), g_loss.item())
                )

    correct_real = (real_output > 0.5).sum().item()  # Number of correctly classified real images
    correct_fake = (fake_output < 0.5).sum().item()  # Number of correctly classified fake images
    total_real = real_output.size(0)  # Total number of real images
    total_fake = fake_output.size(0)  # Total number of fake images
        
        # Calculate accuracy
    accuracy = (correct_real + correct_fake) / (total_real + total_fake)

    return accuracy

accuracy = APEGAN()
print("Test Accuracy {:.4f}".format(accuracy))
