<a href="https://colab.research.google.com/github/marcomarchesi/giraffe/blob/master/notebooks/gdcgan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!cp /content/drive/My\ Drive/giraffe/root.zip root.zip 
!cp /content/drive/My\ Drive/giraffe/images_50000.pkl images.pkl
!unzip root.zip > ziplog.txt


replace root/.DS_Store? [y]es, [n]o, [A]ll, [N]one, [r]ename: A


In [None]:
import torch
from torch import nn
from torch.utils.data import random_split, Dataset, DataLoader
from torchvision import transforms
from torch.autograd import Variable
import torchvision
from torchvision.datasets import ImageFolder
import numpy as np
from PIL import Image
import os
import pickle

In [None]:
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=3, padding=1),  # b, 16, 10, 10
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),  # b, 16, 5, 5
            nn.Conv2d(16, 8, 3, stride=2, padding=1),  # b, 8, 3, 3
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=1)  # b, 8, 2, 2
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, 3, stride=2),  # b, 16, 5, 5
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1),  # b, 8, 15, 15
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 3, 6, stride=2, padding=1),  # b, 1, 28, 28
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
img_size = 128
z_dim = 100
scene_size = 35


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = img_size // 4
        self.l1 = nn.Sequential(nn.Linear(scene_size, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, scene):
        out = self.l1(scene)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(3, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity

In [None]:
transform = transforms.Compose([
    # you can add other transformations in this list
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

class RenderDataset(Dataset):
    """Renderdataset."""
    def __init__(self, scenes_file, root_dir, transform=None):
        """
        Args:
            scenes_file (string): Path to the pkl file with scene description.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.scenes_file = pickle.load(open(scenes_file, 'rb'))
        self.root_dir = root_dir
        self.images = [image for image in os.listdir(self.root_dir) if image.endswith(('.png', '.jpg'))]
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.images[idx])
        image = Image.open(img_name)
        scene = self.scenes_file[idx][8:43]
        scene = torch.tensor(scene)
        

        sample = {'image': image, 'scene': scene}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
 
        return sample


In [None]:
b_size = 64

train_dataset = RenderDataset('images.pkl','root/images', transform=transform)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=b_size, shuffle=True, num_workers=4,
)


num_epochs = 300
learning_rate = 0.0002
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model = autoencoder()
generator = Generator()
generator.to(device)
discriminator = Discriminator()
discriminator.to(device)

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Losses contributions
adversarial_loss = nn.BCELoss()
adversarial_loss.to(device)
criterion = nn.MSELoss()
criterion.to(device)

for param in generator.parameters():
  print(len(param))

generator_optimizer = torch.optim.Adam(
    generator.parameters(), lr=learning_rate, 
    weight_decay=1e-5
)

discriminator_optimizer = torch.optim.Adam(
    discriminator.parameters(), lr=learning_rate, 
    weight_decay=1e-5
)




131072
131072
128
128
128
128
128
128
64
64
64
64
3
3


In [None]:
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), img_size, img_size, 3)
    return x

from torchvision.utils import save_image

cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

losses = []
for epoch in range(num_epochs):
    for data in train_loader:

        render = data['image']
        scene = data['scene']

        z = Variable(Tensor(np.random.normal(0, 1, (render.shape[0], z_dim))))
        
        # Adversarial ground truths
        valid = Variable(Tensor(render.shape[0], 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(render.shape[0], 1).fill_(0.0), requires_grad=False)

        # render = render.view(render.size(0), 3, img_size, img_size)
        render = Variable(render).to(device)
        scene = Variable(scene.type(Tensor)).to(device)

        
        # Generator
        generator_optimizer.zero_grad()
        generated = generator(scene)
        validity = discriminator(generated)
        generator_loss = adversarial_loss(validity, valid)
        # generator_loss = criterion(generated, render)

        losses.append(generator_loss)
        generator_loss.backward()
        generator_optimizer.step()

        Discriminator
        discriminator_optimizer.zero_grad()
        real_loss = adversarial_loss(discriminator(render), valid)
        fake_loss = adversarial_loss(discriminator(generated.detach()), fake)
        discriminator_loss = (real_loss + fake_loss) / 2
        discriminator_loss.backward()
        discriminator_optimizer.step()


    if epoch % 10 == 0:
        print('epoch [{}/{}], g_loss:{:.4f}'.format(epoch+1, num_epochs, 
                                                                  generator_loss.item()))
        pic = to_img(generated.cpu().data)
        save_image(pic, '/content/drive/My Drive/giraffe/train_dcgan/{}.png'.format(str(epoch+1)))
import matplotlib.pyplot as plt

plt.plot(losses)
plt.show()
torch.save(generator.state_dict(), '/content/drive/My Drive/giraffe/train/generator.pth')
torch.save(discriminator.state_dict(), '/content/drive/My Drive/giraffe/train/discriminator.pth')

KeyboardInterrupt: ignored