In [1]:
import torch
from torch import nn
from torchvision import transforms
from torchvision.datasets import CelebA
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np

In [4]:
def show_tensor_images(image_tensor,num_images = 25,size = (1,28,28),rows = 3,show = True):
    image_tensor = (image_tensor+1)/2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images],rows)
    plt.imshow(image_grid.permute(1,2,0).squeeze())
    if show:
        plt.show()

In [3]:
class Generator(nn.Module):
    def __init__(self,im_chan = 3,input_dim = 100,n_classes = 40):
        self.im_chan = im_chan
        self.input_dim = input_dim
        self.n_classes = n_classes
        self.gen = nn.Sequential(
            self.gen_block(input_dim+n_classes,512,4,1,0),
            self.gen_block(512,256,4),
            self.gen_block(256,128,4),
            self.gen_block(128,64,4),
            self.gen_block(64,im_chan,4,final = True)
        )

    def gen_block(self,input_channels,output_channels,kernel_size = 3,stride = 2,padding = 1,final = False):
        if not final:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels,output_channels,kernel_size,stride,padding),
                nn.BatchNorm2d(output_channels),
                nn.ReLU()
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels,output_channels,kernel_size,stride,padding),
                nn.Tanh()
            )
    def forward(self,noise):
        x = noise.view(len(noise),self.input_dim,1,1)
        return self.gen(x)
    def get_noise(n_samples,input_dim,device = 'cpu'):
        return torch.randn(n_samples,input_dim,device)

In [2]:
class Discriminator(nn.Module):
    def __init__(self,im_channels = 3,n_classes = 40):
        super(Discriminator,self).__init__()
        self.im_channels = im_channels
        self.n_classes = n_classes
        self.disc = nn.Sequential(
            self.disc_block(im_channels+n_classes,64),
            self.disc_block(64,128),
            self.disc_block(128,1,final= True)
        )


    def disc_block(self,input_channels,output_channels,kernel_size = 3,stride = 2,padding = 1,final = False):
        if not final:
            return nn.Sequential(
                nn.Conv2d(input_channels,output_channels,kernel_size,stride,padding),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2,inplace = True)
            )
        else :
            return nn.Sequential(
                nn.Conv2d(input_channels,output_channels,kernel_size,stride,padding),
            )
    def forward(self,x):
        disc_pred = self.disc(x)
        return disc_pred.view(len(disc_pred),-1)


In [5]:
import torch.nn.functional as f
def get_one_hot(labels,n_classes):
    return f.one_hot(labels,n_classes)

In [6]:
def combine_vectors(x,y):
    combined = torch.cat((x.float(),y.float()),dim = -1)
    return combined

In [7]:
celebA_shape = (218,178,3)
n_classes = 40

In [10]:
loss = nn.BCEWithLogitsLoss
input_dim = z_dim = 64
display_step = 100
batch_size = 32
device = 'cuda'
transform = transforms.Compose([
    transforms.CenterCrop(178),
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [11]:
dataloader = DataLoader(
    CelebA('.', download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True)

FileURLRetrievalError: Failed to retrieve file url:

	Too many users have viewed or downloaded this file recently. Please
	try accessing the file again later. If the file you are trying to
	access is particularly large or is shared with many people, it may
	take up to 24 hours to be able to view or download the file. If you
	still can't access a file after 24 hours, contact your domain
	administrator.

You may still be able to access the file from the browser:

	https://drive.google.com/uc?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM

but Gdown can't. Please check connections and permissions.