In [3]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
!cp /content/drive/My\ Drive/giraffe/root.zip root.zip
!cp /content/drive/My\ Drive/giraffe/images.pkl images.pkl
!unzip root.zip > ziplog.txt


In [0]:
import torch
from torch import nn
from torch.utils.data import random_split, Dataset, DataLoader
from torchvision import transforms
from torch.autograd import Variable
import torchvision
from torchvision.datasets import ImageFolder
import numpy as np
from PIL import Image
import os
import pickle

In [0]:
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=3, padding=1),  # b, 16, 10, 10
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),  # b, 16, 5, 5
            nn.Conv2d(16, 8, 3, stride=2, padding=1),  # b, 8, 3, 3
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=1)  # b, 8, 2, 2
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, 3, stride=2),  # b, 16, 5, 5
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1),  # b, 8, 15, 15
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 3, 6, stride=2, padding=1),  # b, 1, 28, 28
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [0]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


class Generator(nn.Module):
    def __init__(self, img_size, scene_size):
        super(Generator, self).__init__()

        self.init_size = img_size // 4
        self.l1 = nn.Sequential(nn.Linear(scene_size, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

In [0]:
transform = transforms.Compose([
    # you can add other transformations in this list
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

transform_scene = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5)),
])

class RenderDataset(Dataset):
    """Renderdataset."""
    def __init__(self, scenes_file, root_dir, transform=None):
        """
        Args:
            scenes_file (string): Path to the pkl file with scene description.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.scenes_file = pickle.load(open(scenes_file, 'rb'))
        self.root_dir = root_dir
        self.images = [image for image in os.listdir(self.root_dir) if image.endswith(('.png', '.jpg'))]
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.images[idx])
        image = Image.open(img_name)
        scene = self.scenes_file[idx]
        scene = torch.tensor(scene)

        sample = {'image': image, 'scene': scene}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
 
        return sample


In [0]:
b_size = 64
img_size = 128
scene_size = 57

train_dataset = RenderDataset('images.pkl','root/images', transform=transform)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=b_size, shuffle=False, num_workers=4,
)


num_epochs = 300
learning_rate = 1e-3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model = autoencoder()
model = Generator(img_size, scene_size)
model.to(device)
# Initialize weights
model.apply(weights_init_normal)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate, 
    weight_decay=1e-5
)


In [0]:
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 3, img_size, img_size)
    return x

from torchvision.utils import save_image

losses = []
for epoch in range(num_epochs):
    for data in train_loader:

        render = data['image']
        scene = data['scene']

        # if epoch == 0:
          # print(render[0])
          # img = Image.fromarray(render[0])
          # img.show()
          # print(scene[0])


        render = render.view(render.size(0), 3, img_size, img_size)
        render = Variable(render).to(device)
        scene = Variable(scene).to(device)
        

        optimizer.zero_grad()

        output = model(scene)

        loss = criterion(output, render)
        losses.append(loss)

        loss.backward()
        optimizer.step()
    
    if epoch % 10 == 0:
        print('epoch [{}/{}], loss:{:.4f}'.format(epoch+1, num_epochs, loss.data))
        pic = to_img(output.cpu().data)
        save_image(pic, '/content/drive/My Drive/giraffe/train_dcgan/{}.png'.format(str(epoch+1)))
import matplotlib.pyplot as plt

plt.plot(losses)
plt.show()
torch.save(model.state_dict(), '/content/drive/My Drive/giraffe/train/checkpoint.pth')

epoch [1/300], loss:0.2165
epoch [11/300], loss:0.2170
epoch [21/300], loss:0.2157
