In [None]:
import os

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as tt
import torch
import torch.nn as nn

from tqdm.notebook import tqdm
import torch.nn.functional as F
from torchvision.utils import save_image
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# http://mathieu.delalandre.free.fr/projects/sesyd/symbols/floorplans.html
!wget -nc http://mathieu.delalandre.free.fr/projects/sesyd/symbols/floorplans/floorplans16-01.zip
!unzip -q -o -d data/ floorplans16-01.zip

DATA_DIR = 'data/'
print(os.listdir(DATA_DIR))

In [None]:
image_size = 64
batch_size = 128
#stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
stats = (1, 1, 1), (1, 1, 1)

transform = tt.Compose([
    tt.Grayscale(num_output_channels=1),
    tt.Resize(image_size),
    tt.CenterCrop(image_size),
    tt.ToTensor(),
    #tt.Normalize(*stats),
])

train_ds = ImageFolder(DATA_DIR, transform = transform)

In [None]:
train_dl = DataLoader(train_ds, batch_size, shuffle = True, num_workers = 0, pin_memory = True)

In [None]:
#def denorm(img_tensors):
#    return img_tensors * stats[1][0] + stats[0][0]

def show_images(images, nmax = 16):
    fig, ax = plt.subplots(figsize = (16,16))
    ax.set_xticks([])
    ax.set_yticks([])
    denrom_batch = images.detach()[:nmax]
    grid = make_grid(denrom_batch, nrow = 4).permute(1,2,0)
    # print(torch.max(grid))
    ad = (grid*255).byte()
    ax.imshow(ad)

def show_batch(dl, nmax = 16)    :
    for images, _ in dl:
        show_images(images, nmax)
        break

In [None]:
show_batch(train_dl)

In [None]:
def get_default_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking = True)

class DeviceDataLoader():
    def __init__(self, dl, device):
        self.dl = dl
        self.device  = device
        
    def __iter__(self):
        for b in self.dl:
            yield to_device(b, self.device)
    
    def __len__(self):
        return len(self.dl)
    
        

In [None]:
device = get_default_device()
device

In [None]:
train_dl = DeviceDataLoader(train_dl, device)

In [None]:
first_kernel = image_size // 64
first_stride = first_kernel * 2
first_padding = first_kernel // 4
discriminator = nn.Sequential(
    #in: 3x64x64
    nn.Conv2d(1, 64, kernel_size = first_kernel, stride = first_stride, padding = first_padding, bias = False),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.1, inplace = True),
    #out: 64x32x32
    
    nn.Conv2d(64, 128, kernel_size = 4, stride = 2, padding = 1, bias = False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.1, inplace = True),
    #out: 128x16x16
    
    nn.Conv2d(128, 256, kernel_size = 4, stride = 2, padding = 1, bias = False),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.1, inplace = True),
    #out: 256x8x8
    
    nn.Conv2d(256, 512, kernel_size = 4, stride = 2, padding = 1, bias = False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.1, inplace = True),
    #out: 512x4x4
    
    nn.Conv2d(512, 1, kernel_size = 4, stride = 1, padding = 0, bias = False),
    
    #out: 1x1x1
    
    nn.Flatten(),
    nn.Sigmoid()
)

discriminator = to_device(discriminator,device)

In [None]:
# test output shape of discriminator
print (first_kernel, first_stride, first_padding)
for real_images, _ in tqdm(train_dl):
    print (real_images.shape)
    d = discriminator(to_device(real_images,device))
    print(d.shape)
    break

In [None]:
latent_size = 512

In [None]:
import math

num_layers =  int(math.log(image_size / 8,2))
print(num_layers)
startsize = int(64*2**num_layers)
generator = nn.Sequential(
    #in: latent_size x 1 x 1
    nn.ConvTranspose2d(latent_size, startsize, kernel_size = 4, stride = 1, padding = 0, bias = False),
    nn.BatchNorm2d(startsize),
    nn.ReLU(True)
    )
a=startsize

for i in reversed(range(num_layers)):
    b = a // 2
    generator.add_module("convtranspose2d"+str(i), nn.ConvTranspose2d(a,b,kernel_size=4, stride=2, padding=1,bias=False))
    generator.add_module("batchnorm"+str(i),nn.BatchNorm2d(b))
    generator.add_module("relu"+str(i),nn.ReLU(True))
    a //= 2

generator.add_module("lastconv",nn.ConvTranspose2d(64, 1, kernel_size= 4, stride= 2, padding= 1, output_padding= 0, bias = False))
generator.add_module("tanh",nn.Tanh())

