In [1]:
import argparse
import os
from data import CIFAR10Dataset, Imagenet32Dataset
from models.embedders import BERTEncoder, OneHotClassEmbedding, UnconditionalClassEmbedding, GPTEncoder
import torch
from models.cgan import CDCGAN_G, CDCGAN_D
from torch.optim import lr_scheduler
import time
from tqdm import tqdm
from tensorboardX import SummaryWriter
import numpy as np
import torchvision.utils as vutils

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
%reload_ext autoreload
%autoreload 2
from utils.evalutils import sample_image, sample_final, load_model, sample_for_inception, real_imgs, sample_final_tier2
from inception import inception_score

# Setup

In [3]:
# setting params
batch_size = 32
use_cuda = 1
n_filters=128
z_dim=100
output_dir="outputs/cgan_cifar10"
model_checkpoint="outputs/cgan_cifar10/models/epoch_81.pt"
print_every=10
dataset="cifar10"
conditioning="bert"
device = torch.device("cuda") if (torch.cuda.is_available() and use_cuda) else torch.device("cpu")
# n_epochs =150
# lr=0.0001
# lr_decay=0.99
# n_cpu=8
# sample_interval=100
# eval_dir = "outputs/cgan_cifar10/eval"
# debug=0
# train_on_val=0
# train=1
#choices=["unconditional", "one-hot", "bert", "gpt"]
# setup device

In [4]:
print("Device is {}".format(device))

Device is cuda


In [5]:
# Imagenet and CIFAR dataloaders
imagenet_dataset = Imagenet32Dataset(train=0, max_size=-1)
cifar_dataset = CIFAR10Dataset(train=0, max_size=-1)
imagenet_val_dataloader = torch.utils.data.DataLoader(
    imagenet_dataset,
    batch_size,
    shuffle=True,
    drop_last=True
)
cifar_val_dataloader = torch.utils.data.DataLoader(
    cifar_dataset,
    batch_size,
    shuffle=True,
    drop_last=True
)

loading data file 1/1, datasets/ImageNet32/val/val_data.npz
Files already downloaded and verified


In [6]:
# check len
len(imagenet_val_dataloader)

1562

In [7]:
len(cifar_val_dataloader)

312

In [8]:
# Initialize embedder
unconditional_encoder = UnconditionalClassEmbedding()
bert_encoder = BERTEncoder()
gpt_encoder = GPTEncoder()

# Initialize models

In [9]:
# init generator model
# model_G = CDCGAN_G(z_dim=z_dim, embed_dim=768, n_filters=n_filters)
model_G_cifar10_baseline = CDCGAN_G(z_dim=z_dim, embed_dim=768, n_filters=n_filters)
model_G_imagenet_baseline = CDCGAN_G(z_dim=z_dim, embed_dim=768, n_filters=n_filters)
model_G_cifar10_gpt = CDCGAN_G(z_dim=z_dim, embed_dim=768, n_filters=n_filters)
model_G_cifar10_gptsigmoid = CDCGAN_G(z_dim=z_dim, embed_dim=768, n_filters=n_filters)
model_G_cifar10_wgan = CDCGAN_G(z_dim=z_dim, embed_dim=768, n_filters=n_filters)
# model_G.weight_init(mean=0.0, std=0.02)
# model_G = model_G.to(device)
# state_dict = torch.load(model_checkpoint, map_location=torch.device('cpu'))['G']

In [10]:
# model checkpoints
cgan_cifar10_baseline = "outputs/cgan_cifar10/models"
cgan_imagenet_baseline = "outputs/cgan_imagenet/models"
cgan_cifar10_gpt = "outputs/cgan_gpt/models"
cgan_cifar10_gptsigmoid = "outputs/cgan_gpt_sigmoid/models"
cgan_cifar10_wgan = "outputs/wcgan_cifar10/models"

# Sample images from most trained model

In [11]:
# sample images: baseline CIFAR
model_checkpoint = os.path.join(cgan_cifar10_baseline, "epoch_" + str(81) + ".pt")
load_model(model_checkpoint, model_G_cifar10_baseline, torch.device('cpu'))

