In [None]:
import sys
sys.path.append('/common/users/ppk31/CS543_DL_Proj')

##### import all libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
from PIL import Image
import torchvision.utils as vutils
from torch.autograd import Variable
from pytorch_model_summary import summary
from IPython.display import clear_output
import matplotlib.pyplot as plt
import shutil
from utils import (load_data, save_img_results, load_from_checkpoint)

torch.cuda.empty_cache()
from torch.utils.data import DataLoader
import random
import os
import pickle
import argparse

##### set cuda

In [None]:
print(f"using GPU: {torch.cuda.is_available()}")
gpus = list(range(torch.cuda.device_count()))
print(f"GPU ids: {gpus}")

# torch.random.seed()
# torch.manual_seed(0)

torch.cuda.set_device(gpus[1])
cudnn.benchmark=True

#### define tokenizer

In [None]:
def get_tokenizer(text_encoder):
    print(f"using {text_encoder} as text encoder")
    if text_encoder == "distilbert-base-uncased":
        return DistilBertTokenizer.from_pretrained(text_encoder)
    elif text_encoder == "openai/clip-vit-base-patch32":
        return CLIPTokenizer.from_pretrained(text_encoder)
    elif text_encoder == "text-cnn-rnn":
        with open("train_test_split/text-cnn-rnn/train_embeddings.pkl", "rb") as f:
            tokenizer = pickle.load(f)
        return tokenizer

class TextEncoder(nn.Module):
    def __init__(self, text_encoder, pretrained=True):
        super(TextEncoder, self).__init__()
        self.text_encoder = text_encoder
        if text_encoder == "distilbert-base-uncased":
            self.encoder = DistilBertModel.from_pretrained(text_encoder)
        elif text_encoder == "openai/clip-vit-base-patch32":
            self.encoder = CLIPModel.from_pretrained(text_encoder)
        self.retrieve_token_index = 0
    
    def forward(self, input_tokens, attention_mask):
        if self.text_encoder == "distilbert-base-uncased":
            out = self.encoder(input_ids = input_tokens, attention_mask = attention_mask)
            last_hidden_states = out.last_hidden_state
            embeddings = last_hidden_states[:, self.retrieve_token_index, :]    # output_dimensions = 768
        elif self.text_encoder == "openai/clip-vit-base-patch32":
            embeddings = self.encoder.get_text_features(input_ids = input_tokens, attention_mask = attention_mask) # output_dimensions = 512
        return embeddings

#### Augmented Projection Block

In [None]:
class Augmented_Projection(nn.Module):
    def __init__(self, stage, gen_channels, gen_dim):
        super(Augmented_Projection, self).__init__()
        self.stage = stage
        self.t_dim = config.text_dim
        self.c_dim = config.condition_dim
        self.z_dim = config.z_dim
        self.gen_in = gen_channels #config.generator_dim * gen_dim
        self.fc = nn.Linear(self.t_dim, self.c_dim * 2)
        self.relu = nn.ReLU()
        if stage == 1:
            self.project = nn.Sequential(
                nn.Linear(self.c_dim + self.z_dim, self.gen_in * gen_dim * gen_dim, bias=False), # bias=False, # 768 -> 192*8*8*8
                nn.BatchNorm1d(self.gen_in * gen_dim * gen_dim),
                nn.ReLU()
            )

    def augment(self, mu, logvar):
        std = logvar.mul(0.5).exp()
        eps = Variable(torch.randn(std.size()).float().cuda())
        return mu + (std * eps)

    def forward(self, text_embedding, noise=None):
        if noise is None and self.stage==1:
            noise = torch.randn((text_embedding.shape[0], self.z_dim)).float().cuda()
        x = self.relu(self.fc(text_embedding))
        mu = x[:, :self.c_dim]
        logvar = x[:, self.c_dim:]
        c_code = self.augment(mu, logvar)
        
        if self.stage == 1:
            c_code = torch.cat((c_code, noise), dim=1)
            c_code = self.project(c_code)
        
        return c_code, mu, logvar

#### Downsampling and Upsampling Block for Generator

