In [1]:
  
import errno
import os
import random
import shutil
import sys

import os
import io
from torch.utils.data import Dataset, DataLoader
import h5py
import numpy as np
import pdb
from PIL import Image
import torch
from torch.autograd import Variable
import pdb
import torch.nn.functional as F

import argparse
import os
import sys

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torchvision.datasets as dset
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from PIL import Image

  
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import itertools
import os
import time

from datetime import datetime

import numpy as np
import torch
import torchvision.utils as vutils

In [2]:
def to_np(var):
    """Exports torch.Tensor to Numpy array.
    """
    return var.detach().cpu().numpy()


def create_folder(folder_path):
    """Create a folder if it does not exist.
    """
    try:
        os.makedirs(folder_path)
    except OSError as _e:
        if _e.errno != errno.EEXIST:
            raise


def clear_folder(folder_path):
    """Clear all contents recursively if the folder exists.
    Create the folder if it has been accidently deleted.
    """
    create_folder(folder_path)
    for the_file in os.listdir(folder_path):
        _file_path = os.path.join(folder_path, the_file)
        try:
            if os.path.isfile(_file_path):
                os.unlink(_file_path)
            elif os.path.isdir(_file_path):
                shutil.rmtree(_file_path)
        except OSError as _e:
            print(_e)


class StdOut(object):
    """Redirect stdout to file, and print to console as well.
    """
    def __init__(self, output_file):
        self.terminal = sys.stdout
        self.log = open(output_file, "a")

    def write(self, message):
        self.terminal.write(message)
        self.terminal.flush()
        self.log.write(message)
        self.log.flush()

    def flush(self):
        self.terminal.flush()
        self.log.flush()


def boolean_string(s):
    if s not in {'False', 'True'}:
        raise ValueError('Not a valid boolean string')
    return s == 'True'

In [3]:
# This is a custom dataset class which exports the HDF5 
# Original dataset into a pytorch utils.data.Dataset here 

class Text2ImageDataset(Dataset):        # A subclass of the Pytorch dataset here 

    def __init__(self, datasetFile, transform=None, split=0):
        self.datasetFile = datasetFile
        self.transform = transform
        self.dataset = None
        self.dataset_keys = None
        self.split = 'train' if split == 0 else 'valid' if split == 1 else 'test'
        self.h5py2int = lambda x: int(np.array(x))

    def __len__(self):
        f = h5py.File(self.datasetFile, 'r')
        self.dataset_keys = [str(k) for k in f[self.split].keys()]
        length = len(f[self.split])
        f.close()

        return length

    def __getitem__(self, idx):
        if self.dataset is None:
            self.dataset = h5py.File(self.datasetFile, mode='r')
            self.dataset_keys = [str(k) for k in self.dataset[self.split].keys()]

        example_name = self.dataset_keys[idx]
        example = self.dataset[self.split][example_name]

        # pdb.set_trace()

        right_image = bytes(np.array(example['img']))
        right_embed = np.array(example['embeddings'], dtype=float)
        wrong_image = bytes(np.array(self.find_wrong_image(example['class'])))
        inter_embed = np.array(self.find_inter_embed())

        right_image = Image.open(io.BytesIO(right_image)).resize((64, 64))
        wrong_image = Image.open(io.BytesIO(wrong_image)).resize((64, 64))

        right_image = self.validate_image(right_image)
        wrong_image = self.validate_image(wrong_image)

        txt = np.array(example['txt']).astype(str)

        sample = {
                'right_images': torch.FloatTensor(right_image),
                'right_embed': torch.FloatTensor(right_embed),
                'wrong_images': torch.FloatTensor(wrong_image),
                'inter_embed': torch.FloatTensor(inter_embed),
                'txt': str(txt)
                 }

        sample['right_images'] = sample['right_images'].sub_(127.5).div_(127.5)
        sample['wrong_images'] =sample['wrong_images'].sub_(127.5).div_(127.5)

        return sample

    def find_wrong_image(self, category):
        idx = np.random.randint(len(self.dataset_keys))
        example_name = self.dataset_keys[idx]
        example = self.dataset[self.split][example_name]
        _category = example['class']

        if _category != category:
            return example['img']

        return self.find_wrong_image(category)

    def find_inter_embed(self):
        idx = np.random.randint(len(self.dataset_keys))
        example_name = self.dataset_keys[idx]
        example = self.dataset[self.split][example_name]
        return example['embeddings']


    def validate_image(self, img):
        img = np.array(img, dtype=float)
        if len(img.shape) < 3:
            rgb = np.empty((64, 64, 3), dtype=np.float32)
            rgb[:, :, 0] = img
            rgb[:, :, 1] = img
            rgb[:, :, 2] = img
            img = rgb

        return img.transpose(2, 0, 1)


