In [1]:
import os
import sys

In [2]:
current_cwd = os.getcwd()
src_path = '/'.join(current_cwd.split('/')[:-1])
sys.path.append(src_path)

In [3]:
import numpy as np
import torch
import torchvision.utils as vutils
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms

from src.generator.model import Generator
from src.encoders.text_encoder import RNNEncoder
from src.objects.dataset import AttnGANDataset
from src.objects.utils import prepare_data

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
generator = Generator().to(device)
generator.load_state_dict(torch.load("../gen_weights_70/gen_weights_epoch_69.pth", map_location=device))
generator = generator.eval()

In [6]:
batch_size = 16
split_dir, bshuffle = 'test', True
image_size = 64 * (2 ** (3 - 1))

image_transform = transforms.Compose([
    transforms.Scale(int(image_size * 76 / 64)),
    transforms.RandomCrop(image_size),
    transforms.RandomHorizontalFlip()
])

data_dir = "../data"
dataset = AttnGANDataset(data_dir, split_dir, image_transform)

dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=True, shuffle=bshuffle)

In [8]:
n_words = dataloader.dataset.n_words

In [9]:
text_encoder = RNNEncoder.load("../encoder_weights/text_encoder200.pth", n_words)
text_encoder.to(device)

# Own birds

In [10]:
def save_image(image: np.ndarray, save_dir: str, file_name: str):
    # [-1, 1] --> [0, 255]
    image = (image + 1.0) * 127.5
    image = image.astype(np.uint8)
    image = np.transpose(image, (1, 2, 0))
    image = Image.fromarray(image)
    fullpath = os.path.join(save_dir, f"{file_name.replace('/', '_')}.png")
    image.save(fullpath)

In [11]:
def gen_own_bird(word_caption, name):
    codes = [dataset.word2code[w] for w in word_caption.lower().split()]
    
    caption = np.array(codes)
    pad_caption = np.zeros((18, 1), dtype='int64')

    if len(caption) <= 18:
        pad_caption[:len(caption), 0] = caption
        len_ = len(caption)
    else:
        indices = list(np.arange(len(caption)))
        np.random.shuffle(indices)
        pad_caption[:, 0] = caption[np.sort(indices[:18])]
        len_ = 18

    captions = torch.tensor(pad_caption).reshape(1, -1)
    captions_len = torch.tensor([len_])
    word_embeds, sentence_embeds = text_encoder(captions, captions_len)
    
    mask = (captions == 0)
    num_words = word_embeds.size(2)

    if mask.size(1) > num_words:
        mask = mask[:, :num_words]
        
    batch_size = sentence_embeds.shape[0]
        
    noise = torch.randn(batch_size, 100, device=device)
    fake_images, _, mu, log_var = generator(noise, sentence_embeds, word_embeds, mask)
    
    save_image(fake_images[2][0].data.cpu().numpy(), "../gen_images", name)

In [12]:
caption = "Small brown chicken with red crown and yellow wings"
gen_own_bird(caption, caption)

In [13]:
batch = next(iter(dataloader))

In [14]:
images, captions, captions_len, _, file_names = prepare_data(batch, device)
word_embeds, sentence_embeds = text_encoder(captions, captions_len)

mask = (captions == 0)
num_words = word_embeds.size(2)

if mask.size(1) > num_words:
    mask = mask[:, :num_words]

batch_size = sentence_embeds.shape[0]

noise = torch.randn(batch_size, 100, device=device)
fake_images, _, mu, log_var = generator(noise, sentence_embeds, word_embeds, mask)

vutils.save_image(fake_images[2].data, "../gen_images_70/birds.png", normalize=True)