In [12]:
from torch.utils.data import DataLoader, dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import image_dataset
import torch.optim as optim
from torchvision import transforms, models
from tqdm import tqdm

In [2]:
#set manual seed
torch.manual_seed(42)

# Set device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
#define the generator network
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.tanh = nn.Tanh()

    def forward(self, x):
        return self.tanh(self.fc2(self.relu(self.fc1(x))))
    
#define the discriminator network
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return self.sigmoid(self.fc2(self.relu(self.fc1(x))))

In [5]:
latent_size = 100
hidden_size = 256
image_size = 512
lr = .0002
epochs = 1
batch_size = 2

dataset = image_dataset.ImageDataset(transform=transforms.ToTensor())

generator = Generator(latent_size, hidden_size, image_size * image_size)
discrimator = Discriminator(image_size * image_size, hidden_size, 1)
dataloader = DataLoader(dataset, batch_size, shuffle=True)

criterion = nn.BCELoss()
gen_optim = optim.Adam(generator.parameters(), lr)
disc_optim = optim.Adam(discrimator.parameters(), lr)

In [17]:
for epoch in range(epochs):
    for i, (real_image, _) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}/{epochs}", leave=False)):

        real_image = real_image.view(-1, image_size * image_size)

        noise = torch.randn(batch_size, latent_size)

        #train the discriminator
        disc_optim.zero_grad()
        real_labels = torch.ones(batch_size, 1)

        real_outputs = discrimator(real_image)
        real_loss = criterion(real_outputs, real_labels)

        fake_images = generator(noise)
        fake_outputs = discrimator(fake_images.detach())
        fake_labels = torch.zeros_like(fake_outputs)
        fake_loss = criterion(fake_outputs, fake_labels)

        disc_loss = real_loss + fake_loss
        disc_loss.backward()
        disc_optim.step()

        gen_optim.zero_grad()
        outputs = discrimator(fake_images)
        gen_loss = criterion(outputs, real_labels)
        
        gen_loss.backward()
        gen_optim.step()

        tqdm.write(f"Generator Loss: {gen_loss.item():.4f}, Discriminator Loss: {disc_loss.item():.4f}")

Epoch 0/1:   2%|▏         | 1/48 [00:01<01:16,  1.63s/it]

Generator Loss: 9.5065, Discriminator Loss: 0.0065


Epoch 0/1:   4%|▍         | 2/48 [00:03<01:21,  1.77s/it]

Generator Loss: 10.6996, Discriminator Loss: 0.0219


Epoch 0/1:   6%|▋         | 3/48 [00:05<01:17,  1.73s/it]

Generator Loss: 7.6465, Discriminator Loss: 0.0048


Epoch 0/1:   8%|▊         | 4/48 [00:06<01:13,  1.67s/it]

Generator Loss: 10.6524, Discriminator Loss: 0.0001


Epoch 0/1:  10%|█         | 5/48 [00:08<01:11,  1.65s/it]

Generator Loss: 4.3229, Discriminator Loss: 3.5758


Epoch 0/1:  12%|█▎        | 6/48 [00:09<01:08,  1.64s/it]

Generator Loss: 3.7464, Discriminator Loss: 4.9808


Epoch 0/1:  15%|█▍        | 7/48 [00:11<01:08,  1.67s/it]

Generator Loss: 8.4528, Discriminator Loss: 0.3688


Epoch 0/1:  17%|█▋        | 8/48 [00:13<01:06,  1.65s/it]

Generator Loss: 8.2216, Discriminator Loss: 0.1098


Epoch 0/1:  19%|█▉        | 9/48 [00:14<01:03,  1.64s/it]

Generator Loss: 6.4122, Discriminator Loss: 4.2536


Epoch 0/1:  21%|██        | 10/48 [00:16<01:03,  1.68s/it]

Generator Loss: 9.8856, Discriminator Loss: 0.1015


Epoch 0/1:  23%|██▎       | 11/48 [00:18<01:01,  1.66s/it]

Generator Loss: 8.3106, Discriminator Loss: 0.0037


Epoch 0/1:  25%|██▌       | 12/48 [00:20<00:59,  1.67s/it]

Generator Loss: 6.0665, Discriminator Loss: 0.3096


                                                          

KeyboardInterrupt: 