In [4]:

class Generator(nn.Module):
    def __init__(self, channels, latent_dim=100, embed_dim=1024, embed_out_dim=128):
        super(Generator, self).__init__()
        self.channels = channels
        self.latent_dim = latent_dim
        self.embed_dim = embed_dim
        self.embed_out_dim = embed_out_dim

        self.text_embedding = nn.Sequential(
            nn.Linear(self.embed_dim, self.embed_out_dim),
            nn.BatchNorm1d(self.embed_out_dim),
            nn.LeakyReLU(0.2, inplace=True)
        )

        model = []
        model += self._create_layer(self.latent_dim + self.embed_out_dim, 512, 4, stride=1, padding=0)
        # The noise vector and the text encoded vectors are conatenated 
        model += self._create_layer(512, 256, 4, stride=2, padding=1)
        model += self._create_layer(256, 128, 4, stride=2, padding=1)
        model += self._create_layer(128, 64, 4, stride=2, padding=1)
        model += self._create_layer(64, self.channels, 4, stride=2, padding=1, output=True)

        self.model = nn.Sequential(*model)      # This does away with the requirement of us defining a forward path here 
        

    # Transpose Convolution Layers are present here :: to increase the filter activation map dimension 
    # The input to the generator would be a concatenation of the noise vector and the actual text vector 
    
    def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, output=False):
        layers = [nn.ConvTranspose2d(size_in, size_out, kernel_size, stride=stride, padding=padding, bias=False)]
        if output:
            layers.append(nn.Tanh())
        else:
            layers += [nn.BatchNorm2d(size_out),
                        nn.ReLU(True)]
        return layers

    def forward(self, noise, text):
        text = self.text_embedding(text)
        text = text.view(text.shape[0], text.shape[1], 1, 1)
        z = torch.cat([text, noise], 1)
        return self.model(z)

# The text vector is replicated in order to feed to the Discriminator here : 
class Embedding(nn.Module):
    def __init__(self, size_in, size_out):
        super(Embedding, self).__init__()
        self.text_embedding = nn.Sequential(
            nn.Linear(size_in, size_out),
            nn.BatchNorm1d(size_out),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x, text):
        embed_out = self.text_embedding(text)
        embed_out_resize = embed_out.repeat(4, 4, 1, 1).permute(2, 3, 0, 1)
        out = torch.cat([x, embed_out_resize], 1)
        return out

