In [60]:
import torchvision.datasets as dsets
from torch.utils.data import DataLoader
from torchvision import transforms
import torch
from torch import nn

batch_size = 32
path_to_root_data = "data"
transforms = transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize(64),
                transforms.Normalize(mean=(0.5, ), 
                                     std=(0.5, ))])

train_dataset = dsets.CelebA(root=path_to_root_data,
                             split='train',
                             transform=transforms,
                             download=True)

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True)

Files already downloaded and verified


In [61]:
from torchvision.utils import save_image
import matplotlib.pyplot as plt

def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

train_dataset[0][0].shape

torch.Size([3, 78, 64])

In [62]:
class Descriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Define the layers of the CNN
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1),
            nn.GELU(), 
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2)
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1),
            nn.GELU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2)
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1),
            nn.GELU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2)
        )

        self.layer4 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=2, stride=1),
            nn.GELU(),
            nn.BatchNorm2d(256),
            nn.AdaptiveAvgPool2d(1)
        )

        # self.layer5 = nn.Sequential(
        #     nn.Conv2d(in_channels=128, out_channels=256, kernel_size=2, stride=1),
        #     nn.GELU(),
        #     nn.BatchNorm2d(256),
        #     nn.AdaptiveAvgPool2d(1)
        # )

        self.linear_layer = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        # Apply the layers in the forward pass
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        # x = self.layer5(x)
#         x = self.layer6(x)
        x = x.view(x.size(0), -1)
        # print(x.size())
        x = self.linear_layer(x)
        
        return x
    
# class Generator(nn.Module):
#     def __init__(self):
#         super().__init__()

#         self.linear_layer = nn.Sequential(
#             nn.Linear(64, 256),
#             nn.ReLU(),
    
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(64, 128 * 7 * 7),  # Increase the size to match the required image size
            nn.BatchNorm1d(256 * 14 * 14),
            nn.ReLU(inplace=True),
            nn.Unflatten(1, (256, 14, 14)),  # Reshape to 256x14x14
            
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # Upsample to 28x28
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # Upsample to 56x56
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # Upsample to 112x112
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),  # Upsample to 224x224
            nn.Tanh()  # Output layer with tanh activation function
        )

    def forward(self, z):
        return self.model(z)
            

In [57]:
D = Descriminator()

# Generator 
G = Generator()

if torch.cuda.is_available():
    D.cuda()
    G.cuda()

# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()

d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

In [59]:
for epoch in range(2):
    accuracy_real = 0
    accuracy_fake = 0

    for i, (images, _) in enumerate(train_loader):
        # Build mini-batch dataset
        batch_size = images.size(0)
        # images = images.view(batch_size, -1)
        
        # Create the labels which are later used as input for the BCE loss
        real_labels = torch.ones(batch_size)
        fake_labels = torch.zeros(batch_size)
        
        #============= Train the discriminator =============#
        # Compute BCE_Loss using real images where BCE_Loss(x, y):
        #         - y * log(D(x)) - (1-y) * log(1 - D(x))
        # Second term of the loss is always zero since real_labels == 1
        D.train()
        G.train(False) # <-> G.eval()

        outputs = D(images) # Real images
        d_loss_real = criterion(outputs.squeeze(1), real_labels)
        real_score = outputs
        
        # Compute BCELoss using fake images
        # First term of the loss is always zero since fake_labels == 0
        z = torch.randn(batch_size, 64)
        fake_images = G(z) # Generate fake images
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs.squeeze(1), fake_labels)
        fake_score = outputs
        
        # Backprop + Optimize
        d_loss = d_loss_real + d_loss_fake
        D.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        #=============== Train the generator ===============#
        # Compute loss with fake images
        D.train(False)
        G.train()
        z = torch.randn(batch_size, 64)
        fake_images = G(z)
        outputs = D(fake_images)
        
        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))        
        g_loss = criterion(outputs.squeeze(1), real_labels)
        
        # Backprop + Optimize
        D.zero_grad()
        G.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 300 == 0:
            print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, '
                  'g_loss: %.4f, D(x): %.2f, D(G(z)): %.2f' 
                  %(epoch, 200, i+1, 600, d_loss.data, g_loss.data,
                    real_score.data.mean(), fake_score.data.mean()))
    
    # Save real images
    if (epoch+1) == 1:
        images = images.view(images.size(0), 1, 28, 28)
        save_image(denorm(images.data), './data/real_images.png')

    plt.imshow(denorm(fake_images.data[0]).view(28, 28).cpu().numpy(), cmap='gray')
    plt.show()
    
    # Save sampled images
    fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images.data), './data/fake_images-%d.png' %(epoch+1))

KeyboardInterrupt: 