In [12]:
sample_final(model_G_cifar10_baseline, bert_encoder, "outputs/cgan_cifar10",
             n_row=4, dataloader=cifar_val_dataloader, device=device)



saved  outputs/cgan_cifar10/samples/final_sample


In [13]:
sample_final_tier2(model_G_cifar10_baseline, bert_encoder, "outputs/cgan_cifar10",
             n_row=4, caption_file="map_clsloc2.txt", device=device)

saved  outputs/cgan_cifar10/samples/final_sample_tier2_2


In [14]:
sample_final_tier2(model_G_cifar10_baseline, bert_encoder, "outputs/cgan_cifar10",
             n_row=4, caption_file="map_clsloc3.txt", device=device)

saved  outputs/cgan_cifar10/samples/final_sample_tier2_3


In [15]:
# sample images: basline imagenet
model_checkpoint = os.path.join(cgan_imagenet_baseline, "epoch_" + str(7) + ".pt")
load_model(model_checkpoint, model_G_imagenet_baseline, torch.device('cpu'))

In [16]:
sample_final(model_G_imagenet_baseline, bert_encoder, "outputs/cgan_imagenet",
             n_row=4, dataloader=imagenet_val_dataloader, device=device)

saved  outputs/cgan_imagenet/samples/final_sample


In [49]:
sample_final_tier2(model_G_imagenet_baseline, bert_encoder, "outputs/cgan_imagenet",
             n_row=4, caption_file="map_clsloc2.txt", device=device)

saved  outputs/cgan_imagenet/samples/final_sample_tier2_2


In [50]:
sample_final_tier2(model_G_imagenet_baseline, bert_encoder, "outputs/cgan_imagenet",
             n_row=4, caption_file="map_clsloc3.txt", device=device)

saved  outputs/cgan_imagenet/samples/final_sample_tier2_3


In [19]:
# sample images: CIFAR GPT
model_checkpoint = os.path.join(cgan_cifar10_gpt, "epoch_" + str(99) + ".pt")
load_model(model_checkpoint, model_G_cifar10_gpt, torch.device('cpu'))

In [20]:
sample_final(model_G_cifar10_gpt, gpt_encoder, "outputs/cgan_gpt",
             n_row=4, dataloader=cifar_val_dataloader, device=device)

saved  outputs/cgan_gpt/samples/final_sample


In [21]:
sample_final_tier2(model_G_cifar10_gpt, gpt_encoder, "outputs/cgan_gpt",
             n_row=4, caption_file="map_clsloc2.txt", device=device)

saved  outputs/cgan_gpt/samples/final_sample_tier2_2


In [22]:
sample_final_tier2(model_G_cifar10_gpt, gpt_encoder, "outputs/cgan_gpt",
             n_row=4, caption_file="map_clsloc3.txt", device=device)

saved  outputs/cgan_gpt/samples/final_sample_tier2_3


In [23]:
# sample images: CIFAR GPT with sigmoid
model_checkpoint = os.path.join(cgan_cifar10_gptsigmoid, "epoch_" + str(122) + ".pt")
load_model(model_checkpoint, model_G_cifar10_gptsigmoid, torch.device('cpu'))

In [24]:
sample_final(model_G_cifar10_gptsigmoid, gpt_encoder, "outputs/cgan_gpt_sigmoid",
             n_row=4, dataloader=cifar_val_dataloader, device=device)

saved  outputs/cgan_gpt_sigmoid/samples/final_sample


In [25]:
sample_final_tier2(model_G_cifar10_gptsigmoid, gpt_encoder, "outputs/cgan_gpt_sigmoid",
             n_row=4, caption_file="map_clsloc2.txt", device=device)

saved  outputs/cgan_gpt_sigmoid/samples/final_sample_tier2_2


In [26]:
sample_final_tier2(model_G_cifar10_gptsigmoid, gpt_encoder, "outputs/cgan_gpt_sigmoid",
             n_row=4, caption_file="map_clsloc3.txt", device=device)

saved  outputs/cgan_gpt_sigmoid/samples/final_sample_tier2_3


In [27]:
# sample images: WCGAN CIFAR
model_checkpoint = os.path.join(cgan_cifar10_wgan, "epoch_" + str(92) + ".pt")
load_model(model_checkpoint, model_G_cifar10_wgan, torch.device('cpu'))