class Discriminator(nn.Module):
    def __init__(self, channels, embed_dim=1024, embed_out_dim=128):
        super(Discriminator, self).__init__()
        self.channels = channels
        self.embed_dim = embed_dim
        self.embed_out_dim = embed_out_dim

        self.model = nn.Sequential(
            *self._create_layer(self.channels, 64, 4, 2, 1, normalize=False),
            *self._create_layer(64, 128, 4, 2, 1),
            *self._create_layer(128, 256, 4, 2, 1),
            *self._create_layer(256, 512, 4, 2, 1)
        )
        self.text_embedding = Embedding(self.embed_dim, self.embed_out_dim)       # The Text is basically replicated here 
        # And then stacked with the feature maps here 
        self.output = nn.Sequential(
            nn.Conv2d(512 + self.embed_out_dim, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, normalize=True):
        layers = [nn.Conv2d(size_in, size_out, kernel_size=kernel_size, stride=stride, padding=padding)]
        if normalize:
            layers.append(nn.BatchNorm2d(size_out))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    def forward(self, x, text):
        x_out = self.model(x)       # Get a convolutional feature map 
        # The input x is either the real image or the fake image here 
        out = self.text_embedding(x_out, text)     # Do stacking here 
        out = self.output(out)                     # Pass through a convolution again and if required a ANN to get the outputs 
        return out.squeeze(), x_out

In [5]:

def _weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


class Model(object):
    def __init__(self,
                 name,
                 device,
                 data_loader,
                 channels,
                 l1_coef,
                 l2_coef):
        
        self.name = name
        self.device = device
        self.data_loader = data_loader
        self.channels = channels
        self.l1_coef = l1_coef
        self.l2_coef = l2_coef
        self.netG = Generator(self.channels)
        self.netG.apply(_weights_init)
        self.netG.to(self.device)
        self.netD = Discriminator(self.channels)
        self.netD.apply(_weights_init)
        self.netD.to(self.device)
        self.optim_G = None
        self.optim_D = None
        self.loss_adv = torch.nn.BCELoss()
        self.loss_l1 = torch.nn.L1Loss()
        self.loss_l2 = torch.nn.MSELoss()

    
# The property() method in Python provides an interface to instance attributes. It encapsulates instance attributes and provides a property, same as Java and C#. 
# The property() method takes the get, set and delete methods as arguments and returns an object of the property class.    
        
    @property
    def generator(self):
        return self.netG

    @property
    def discriminator(self):
        return self.netD

    def create_optim(self, lr, alpha=0.5, beta=0.999):
        self.optim_G = torch.optim.Adam(self.netG.parameters(),
                                        lr=lr, betas=(alpha, beta))
        
        self.optim_D = torch.optim.Adam(self.netD.parameters(),
                                        lr=lr, betas=(alpha, beta))

    def train(self,
              epochs,
              log_interval=100,
              out_dir='',
              verbose=True):
        
        self.netG.train()
        self.netD.train()          # In train mode 
        total_time = time.time()
        for epoch in range(epochs):
            batch_time = time.time()
            for batch_idx, data in enumerate(self.data_loader):
                image = data['right_images'].to(self.device)       # Get only the correct images here 
                embed = data['right_embed'].to(self.device)

                real_label = torch.ones((image.shape[0]), device=self.device)
                fake_label = torch.zeros((image.shape[0]), device=self.device)

                # Train D
                self.optim_D.zero_grad()

                out_real, _ = self.netD(image, embed)
                loss_d_real = self.loss_adv(out_real, real_label)         # Adverserial Loss for the Disc.

                noise = torch.randn((image.shape[0], 100, 1, 1), device=self.device)
                image_fake = self.netG(noise, embed)
                out_fake, _ = self.netD(image_fake, embed)
                loss_d_fake = self.loss_adv(out_fake, fake_label)           # Adverserial Loss for the Disc.

                d_loss = loss_d_real + loss_d_fake
                d_loss.backward()
                self.optim_D.step()                  # Train the Disc. 

                # Train G
                self.optim_G.zero_grad()
                noise = torch.randn((image.shape[0], 100, 1, 1), device=self.device)
                image_fake = self.netG(noise, embed)
                out_fake, act_fake = self.netD(image_fake, embed)   # Activation Map and the Fake Output here 
                _, act_real = self.netD(image, embed)               # 

                l1_loss = self.loss_l1(torch.mean(act_fake, 0), torch.mean(act_real, 0).detach())
                # L1 Loss between the Activation Maps mean 
                
                g_loss = self.loss_adv(out_fake, real_label) + \
                    self.l1_coef * l1_loss + \
                    self.l2_coef * self.loss_l2(image_fake, image) # Okay 1 adverserial loss term 
                    # One L1 loss term and one L2 loss term here 
                    
                g_loss.backward()
                self.optim_G.step()

                if verbose and batch_idx % log_interval == 0 and batch_idx > 0:
                    print('Epoch {} [{}/{}] loss_D: {:.4f} loss_G: {:.4f} time: {:.2f}'.format(
                          epoch, batch_idx, len(self.data_loader),
                          d_loss.mean().item(),
                          g_loss.mean().item(),
                          time.time() - batch_time))
                    with torch.no_grad():
                        viz_sample = torch.cat((image[:32], image_fake[:32]), 0)
                        vutils.save_image(viz_sample,
                                          os.path.join(out_dir, 'samples_{}_{}.png'.format(epoch, batch_idx)),
                                          nrow=8,
                                          normalize=True)
                    batch_time = time.time()

            self.save_to(path=out_dir, name=self.name, verbose=False)
        if verbose:
            print('Total train time: {:.2f}'.format(time.time() - total_time))

    def eval(self,
             batch_size=None):
        self.netG.eval()
        self.netD.eval()
        if batch_size is None:
            batch_size = self.data_loader.batch_size

        with torch.no_grad():                               # In evaluation mode here :: 
            for batch_idx, data in enumerate(self.data_loader):
                image = data['right_images'].to(self.device)[:batch_size]
                embed = data['right_embed'].to(self.device)[:batch_size]
                text = data['txt'][:batch_size]
                noise = torch.randn((image.shape[0], 100, 1, 1), device=self.device)
                viz_sample = self.netG(noise, embed)
                vutils.save_image(viz_sample,
                                  'img_{}.png'.format(batch_idx),
                                  nrow=batch_size//8,
                                  normalize=True)
                for t in text:
                    print(t)
                break

    def save_to(self,
                path='',
                name=None,
                verbose=True):
        if name is None:
            name = self.name
        if verbose:
            print('\nSaving models to {}_G.pt and {}_D.pt ...'.format(name, name))
        torch.save(self.netG, os.path.join(path, '{}_G.pt'.format(name)))
        torch.save(self.netD, os.path.join(path, '{}_D.pt'.format(name)))

    def load_from(self,
                  path='',
                  name=None,
                  verbose=True):
        if name is None:
            name = self.name
        if verbose:
            print('\nLoading models from {}_G.pt and {}_D.pt ...'.format(name, name))
        ckpt_G = torch.load(os.path.join(path, '{}_G.pt'.format(name)))
        if isinstance(ckpt_G, dict) and 'state_dict' in ckpt_G:
            self.netG.load_state_dict(ckpt_G['state_dict'], strict=True)
        elif isinstance(ckpt_G, torch.nn.Module):
            self.netG = ckpt_G
        else:
            self.netG.load_state_dict(ckpt_G, strict=True)
        ckpt_D = torch.load(os.path.join(path, '{}_D.pt'.format(name)))
        if isinstance(ckpt_D, dict) and 'state_dict' in ckpt_D:
            self.netD.load_state_dict(ckpt_D['state_dict'], strict=True)
        elif isinstance(ckpt_D, torch.nn.Module):
            self.netD = ckpt_D
        else:
            self.netD.load_state_dict(ckpt_D, strict=True)
            
            
# A checkpoint is an intermediate dump of a model's entire internal state (its weights, current learning rate, etc.) 
# so that the framework can resume the training from this point whenever desired.

In [None]:

def main():
#     device = torch.device("cuda:0" if FLAGS.cuda else "cpu")
    device = "cpu"

    print('Loading data...\n')
    dataloader = DataLoader(Text2ImageDataset(os.path.join(FLAGS.data_dir, '{}.hdf5'.format(FLAGS.dataset)), split=0),
                            batch_size=FLAGS.batch_size, shuffle=True, num_workers=8)

    print('Creating model...\n')
    model = Model(FLAGS.model, device, dataloader, FLAGS.channels, FLAGS.l1_coef, FLAGS.l2_coef)

    if FLAGS.train:
        model.create_optim(FLAGS.lr)

        print('Training...\n')
        model.train(FLAGS.epochs, FLAGS.log_interval, FLAGS.out_dir, True)

        model.save_to('')
    else:
        model.load_from('')

        print('Evaluating...\n')
        model.eval(batch_size=64)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--model', type=str, default='text2image', help='text2image')
    parser.add_argument('--cuda', default=True, help='enable CUDA.')
    parser.add_argument('--train', default=True, help='train mode or eval mode.')
    parser.add_argument('--data_dir', type=str, default='D:\\Text-to-Image-GAN\\', help='Directory for dataset.')
    parser.add_argument('--dataset', type=str, default='birds', help='Dataset name.')
    parser.add_argument('--out_dir', type=str, default='D:\\Text-to-Image-GAN\\output\\', help='Directory for output.')
    parser.add_argument('--epochs', type=int, default=200, help='number of epochs')
    parser.add_argument('--batch_size', type=int, default=256, help='size of batches in training')
    parser.add_argument('--lr', type=float, default=0.0002, help='learning rate')
    parser.add_argument('--channels', type=int, default=3, help='number of image channels')
    parser.add_argument('--l1_coef', type=float, default=50, help='l1 coefficient')
    parser.add_argument('--l2_coef', type=float, default=100, help='l2 coefficient')
    parser.add_argument('--log_interval', type=int, default=20, help='interval between logging and image sampling')
    parser.add_argument('--seed', type=int, default=1, help='random seed')

    FLAGS, unknown = parser.parse_known_args()
    FLAGS.cuda = FLAGS.cuda and torch.cuda.is_available()

    if FLAGS.seed is not None:
        torch.manual_seed(FLAGS.seed)
        if FLAGS.cuda:
            torch.cuda.manual_seed(FLAGS.seed)
        np.random.seed(FLAGS.seed)

    cudnn.benchmark = True

    create_folder(FLAGS.out_dir)
    if FLAGS.train:
        clear_folder(FLAGS.out_dir)

    log_file = os.path.join(FLAGS.out_dir, 'log.txt')
    print("Logging to {}\n".format(log_file))
    sys.stdout = StdOut(log_file)

    print("PyTorch version: {}".format(torch.__version__))
    print("CUDA version: {}\n".format(torch.version.cuda))

    print(" " * 9 + "Args" + " " * 9 + "|    " + "Type" + \
          "    |    " + "Value")
    print("-" * 50)
    for arg in vars(FLAGS):
        arg_str = str(arg)
        var_str = str(getattr(FLAGS, arg))
        type_str = str(type(getattr(FLAGS, arg)).__name__)
        print("  " + arg_str + " " * (20-len(arg_str)) + "|" + \
              "  " + type_str + " " * (10-len(type_str)) + "|" + \
              "  " + var_str)

    main()

Logging to D:\Text-to-Image-GAN\output\log.txt

PyTorch version: 1.9.0
CUDA version: 10.2

         Args         |    Type    |    Value
--------------------------------------------------
  model               |  str       |  text2image
  cuda                |  bool      |  True
  train               |  bool      |  True
  data_dir            |  str       |  D:\Text-to-Image-GAN\
  dataset             |  str       |  birds
  out_dir             |  str       |  D:\Text-to-Image-GAN\output\
  epochs              |  int       |  200
  batch_size          |  int       |  256
  lr                  |  float     |  0.0002
  channels            |  int       |  3
  l1_coef             |  int       |  50
  l2_coef             |  int       |  100
  log_interval        |  int       |  20
  seed                |  int       |  1
Loading data...

Creating model...

Training...



In [None]:

class Generator_modified(nn.Module):
    def __init__(self, channels, latent_dim=100, embed_dim=1024, embed_out_dim=128, text_stack = 10):
        super(Generator, self).__init__()
        self.channels = channels
        self.latent_dim = latent_dim
        self.embed_dim = embed_dim
        self.embed_out_dim = embed_out_dim
        
        self.text_stack = text_stack               # We are also stacking the text instead of just concatenating here 

        self.text_embedding = nn.Sequential(
            nn.Linear(self.embed_dim, self.embed_out_dim),
            nn.BatchNorm1d(self.embed_out_dim),
            nn.LeakyReLU(0.2, inplace=True)
        )

        model = []
        model += self._create_layer(self.latent_dim + self.embed_out_dim, 512, 4, stride=1, padding=0)
        # The noise vector and the text encoded vectors are conatenated 
        model += self._create_layer(512, 256 + self.text_stack, 4, stride=2, padding=1)
        model += self._create_layer(256, 128 + self.text_stack, 4, stride=2, padding=1)
        model += self._create_layer(128, 64 + self.text_stack, 4, stride=2, padding=1)
        model += self._create_layer(64 + self.text_stack, self.channels, 4, stride=2, padding=1, output=True)

        self.model = nn.Sequential(*model)      # This does away with the requirement of us defining a forward path here 
        

    # Transpose Convolution Layers are present here :: to increase the filter activation map dimension 
    # The input to the generator would be a concatenation of the noise vector and the actual text vector 
    
    def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, output=False):
        layers = [nn.ConvTranspose2d(size_in, size_out, kernel_size, stride=stride, padding=padding, bias=False)]
        if output:
            layers.append(nn.Tanh())
        else:
            layers += [nn.BatchNorm2d(size_out),
                        nn.ReLU(True)]
        return layers

    def forward(self, noise, text):
        text = self.text_embedding(text)
        text = text.view(text.shape[0], text.shape[1], 1, 1)
        z = torch.cat([text, noise], 1)
        return self.model(z)

# The text vector is replicated in order to feed to the Discriminator here : 
class Embedding(nn.Module):
    def __init__(self, size_in, size_out):
        super(Embedding, self).__init__()
        self.text_embedding = nn.Sequential(
            nn.Linear(size_in, size_out),
            nn.BatchNorm1d(size_out),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x, text):
        embed_out = self.text_embedding(text)
        embed_out_resize = embed_out.repeat(4, 4, 1, 1).permute(2, 3, 0, 1)
        out = torch.cat([x, embed_out_resize], 1)
        return out

