## Notebook used for testing the trained model with text and images

In [1]:
import os
import argparse
import fasttext
from PIL import Image
import cv2
import numpy as np
import random

import torch
import torchvision.transforms as transforms
from torchvision.utils import save_image

from model import Generator

In [2]:
if not torch.cuda.is_available():
    print('Warning: cuda is not available on this machine.')
    args.no_cuda = True
device = torch.device('cpu' if not torch.cuda.is_available() else 'cuda')

In [3]:
device

device(type='cuda')

In [4]:
print('Loading a pretrained fastText model...')
word_embedding = fasttext.load_model("caption_vec.bin")

Loading a pretrained fastText model...




In [90]:
print('Loading a pretrained model...')
G = Generator().to(device)
G.load_state_dict(torch.load("models/birds_GEN.pth"))
G.eval()

Loading a pretrained model...


Generator(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ReLU(inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
  )
  (residual_blocks): Sequential(
    (0): Conv2d(640, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=Tru

In [41]:
transform = transforms.Compose([transforms.CenterCrop(128), transforms.ToTensor()])

In [48]:
classes = ['Caspian Tern']

In [49]:
num_images = int(input("Please input number of images per class: "))

Please input number of images per class: 4


In [50]:
image_generation(classes, num_images)

Done:  Caspian Tern


In [14]:
path = "CUB_200/test/Rose Breasted Grosbeak/"
filenames = os.listdir(path)
img = []
for fn in filenames:
    im = Image.open(os.path.join(path, fn))
    im = transform(im)
    img.append(im)
img = torch.stack(img)
save_image(img, os.path.join("CUB_200/result/", 'rose_original.jpg'), pad_value=1)
img = img.mul(2).sub(1).to(device)

In [15]:
img.size()

torch.Size([6, 3, 128, 128])

In [91]:
text = "a light blue bird having black beak"

In [92]:
words = text.split()

In [93]:
words

['a', 'bird', 'having', 'light', 'blue', 'head']

In [94]:
txt = torch.tensor([word_embedding.get_word_vector(w) for w in words], device=device)
txt = txt.unsqueeze(1)
txt = txt.repeat(1, img.size(0), 1)
len_txt = torch.tensor([len(words)], dtype=torch.long, device=device)
len_txt = len_txt.repeat(img.size(0))

In [95]:
output, _ = G(img, (txt, len_txt))

In [96]:
output.size()

torch.Size([6, 3, 128, 128])

In [97]:
out_filename = 'rose_output_1.jpg'

In [98]:
save_image(output.mul(0.5).add(0.5), os.path.join("CUB_200/result/", out_filename), pad_value=1)