In [30]:
import argparse
import os
import numpy as np
import math
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

batch_size = 64
latent_dim = 100
n_classes = 10
img_shape = (1, 32, 32)

In [31]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(n_classes, n_classes)

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim + n_classes, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # Concatenate label embedding and image to produce input
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), *img_shape)
        return img

In [108]:
generator = Generator()
generator.load_state_dict(torch.load("models/G-140.model"))
generator.eval()

Generator(
  (label_emb): Embedding(10, 10)
  (model): Sequential(
    (0): Linear(in_features=110, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): BatchNorm1d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Linear(in_features=256, out_features=512, bias=True)
    (6): BatchNorm1d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace)
    (8): Linear(in_features=512, out_features=1024, bias=True)
    (9): BatchNorm1d(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace)
    (11): Linear(in_features=1024, out_features=1024, bias=True)
    (12): Tanh()
  )
)

In [109]:
def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = Variable(torch.FloatTensor(np.random.normal(0, 1, (n_row, latent_dim))))
    # Get labels ranging from 0 to n_classes for n rows  
    #labels = np.array([n_row for _ in range(n_row) for num in range(1, 1 + n_row)])
    labels = np.array([7 for _ in range(n_row)])
    labels = Variable(torch.LongTensor(labels))
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, "./test.png", nrow=n_row, normalize=True)

In [110]:
sample_image(6, 10)

In [115]:
batch_size = 64
for i in range(10):
    z = Variable(torch.FloatTensor(np.random.normal(0, 1, (batch_size, latent_dim))))
    gen_labels = Variable(torch.LongTensor(np.random.randint(i, i + 1, batch_size)))
    gen_imgs = generator(z, gen_labels)
    
    #print (gen_labels)
    save_image(gen_imgs.data, "./%d.png" % (i), nrow=8, normalize=True)