In [None]:
class Downsample(nn.Module):
    """
    A downsampling layer with an optional convolution.

    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining if a convolution is applied.
    """

    def __init__(self, channels, out_channels=None, kernel_size=4, stride=2, padding=1, batch_norm=True, activation=True, use_conv=True, bias=False):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.batch_norm = batch_norm
        self.activation = activation
        if use_conv:
            self.op = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
        else:
            assert self.channels == self.out_channels
            self.op = nn.AvgPool2d(kernel_size=stride, stride=stride)
        if batch_norm:
            self.batchnorm = nn.BatchNorm2d(out_channels)
        if activation:
            self.activtn = nn.LeakyReLU(0.2)

    def forward(self, x):
        assert x.shape[1] == self.channels
        x = self.op(x)
        if self.batch_norm:
            x = self.batchnorm(x)
        if self.activation:
            x = self.activtn(x)
        return x

class Upsample(nn.Module):
    """
    An upsampling layer with an optional convolution.

    :param channels: channels in the inputs and outputs.
    """

    def __init__(self, channels, out_channels=None, stride=1, padding=1, batch_norm=True, activation=True, bias=False, use_deconv=False, dropout=False):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.batch_norm = batch_norm
        self.activation = activation
        self.dropout = dropout
        self.use_deconv = use_deconv

        if use_deconv:
            self.deconv = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=4, stride=2, padding=padding, bias=bias) # use when not using interpolate
        else:
            self.conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=3, stride=stride, padding=padding, bias = bias)
        if batch_norm:
            self.batchnorm = nn.BatchNorm2d(out_channels)
        if activation:
            self.activtn = nn.ReLU()
        if self.dropout:
            self.drop = nn.Dropout2d(0.5)
    
    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.use_deconv:
            x = self.deconv(x)
        else:
            x = F.interpolate(x, scale_factor=2, mode="nearest")
            x = self.conv(x)
        if self.batch_norm:
            x = self.batchnorm(x)
        if self.activation:
            x = self.activtn(x)
        if self.dropout:
            x = self.drop(x)
        return x

#### Residual Layer Block

In [None]:
class ResBlock(nn.Module):
    """
    A residual block that can optionally change the number of channels.

    :param in_channels: the number of input channels.
    :param out_channels: if specified, the number of out channels.
    """
    def __init__(
        self,
        in_channels,
        out_channels=None,
        stride = 1,
        padding = 1
    ):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=padding)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=padding)
        if in_channels == out_channels:
                self.x_residual = nn.Identity()
        else:
            self.x_residual = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU()
    
    def forward(self, x):
        g = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
        x = self.x_residual(x)
        h = x + g
        return self.relu(h)

#### Stage-I Generator Block