In [28]:
sample_final(model_G_cifar10_wgan, bert_encoder, "outputs/wcgan_cifar10",
             n_row=4, dataloader=cifar_val_dataloader, device=device)

saved  outputs/wcgan_cifar10/samples/final_sample


In [29]:
sample_final_tier2(model_G_cifar10_wgan, bert_encoder, "outputs/wcgan_cifar10",
             n_row=4, caption_file="map_clsloc2.txt", device=device)

saved  outputs/wcgan_cifar10/samples/final_sample_tier2_2


In [30]:
sample_final_tier2(model_G_cifar10_wgan, bert_encoder, "outputs/wcgan_cifar10",
             n_row=4, caption_file="map_clsloc3.txt", device=device)

saved  outputs/wcgan_cifar10/samples/final_sample_tier2_3


# Calculate inception score for true data
(9.562877186875763, 0.0)

In [63]:
cifar_val_dataloader = torch.utils.data.DataLoader(
    cifar_dataset,
    100,
    shuffle=True,
    drop_last=True
)
imgs = [img for (img, labels_batch, captions_batch) in cifar_val_dataloader][0].numpy()

In [64]:
inception_score(imgs, resize = True)
# (8.821857043548619, 0.0)

(8.821857043548619, 0.0)

# Inception score for our models

In [79]:
# GPT
imgs_gpt = sample_for_inception(model_G_cifar10_gpt, gpt_encoder, 
                     100, dataloader=cifar_val_dataloader, device=device)
inception_score(imgs_gpt, resize = True)
# (3.1605518320034443, 0.0) 500 #this at 300
# size of imgs is (320,3,32,32)
# (3.069024780561485, 0.0)

(3.1159828196134054, 0.0)

In [80]:
# check
imgs_gpt.shape

(300, 3, 32, 32)

In [104]:
# baseline : NOTE this has batch size 300
imgs_baseline = sample_for_inception(model_G_cifar10_baseline, bert_encoder, 
                     100, dataloader=cifar_val_dataloader, device=device)
inception_score(imgs_baseline, resize = True)
# (1.064679586315566, 0.0)
# (3.0453448825573903, 0.0)
# (3.110995859711864, 0.0) 500
# (3.1048524864446625, 0.0) 300

(3.0566969947586062, 0.0)

In [82]:
# wcgan
imgs_wcgan = sample_for_inception(model_G_cifar10_wgan, bert_encoder, 
                     100, dataloader=cifar_val_dataloader, device=device)
inception_score(imgs_wcgan, resize = True)
#(2.102412739434052, 0.0)
# (2.081028986475838, 0.0) 300

(2.081028986475838, 0.0)

# Save real images for FID score

In [107]:
imgs_real = real_imgs(200, cifar_val_dataloader, device) # 100 imgs

In [108]:
imgs_real_tensor = torch.from_numpy(imgs_real)

In [109]:
imgs_real_tensor.shape

torch.Size([300, 3, 32, 32])

In [111]:
for i in range(300):
    vutils.save_image(imgs_real_tensor[i,:,:,:], 'outputs/real/{}.png'.format(i))

# Create fake images for FID Score

In [112]:
# baseline
imgs_baseline_tensor = torch.from_numpy(imgs_baseline)
for i in range(300):
    vutils.save_image(imgs_baseline_tensor[i,:,:,:], 'outputs/cgan_cifar10/fid/{}.png'.format(i))

In [89]:
# gpt
imgs_gpt_tensor = torch.from_numpy(imgs_gpt)
for i in range(300):
    vutils.save_image(imgs_gpt_tensor[i,:,:,:], 'outputs/cgan_gpt/fid/{}.png'.format(i))

In [91]:
# wcgan
imgs_wcgan_tensor = torch.from_numpy(imgs_wcgan)
for i in range(300):
    vutils.save_image(imgs_wcgan_tensor[i,:,:,:], 'outputs/wcgan_cifar10/fid/{}.png'.format(i))

# results

131.25505595644205 for baseline

135.2906583566559 for gpt

164.5993704587727 for wcgan