generator = to_device(generator, device)
print(generator)

In [None]:
latent = torch.randn(batch_size, latent_size, 1, 1, device = device)
fake_images = generator(latent)
print(fake_images.shape, torch.max(fake_images))
show_images(fake_images.cpu())

fake_preds = discriminator(fake_images)
print(fake_preds.shape)

In [None]:
def train_discriminator(real_images, opt_d):
    # clear grad
    opt_d.zero_grad()
    
    # pass real images through discriminator
    real_preds = discriminator(real_images)
    real_targets = torch.ones(real_images.size(0), 1, device = device)

    real_loss = F.binary_cross_entropy(real_preds, real_targets)
    real_score = torch.mean(real_preds).item()
    
    # generate fake images
    latent = torch.randn(batch_size, latent_size, 1, 1, device = device)
    fake_images = generator(latent)
    
    # pass fake images through discriminator
    fake_targetes = torch.zeros(fake_images.size(0), 1, device = device)
    fake_preds = discriminator(fake_images)
    fake_loss = F.binary_cross_entropy(fake_preds, fake_targetes)
    fake_score = torch.mean(fake_preds).item()
    
    # update discriminator weights
    loss = real_loss + fake_loss
    loss.backward()
    opt_d.step()
    
    return loss.item(), real_score, fake_score

In [None]:
def train_generator(opt_g):
    # clear generator gradients
    opt_g.zero_grad()
    
    # generate fake images
    latent = torch.randn(batch_size, latent_size, 1, 1, device = device)
    fake_images = generator(latent)
    
    # try to fool the discriminator
    preds = discriminator(fake_images)
    targets = torch.ones(batch_size, 1, device = device)
    loss = F.binary_cross_entropy(preds, targets)
    
    # update generator weights
    loss.backward()
    opt_g.step()
    
    return loss.item()

In [None]:
sample_dir = 'generated'
os.makedirs(sample_dir, exist_ok = True)

def save_samples(index, latent_tensors, show= False):
    fake_images = generator(latent_tensors).cpu()
    fake_filename = 'generated-{0:0=4d}.png'.format(index)
    # print(fake_images.shape, torch.max(fake_images))
    
    filename = os.path.join(sample_dir, fake_filename)
    save_image(fake_images, filename, nrow = 4)
    # print("Saving ", fake_filename)
    if show:
        show_images(fake_images)

In [None]:
fixed_latent = torch.randn(image_size, latent_size, 1, 1, device = device)
#save_samples(0, fixed_latent,  True)

In [None]:
def fit(epochs, lr, start_idx = 1):
    torch.cuda.empty_cache()
    
    # losses and scores
    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []
    
    # create oprimizeres
    opt_d = torch.optim.Adam(discriminator.parameters(), lr = lr/10, weight_decay=0.89, betas = (0.5, 0.999))
    opt_g = torch.optim.Adam(generator.parameters(), lr = lr, weight_decay=0.99,  betas = (0.5, 0.999))
    
    for epoch in range(epochs):
        for real_images, _ in tqdm(train_dl):
            # train discriminator
            loss_d, real_score, fake_score = train_discriminator(real_images, opt_d)
            
            # train generator
            loss_g =  train_generator(opt_g)
            
            # record losses and scores
            
            losses_g.append(loss_g)
            losses_d.append(loss_d)
            real_scores.append(real_score)
            fake_scores.append(fake_score)
            
            # pring losses and scores
            message_template = "Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}"
            print(message_template.format(
                epoch + 1,
                epochs,
                loss_g,
                loss_d,
                real_score,
                fake_score
            ))
            
            save_samples(epoch+start_idx, fixed_latent, show = False)
            
    return losses_g, losses_d, real_scores, fake_scores

In [None]:
lr = 5e-6
epochs = 111

history = fit(epochs, lr)
losses_g, losses_d, real_scores, fake_scores = history

In [None]:
print(os.listdir('generated/'))
from IPython.display import Image
Image('generated/generated-{:04d}.png'.format(epochs))



In [None]:
plt.plot(losses_d, '-')
plt.plot(losses_g, '-')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['Discriminator', 'Generator'])
plt.title('Losses')

In [None]:
plt.plot(real_scores, '-')
plt.plot(fake_scores, '-')
plt.xlabel('epoch')
plt.ylabel('score')
plt.legend(['Real', 'Fake'])
plt.title('Scores')