In [None]:
class Generator1(nn.Module):
    def __init__(self, stage):
        super(Generator1, self).__init__()
        self.stage = stage
        self.in_dims = config.in_dims # 4
        self.in_channels = config.generator_dim * 8 # 192*8
        self.channel_mul = config.channel_mul
        self.num_resblocks = config.n_resblocks
        self.use_deconv=config.use_deconv 
        self.dropout=config.dropout
        ch = self.in_channels
        
        self.c_dim = config.condition_dim
        n_heads =  config.attention_heads
        attention_resolutions = config.attention_resolutions
        dims = self.in_dims

        self.aug_project = Augmented_Projection(self.stage, self.in_channels, self.in_dims)

        self.blocks = nn.ModuleList()
        for layer, cmul in enumerate(self.channel_mul):

            for _ in range(self.num_resblocks[layer]): # n_resblocks in stage2 = 2
                self.blocks.append(ResBlock(ch//cmul, ch//cmul, stride=1, padding=1))
            
            if layer < len(self.channel_mul)-1:
                self.blocks.append(Upsample(ch//cmul, ch//self.channel_mul[layer+1], use_deconv=self.use_deconv, dropout=self.dropout))
            
            dims *= 2
        
        self.out = nn.Sequential(
            nn.Conv2d(ch//self.channel_mul[-1], 3, kernel_size=3, padding=1, bias=True),
            nn.Tanh()
        )

    def forward(self, text_embedding, noise=None):
        proj_x, mu, logvar = self.aug_project(text_embedding, noise)
        x = proj_x.view(-1, self.in_channels, self.in_dims, self.in_dims)

        for up in self.blocks:
            x = up(x)
        img_out = self.out(x)
        return img_out, mu, logvar

#### Stage-II Generator Block

In [None]:
class Generator2(nn.Module):
    def __init__(self, stage):
        super(Generator2, self).__init__()
        self.stage = stage
        self.in_dims = config.in_dims * config.in_dims # 16
        self.in_channels = config.generator_dim # 192
        self.channel_mul = config.channel_mul_stage2
        self.num_resblocks = config.n_resblocks_stage2
        self.use_deconv=config.use_deconv2 
        self.dropout=config.dropout2
        ch = self.in_channels * 4 
        
        self.c_dim = config.condition_dim
        n_heads =  config.attention_heads
        attention_resolutions = config.attention_resolutions
        dims = self.in_dims

        self.aug_project = Augmented_Projection(self.stage, self.in_channels, self.in_dims)
        
        self.downblocks= nn.Sequential(
            Downsample(3, self.in_channels, kernel_size=3, stride=1, padding=1, batch_norm=False),
            Downsample(self.in_channels, self.in_channels*2),
            Downsample(self.in_channels*2, self.in_channels*4)
        )
        self.combined = nn.Sequential(
            Downsample(self.in_channels*4 + self.c_dim, self.in_channels*4, kernel_size=3, stride=1, padding=1) # 768 x 16 x 16
        )
            
        self.blocks = nn.ModuleList()
        for layer, cmul in enumerate(self.channel_mul):

            for _ in range(self.num_resblocks[layer]): # n_resblocks in stage2 = 2
                self.blocks.append(ResBlock(ch//cmul, ch//cmul, stride=1, padding=1))
            
            if layer < len(self.channel_mul)-1:
                self.blocks.append(Upsample(ch//cmul, ch//self.channel_mul[layer+1], use_deconv=self.use_deconv, dropout=self.dropout if layer<2 else False))
            
            dims *= 2
        
        self.out = nn.Sequential(
            nn.Conv2d(ch//self.channel_mul[-1], 3, kernel_size=3, padding=1, bias=True),
            nn.Tanh()
        )
        
    def forward(self, text_embedding, stage1_out):
        enc_img = self.downblocks(stage1_out)
        
        proj_x, mu, logvar = self.aug_project(text_embedding)
        x = proj_x.view(-1, self.c_dim, 1, 1)
        x = x.repeat(1, 1, self.in_dims, self.in_dims)
        x = torch.cat([enc_img, x], dim=1)
        x = self.combined(x)

        for up in self.blocks:
            x = up(x)
        img_out = self.out(x)
        return img_out, mu, logvar

#### Retrieve text embeddings for given prompt

In [None]:
def get_text_embeddings(prompt, tokenizer, encoder):
    # captions_dict = {'input_ids': [list of captions vector], 'attention_mask': [list of attention_mask]}
    captions_dict = tokenizer(prompt, padding='max_length', truncation=True, max_length=77, return_tensors="pt")
    text_encoder = TextEncoder(encoder, pretrained=True)
    text_encoder.eval()
    with torch.no_grad():
        text_embeddings = text_encoder(captions_dict['input_ids'], captions_dict['attention_mask'])
    return [text_embeddings.squeeze(0)]

#### Arguments to the code

In [None]:
parser = argparse.ArgumentParser(description="Generate Text to Image arguments")
parser.add_argument('--text_encoder', required=True, type=str, default=None, help="Which text encoder to use, distilbert-base-uncased or openai/clip-vit-base-patch32 or text-cnn-rnn")
parser.add_argument('--g1', type=str, required=True, default=None, help="Generator 1 Path")
parser.add_argument('--g2', type=str, required=True, default=None, help="Generator 2 Path")
parser.add_argument('--prompt', type=str, default="this is a large dark grey bird with a large beak.")
parser.add_argument('--n_images', type=int, default=1, help="number of images to generate for given prompt, max_allowed=6")
parser.add_argument('--test_dataset', action="store_true", help="in case of text-cnn-rnn, provide this flag")
parser.add_argument('--out_path', type=str, default="valid_results", help="in case of text-cnn-rnn, provide this flag")

#### if using char-CNN-RNN

In [None]:
# text-cnn-rnn
args = parser.parse_args(['--text_encoder', 'text-cnn-rnn',
                          '--g1', 'text-cnn-rnn/out_I/checkpoint_s1_ls/netG1_epoch_400.pth',
                          '--g2', 'text-cnn-rnn/out_I/checkpoint_s2_ls/netG2_epoch_80.pth',
                          '--test_dataset'])
os.makedirs(args.out_path, exist_ok=True)

#### if using DistilBERT

In [None]:
# bert
args = parser.parse_args(['--text_encoder', 'distilbert-base-uncased',
                          '--g1', 'bert/bert_out_V/checkpoint_s1_bert_ls/netG1_epoch_500.pth',
                          '--g2', 'bert/bert_out_V/checkpoint_s2_bert_ls/netG2_epoch_100.pth',
                          '--prompt', 'the bird has a small orange bill that has a black tip.'])
os.makedirs(args.out_path, exist_ok=True)

#### if using CLIP

In [None]:
# clip
args = parser.parse_args(['--text_encoder', 'openai/clip-vit-base-patch32',
                          '--g1', 'clip/clip_out_V/checkpoint_s1_clip_ls/netG1_epoch_460.pth',
                          '--g2', 'clip/clip_out_V/checkpoint_s2_clip_onels_1/netG2_epoch_80.pth',
                          '--prompt', 'the bird has a small orange bill that has a black tip.'])
os.makedirs(args.out_path, exist_ok=True)

#### Load text encoders

In [None]:
if args.text_encoder == "distilbert-base-uncased":
    from configs import config
    from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
elif args.text_encoder == "openai/clip-vit-base-patch32":
    from configs import config2 as config
    from transformers import CLIPTokenizer, CLIPModel, CLIPProcessor
else:
    from configs import config3 as config
    with open("train_test_split/text-cnn-rnn/test_captions.pkl", 'rb') as f:
        captions = pickle.load(f)

#### Get text embeddings

In [None]:
tokenizer = get_tokenizer(args.text_encoder)

prompt = None
text_embeddings = []

# get text_embeddings
if args.text_encoder == "text-cnn-rnn":
    caption_idx = np.random.randint(0, len(captions), args.n_images)
    prompt = []
    for i, index in enumerate(caption_idx):
        caption_list = captions[index]
        idx = random.randint(0, len(caption_list)-1)
        caption = caption_list[idx]
        embedding_list = tokenizer[index]
        embedding = embedding_list[idx]
        prompt.append(caption)
        text_embeddings.append(torch.tensor(embedding))
else:
    prompt = [args.prompt] * args.n_images
    text_embeddings = get_text_embeddings(args.prompt, tokenizer, args.text_encoder)
    text_embeddings = text_embeddings * args.n_images

text_embeddings = torch.stack(text_embeddings, dim=0).float().cuda()
print(text_embeddings.shape)
assert args.n_images == text_embeddings.size(0), f"No. of text embeddings: {text_embeddings.size(0)} different from number of images: {args.n_images} to be generated"
print(prompt)

#### load Stage-I and Stage-II Generators

In [None]:
gen1 = load_from_checkpoint(Generator1(stage=1), args.g1)
gen2 = load_from_checkpoint(Generator2(stage=2), args.g2)
gen1.float().cuda()
gen2.float().cuda()
gen1.eval()
gen2.eval()
clear_output()

#### Generate images

In [None]:
with torch.no_grad():
    noise = torch.randn(args.n_images, 100).float().cuda()
    low_res, _, _ = gen1(text_embeddings, noise)
    out, _, _ = gen2(text_embeddings, low_res)

low_res = low_res.cpu().data
out = out.cpu().data

#### save generated images

In [None]:
next_counter = len(os.listdir(args.out_path))
vutils.save_image(low_res, '%s/generated_sample_s1_%03d.png' % (args.out_path, next_counter), normalize=True)
vutils.save_image(out, '%s/generated_sample_s2_%03d.png' % (args.out_path, next_counter), normalize=True)
print(f"Image saved at: {args.out_path}, counter: {next_counter}")

#### open generated images

In [None]:
gen2_image = Image.open('%s/generated_sample_s2_%03d.png' % (args.out_path, next_counter))
gen2_image

In [None]:
gen1_image = Image.open('%s/generated_sample_s1_%03d.png' % (args.out_path, next_counter))
gen1_image